In [1]:
from torchvision.datasets import MNIST, mnist, CIFAR10
from torchvision import transforms
import torch.nn.functional as F
import torch
from torch import nn
from torch.utils.data import DataLoader

In [2]:
from tqdm.notebook import tqdm_notebook
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
class CustomTargetTransform:
    def __init__(self, num_classes=10):
        self.num_classes = num_classes

    def __call__(self, target):
        new_target = torch.zeros(self.num_classes, dtype=torch.float, device=device)
        new_target[target] = 1
        return new_target

transform = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Lambda(lambda x: x.float().to(device))
])

In [5]:
dataset = mnist.FashionMNIST("data", download=True, train=True, transform=transform)
dataset_target = mnist.FashionMNIST("data", download=True, train=False, transform=transforms.PILToTensor())
target_data = torch.tensor(dataset_target.data).unsqueeze(1).float().to(device)
dataset.data.shape

  target_data = torch.tensor(dataset_target.data).unsqueeze(1).float().to(device)


torch.Size([60000, 28, 28])

In [6]:
class AutoEncoder(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)
        )
        
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(9, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid()
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [15]:
model = AutoEncoder().to(device)

er_f = torch.nn.MSELoss()
 
optim = torch.optim.Adam(model.parameters(),
                             lr = 1e-1,
                             weight_decay = 1e-8)

batch_count = 32
epoch_count = 10


In [8]:
writer = SummaryWriter(comment="autoencoder")

In [26]:
total_epochs = 0
torch.manual_seed(0)
data_loader = DataLoader(dataset, batch_size=batch_count, shuffle=True)
for epoch in tqdm_notebook(range(epoch_count)):
    for image, target in tqdm_notebook(data_loader, leave=False):
        print(image.swapaxes(2, 0).swapaxes(3, 1).shape)
        optim.zero_grad()
        outs = model(image)
        loss = er_f(outs, target)
        loss.backward()
        optim.step()
        writer.add_scalar("loss", loss, total_epochs)
        total_epochs += 1
    writer.add_images("source", image, total_epochs)
    writer.add_images("reconstructed", outs, total_epochs)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

torch.Size([28, 28, 32, 1])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (896x28 and 784x128)