In [77]:
import torch
import torchvision
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [78]:
import random
from PIL import Image


class CustomDataset(Dataset):
    def __init__(self, train, to_tensor=True):
        self.mnist = torchvision.datasets.MNIST(
            "files/mnist", train=train, download=True
        )
        self.fashion_mnist = torchvision.datasets.FashionMNIST(
            "files/fashion_mnist", train=train, download=True
        )
        self.to_tensor_transform = transforms.Compose([transforms.ToTensor()])

        self.dataset = []
        self.mnist_size = len(self.mnist)

        for i in range(self.mnist_size):
            number_img, number_label = self.mnist[i]
            left_img, left_label = random.choice(self.fashion_mnist)

            while True:
                right_img, right_label = random.choice(self.fashion_mnist)
                if left_label != right_label:
                    break

            img = self.__concat_images(left_img, number_img, right_img)
            label = left_label if number_label % 2 == 0 else right_label

            if to_tensor:
                img = self.to_tensor_transform(img)

            self.dataset.append((img, label))

    def __len__(self):
        return self.mnist_size

    def __getitem__(self, idx):
        return self.dataset[idx]

    def __concat_images(self, left, center, right):
        IMG_SIZE = 28
        img = Image.new("L", (IMG_SIZE * 3, IMG_SIZE))
        img.paste(left, (0, 0))
        img.paste(center, (IMG_SIZE, 0))
        img.paste(right, (2 * IMG_SIZE, 0))
        return img

    @staticmethod
    def decode_label(label):
        mapping = {
            0: "T-shirt/Top",
            1: "Trouser",
            2: "Pullover",
            3: "Dress",
            4: "Coat",
            5: "Sandal",
            6: "Shirt",
            7: "Sneaker",
            8: "Bag",
            9: "Ankle Boot",
        }

        return mapping[label]

In [79]:
train = CustomDataset(train=True)
test = CustomDataset(train=False)

In [80]:
print(test[1][1])
test[1][0]

7


tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [81]:
train_loader = DataLoader(train, batch_size=128)


In [124]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(0, -1),
            nn.Dropout(0.5),
            nn.Linear(512, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
            nn.Softmax()
        )
    
    def forward(self, x):
        return self.layers(x)

In [126]:
model = CNN()
error = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
t = model.forward(train[0][0])
t

tensor([0.1009, 0.1053, 0.0866, 0.1020, 0.1009, 0.0966, 0.1125, 0.1104, 0.0843,
        0.1005], grad_fn=<SoftmaxBackward0>)