<a href="https://colab.research.google.com/github/PedroLatasa/PracticaIA/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [70]:
import torch
import gzip,pickle
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
from io import BytesIO
import imageio.v2 as imageio
import os
from IPython.display import Video, display
from PIL import Image


In [71]:
dataset_dir = '/content/mnist.pkl.gz'

try:
    with gzip.open(dataset_dir, 'rb') as f:
        try:
            train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
        except:
            train_set, valid_set, test_set = pickle.load(f)
except:
    raise ValueError("Download data from http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz and place it in a directory")

## Read and concatenate all data into X and Y
X_tr, Y_tr = train_set
X_va, Y_va = valid_set
X_te, Y_te = test_set

X =  np.vstack((X_tr,X_va, X_te))
## check np.vstack and np.hstack
T = np.hstack((Y_tr, Y_va, Y_te))

In [78]:
class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, transforms = None):
        ## call parent method
        super().__init__()
        self.X = X
        self.Y = Y
        self.transforms = transforms

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        X = self.X[idx]
        for t in self.transforms:
          X = t(X)

        return X, self.Y[idx]

def reshape(X):
    return np.reshape(X,(28,28))

mnist_dataset = MNISTDataset(X ,T, transforms = [reshape])


In [79]:
mnist_loader = torch.utils.data.DataLoader(
                                                dataset = mnist_dataset,
                                                batch_size = 1,
                                                shuffle=True,
                                                num_workers=1,
                                            )
fig, ax = plt.subplots(1,1)
video_filename = "/content/aux4.mp4"
writer = imageio.get_writer(video_filename, format="FFMPEG", mode="I", fps=1, codec="libx264")

counter = 0
for x,t in mnist_loader:
    ax.clear()
    ax.imshow(x[0], cmap = 'gray')
    ax.set_title(f'Class {t.item()}')

    ## add frame for video creation
    buf = BytesIO()
    fig.savefig(buf, format="png", dpi=100)

    buf.seek(0)
    frame = imageio.imread(buf)
    writer.append_data(frame)

    if counter == 10:
        break

    counter += 1

writer.close()
plt.close()

display(Video(data=video_filename, embed=True))
os.remove(video_filename)

In [83]:
class RandomBlankSquare:
    def __init__(self,min_h,max_h, min_v,max_v):
        self.min_h = min_h
        self.max_h = max_h
        self.min_v = min_v
        self.max_v = max_v

    def __call__(self, x):
        min_h_idx = np.random.randint(0, self.min_h, size=(1,), dtype = int).item()
        max_h_idx = np.random.randint(self.min_h, self.max_h, size=(1,), dtype = int).item()

        min_v_idx = np.random.randint(0, self.min_v , size=(1,), dtype = int).item()
        max_v_idx = np.random.randint(self.min_v, self.max_v, size=(1,), dtype = int).item()

        x[min_v_idx:max_v_idx,min_h_idx:max_h_idx] = 0.0

        return x

class RandomRotation:
    def __init__(self,max_degree):
        self.max_deg = max_degree

    def __call__(self, x):
        rot_deg = np.random.random(size=(1,)).item()*2*self.max_deg - self.max_deg

        x = Image.fromarray(x)
        x = x.rotate(rot_deg)
        x = np.array(x)
        return x

mnist_dataset = MNISTDataset(
                                X,
                                T,
                                transforms = [
                                            reshape,
                                            RandomBlankSquare(min_h = 10, max_h = 20,min_v = 10, max_v = 20),
                                            RandomRotation(70),
                                        ]
                         )


mnist_loader = torch.utils.data.DataLoader(
                                                dataset = mnist_dataset,
                                                batch_size = 1,
                                                shuffle=True,
                                                num_workers=1,
                                            )
video_filename = "/content/aux4.mp4"
writer = imageio.get_writer(video_filename, format="FFMPEG", mode="I", fps=1, codec="libx264")
fig, ax = plt.subplots(1,1)

counter = 0
for x,t in mnist_loader:

    ax.imshow(x[0], cmap = 'gray')
    ax.set_title(f'Class {t.item()}')

    ## add frame for video creation
    buf = BytesIO()
    fig.savefig(buf, format="png", dpi=100)

    buf.seek(0)
    frame = imageio.imread(buf)
    writer.append_data(frame)

    if counter == 10:
        break

    counter += 1

writer.close()
plt.close()
display(Video(data=video_filename, embed=True))
os.remove(video_filename)