In [74]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
import numpy as np
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import os
import time
import copy
import idx2numpy

In [75]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [76]:
transform = transforms.Compose([transforms.Resize(28),
                                transforms.CenterCrop(28),
                                transforms.RandomHorizontalFlip(0.5),
                                transforms.ToTensor(),
                                transforms.Normalize(0.5, 0.5)])

In [77]:
image_set = torchvision.datasets.FashionMNIST('./FashionMNIST', download = True, transform=transforms.ToTensor())

In [78]:
test_set = torchvision.datasets.FashionMNIST('./FashionMNIST', download = True, train = False, transform=transforms.ToTensor())

In [79]:
train_loader = torch.utils.data.DataLoader(dataset=image_set, batch_size = 32, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size = 32, shuffle = False)

In [105]:
print(image_set[0][0].shape)

torch.Size([1, 28, 28])


In [100]:
num_epochs = 100
batch_size = 32
learning_rate = 0.1

In [246]:
class DownSamp(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(DownSamp, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, output_channels//2, 3, 2, 2)
        self.bn1 = nn.BatchNorm2d(output_channels//2, eps=1e-4)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(output_channels//2, output_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(output_channels, eps=1e-4)
        self.elu = nn.ELU()

    def forward(self, x):
        x = self.elu(self.bn2(self.conv2(self.bn1(self.conv1(x)))))
        return x

class UpSamp(nn.Module):
    def __init__(self, input_channels, output_channels, scale=2):
        super(UpSamp, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, output_channels//2, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(output_channels//2, eps=1e-4)
        self.conv2 = nn.Conv2d(output_channels//2, output_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(output_channels, eps=1e-4)
        self.up_sample = nn.Upsample(scale_factor=scale, mode = "nearest")
        self.elu = nn.ELU()
    
    def forward(self, x):
        x = self.elu(self.up_sample(self.bn2(self.conv2(self.bn1(self.conv1(x))))))
        return x

In [247]:
class Encoder(nn.Module):
    def __init__(self, input_channels, ch = 64, latent_channels = 256):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, ch, kernel_size=5, stride=1, padding=2)
        self.res_down1 = DownSamp(ch, ch*2)
        self.res_down2 = DownSamp(ch*2, ch*4)
        self.res_down3 = DownSamp(ch*4, ch*8)
        self.mu = nn.Conv2d(ch*8, latent_channels, 5, 1)
        self.log_var = nn.Conv2d(ch*8, latent_channels, 5, 1)
        self.act = nn.ELU()
    
    def sample(self, mu, log_var):
        std = torch.exp(log_var)
        epsilon = torch.randn_like(std)
        return mu + std*epsilon

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.res_down1(x)
        x = self.res_down2(x)
        x = self.res_down3(x)
        mu = self.mu(x)
        log_var = self.log_var(x)

        if self.training:
            x = self.sample(mu, log_var)
        else:
            x = mu
        return x, mu, log_var

In [263]:
class Decoder(nn.Module):
    def __init__(self, channels, ch = 64, latent_channels = 256):
        super(Decoder, self).__init__()
        self.conv_up = nn.ConvTranspose2d(latent_channels, ch*8, 4, 1, 1)
        self.res_up1 = UpSamp(ch*8, ch*4)
        self.res_up2 = UpSamp(ch*4, ch*2)
        self.res_up3 = UpSamp(ch*2, ch)
        self.conv2 = nn.Conv2d(ch, channels, 3, 1, 3)
        self.act = nn.ELU()
    
    def forward(self, x):
        x = self.act(self.conv_up(x)) 
        x = self.res_up1(x)
        x = self.res_up2(x)
        x = self.res_up3(x)
        x = torch.tanh(self.conv2(x))
        return x 

In [264]:
class VarAutoEnc(nn.Module):
    def __init__(self, input_channels = 1, ch = 64, latent_channels = 256):
        super(VarAutoEnc, self).__init__()
        self.encoder = Encoder(input_channels, ch, latent_channels)
        self.decoder = Decoder(input_channels, ch, latent_channels)

    def forward(self, x):
        encoded, mu, log_var = self.encoder(x)
        img = self.decoder(encoded)
        return img, mu, log_var

In [265]:
def KL_Divergence(mu, log_var):
    return 0.5*len(mu)*(1 + log_var - mu.pow(2) - log_var.exp()).mean()

In [266]:
model = VarAutoEnc(input_channels = 1, ch = 64, latent_channels = 256).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
scalar = torch.cuda.amp.GradScaler()

In [267]:
torch.save(model.state_dict(), "vae1.pt")
since = time.time()
best_acc = 0.0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-'*10)
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        recon_imgs, mu, log_var = model(inputs)
        kl_loss = KL_Divergence(mu, log_var)
        mse_loss = F.mse_loss(recon_imgs, inputs)
        loss =  kl_loss + mse_loss
        running_loss = loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Loss = {running_loss}")
torch.save(model.state_dict(), "vae100.pt")


time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s')

Epoch 1/100
----------
torch.Size([32, 1, 28, 28])
torch.Size([32, 256, 2, 2])
torch.Size([32, 512, 3, 3])
torch.Size([32, 256, 6, 6])
torch.Size([32, 128, 12, 12])
torch.Size([32, 64, 24, 24])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 256, 2, 2])
torch.Size([32, 512, 3, 3])
torch.Size([32, 256, 6, 6])
torch.Size([32, 128, 12, 12])
torch.Size([32, 64, 24, 24])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 256, 2, 2])
torch.Size([32, 512, 3, 3])
torch.Size([32, 256, 6, 6])
torch.Size([32, 128, 12, 12])
torch.Size([32, 64, 24, 24])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 256, 2, 2])
torch.Size([32, 512, 3, 3])
torch.Size([32, 256, 6, 6])
torch.Size([32, 128, 12, 12])
torch.Size([32, 64, 24, 24])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 256, 2, 2])
torch.Size([32, 512, 3, 3])
torch.Size([32, 256, 6, 6])
torch.Size([32, 128, 12, 12])
torch.Size([32, 64, 24, 24])
torch.Size

KeyboardInterrupt: 