In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
%matplotlib inline

In [25]:
# -----------------------------
# 1. Load MSNIST dataset
# -----------------------------

# This will stream the data, you don't have to download the full file
# mnist_train = load_dataset("ylecun/mnist", split="train")

# mnist_test = load_dataset("ylecun/mnist", split="test")


##### Look into the normalisation #####
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transform)

mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transform)

In [29]:
img, label = mnist_train[0]
print(img.shape)    # e.g., torch.Size([1, 28, 28])
print(type(label)) 

torch.Size([1, 28, 28])
<class 'int'>


In [16]:
def patch(img, patch_size=7):
    # img shape: (1, 28, 28)
    patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    # shape: (1, 4, 4, 7, 7)
    patches = patches.contiguous().view(1, -1, patch_size, patch_size)
    # shape: (1, 16, 7, 7)
    return patches.squeeze(0)  # (16, 7, 7)

def batch_patch(batch_imgs, patch_size=7):
    # batch_imgs: (B, 1, 28, 28)
    B = batch_imgs.shape[0]
    patches = batch_imgs.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    # (B, 1, 4, 4, 7, 7)
    patches = patches.contiguous().view(B, 1, -1, patch_size, patch_size)  # (B, 1, 16, 7, 7)
    patches = patches.squeeze(1)  # (B, 16, 7, 7)
    return patches



In [17]:
img, label = mnist_train[0]  # img: (1, 28, 28)
patches = patch(img, patch_size=7)
print(patches.shape)  # Should print: torch.Size([16, 7, 7])

torch.Size([16, 7, 7])


In [26]:
all_patches = [patch(img) for img, _ in mnist_train]  # list of (16,7,7)
all_patches = torch.stack(all_patches)  # (N, 16, 7, 7)
all_labels = torch.tensor([label for _, label in mnist_train])  # shape: (60000,)


torch.Size([60000, 16, 7, 7])


In [27]:
print(all_patches.shape)  # Should print: torch.Size([60000, 16, 7, 7])
print(all_labels.shape)  # Should print: torch.Size([60000])


torch.Size([60000, 16, 7, 7])
torch.Size([60000])
