In [1]:
%matplotlib inline
import torch
from torch.distributions import Normal
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.utils import make_grid
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision.utils import save_image

In [2]:
directory = './'

In [3]:
torchvision.datasets.MNIST(directory, train=True, download=True)

Dataset MNIST
    Number of datapoints: 60000
    Root location: E:/datasets/
    Split: Train

In [4]:
images, ground_truth = torch.load(directory + "MNIST/processed/training.pt")

In [5]:
print(images.shape)
print(ground_truth.shape)

torch.Size([60000, 28, 28])
torch.Size([60000])


In [6]:
print(torch.min(images[0]))
print(torch.max(images[0]))

tensor(0, dtype=torch.uint8)
tensor(255, dtype=torch.uint8)


In [7]:
image_index = 15

print(ground_truth[image_index])
# plt.imshow(images[image_index], cmap='gray')

tensor(7)


In [8]:
from torch.utils.data import Dataset

class MNISTDataset(Dataset):
    def __init__(self, path):
        self.images, self.ground_truth = torch.load(path)
    
    def __getitem__(self, idx):
        x = self.images[idx].float()
        x = torch.Tensor(x)
        x = torch.flatten(x)
        return x
    
    def __len__(self):
        return len(self.ground_truth)

In [9]:
train_dataset = MNISTDataset(directory + "MNIST/processed/training.pt")
test_dataset = MNISTDataset(directory + "MNIST/processed/test.pt")

print("Train dataset length:", len(train_dataset))
print("Test dataset length:", len(test_dataset))

Train dataset length: 60000
Test dataset length: 10000


In [10]:
train_dataset[15].shape

torch.Size([784])

In [14]:
import torch.nn as nn

class AutoEncoder(nn.Module):
    
    def __init__(self, input_shape, latent_dim=128):
        super().__init__()
        self.encoder_l1 = nn.Linear(in_features=input_shape, out_features=latent_dim)
        self.encoder_l2 = nn.Linear(in_features=latent_dim, out_features=latent_dim)
        
        self.decoder_l1 = nn.Linear(in_features=latent_dim, out_features=latent_dim)
        self.decoder_l2 = nn.Linear(in_features=latent_dim, out_features=input_shape)
        
    def forward(self, x):
        latent = self.run_encoder(x)
        x_hat = self.run_decoder(latent)
        return x_hat
    
    def run_encoder(self, x):
        output = F.relu(self.encoder_l1(x))
        latent = F.relu(self.encoder_l2(output))
        return latent
    
    def run_decoder(self, latent):
        output = F.relu(self.decoder_l1(latent))
        x_hat = F.relu(self.decoder_l2(output))
        return x_hat

In [15]:
model = AutoEncoder(input_shape=train_dataset[0].shape[0])
model.cuda()

AutoEncoder(
  (encoder_l1): Linear(in_features=784, out_features=128, bias=True)
  (encoder_l2): Linear(in_features=128, out_features=128, bias=True)
  (decoder_l1): Linear(in_features=128, out_features=128, bias=True)
  (decoder_l2): Linear(in_features=128, out_features=784, bias=True)
)

In [16]:
model.train()

AutoEncoder(
  (encoder_l1): Linear(in_features=784, out_features=128, bias=True)
  (encoder_l2): Linear(in_features=128, out_features=128, bias=True)
  (decoder_l1): Linear(in_features=128, out_features=128, bias=True)
  (decoder_l2): Linear(in_features=128, out_features=784, bias=True)
)

In [17]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=128,
                                           shuffle=True)
mse = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 100

for epoch in range(n_epochs):
    loss = 0
    
    for batch_features in train_loader:
        batch_features = batch_features.cuda()
        
        optimizer.zero_grad()
        
        outputs = model(batch_features)
        
        train_loss = mse(outputs, batch_features)
        
        train_loss.backward()
        
        optimizer.step()
        
        loss += train_loss.item()
    
    loss = loss / len(train_loader)
    
    if epoch % 5 == 0:
        print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, n_epochs, loss))

print("epoch: {}/{}, loss = {:.6f}".format(epoch + 1, n_epochs, loss))

epoch : 1/100, loss = 2066.458537
epoch : 6/100, loss = 903.045576
epoch : 11/100, loss = 823.641144
epoch : 16/100, loss = 788.662906
epoch : 21/100, loss = 766.621867
epoch : 26/100, loss = 752.876644
epoch : 31/100, loss = 742.759293
epoch : 36/100, loss = 733.817733
epoch : 41/100, loss = 729.049442
epoch : 46/100, loss = 725.443281
epoch : 51/100, loss = 722.037865
epoch : 56/100, loss = 717.985390
epoch : 61/100, loss = 715.416301
epoch : 66/100, loss = 713.759525
epoch : 71/100, loss = 711.816621
epoch : 76/100, loss = 710.086370
epoch : 81/100, loss = 708.871276
epoch : 86/100, loss = 706.901524
epoch : 91/100, loss = 706.631080
epoch : 96/100, loss = 706.380368
epoch: 100/100, loss = 705.344930


In [18]:
model.eval()

image_idx = 100

image = train_dataset[image_idx].cuda()
x_hat = model(image)
x_hat = x_hat.detach().cpu().numpy().reshape((28, 28))
x_hat = np.around(x_hat)

In [19]:
# Get Latent Representation

@torch.no_grad()
def get_latent_representation(model, dataloader):
    latent_representation = []
    for batch in dataloader:
        batch = batch.cuda()
        latent = model.run_encoder(batch)
        latent = latent.cpu().numpy()
        latent_representation.append(latent)
    latent_representation = np.concatenate(latent_representation)
    return latent_representation

latent = get_latent_representation(model, train_loader)
latent.shape

(60000, 128)

In [20]:
# Variational Auto Encoder tingz

kernel_size = 4
init_channels = 8
image_channels = 1
latent_dim = 16

In [21]:
class ConvVAE(nn.Module):
    def __init__(self,
                 image_channels=1,
                 kernel_size=4,
                 latent_dim=16,
                 init_channels=8):
        super(ConvVAE, self).__init__()
        self.enc1 = nn.Conv2d(
            in_channels=image_channels, out_channels=init_channels, 
            kernel_size=kernel_size, stride=2, padding=1
        )
        self.enc2 = nn.Conv2d(
            in_channels=init_channels, out_channels=init_channels*2, 
            kernel_size=kernel_size, stride=2, padding=1
        )
        self.enc3 = nn.Conv2d(
            in_channels=init_channels*2, out_channels=init_channels*4, 
            kernel_size=kernel_size, stride=2, padding=1
        )
        self.enc4 = nn.Conv2d(
            in_channels=init_channels*4, out_channels=64, 
            kernel_size=kernel_size, stride=2, padding=0
        )
        
        self.fc1 = nn.Linear(64, 128)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_log_var = nn.Linear(128, latent_dim)
        self.fc2 = nn.Linear(latent_dim, 64)
        
        self.dec1 = nn.ConvTranspose2d(
            in_channels=64, out_channels=init_channels*8,
            kernel_size=kernel_size, stride=1, padding=0
        )
        self.dec2 = nn.ConvTranspose2d(
            in_channels=init_channels*8, out_channels=init_channels*4,
            kernel_size=kernel_size, stride=2, padding=1
        )
        self.dec3 = nn.ConvTranspose2d(
            in_channels=init_channels*4, out_channels=init_channels*2,
            kernel_size=kernel_size, stride=2, padding=1
        )
        self.dec4 = nn.ConvTranspose2d(
            in_channels=init_channels*2, out_channels=image_channels,
            kernel_size=kernel_size, stride=2, padding=1
        )
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std)
        return sample
    
    def forward(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        batch, _, _, _ = x.shape
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        hidden = self.fc1(x)
        
        mu = self.fc_mu(hidden)
        log_var = self.fc_log_var(hidden)
        
        z = self.reparameterize(mu, log_var)
        z = self.fc2(z)
        z = z.view(-1, 64, 1, 1)
        
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        reconstruction = torch.sigmoid(self.dec4(x))
        return reconstruction, mu, log_var

In [25]:
def final_loss(bce_loss, mu, logvar):
    BCE = bce_loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train(model, dataloader, dataset_size, device, optimizer, criterion):
    model.train()
    running_loss = 0.0
    counter = 0
    for i, data in tqdm(enumerate(dataloader),
                        total=int(dataset_size/dataloader.batch_size)):
        counter += 1
        data = data[0]
        data = data.to(device)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        bce_loss = criterion(reconstruction, data)
        loss = final_loss(bce_loss, mu, logvar)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
    train_loss = running_loss / counter
    return train_loss

def validate(model, dataloader, dataset, device, criterion):
    model.eval()
    running_loss = 0.0
    counter = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), 
                            total=int(len(dataset)/dataloader.batch_size)):
            counter += 1
            data = data[0]
            data = data.to(device)
            reconstruction, mu, logvar = model(data)
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
    
    val_loss = running_loss / counter
    return val_loss

In [26]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
model = ConvVAE().to(device)

lr = 0.001
epochs = 50
batch_size = 64

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

trainset = torchvision.datasets.MNIST(
    root=directory, train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True
)

testset = torchvision.datasets.MNIST(
    root=directory, train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=True
)

In [27]:
optimizer = optim.Adam(model.parameters(), lr=lr)

criterion = nn.BCELoss(reduction='sum')

for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    
    train_epoch_loss = train(
        model, trainloader, len(trainset), device, optimizer, criterion
    )
    valid_epoch_loss = validate(
        model, testloader, testset, device, criterion
    )
    
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {valid_epoch_loss:.4f}")

Epoch 1 of 50


938it [00:16, 57.10it/s]                                                                               
157it [00:01, 80.88it/s]                                                                               


Train Loss: 15595.7613
Val Loss: 11625.4322
Epoch 2 of 50


938it [00:16, 56.40it/s]                                                                               
157it [00:01, 85.48it/s]                                                                               


Train Loss: 11197.3220
Val Loss: 10927.7700
Epoch 3 of 50


938it [00:16, 56.75it/s]                                                                               
157it [00:01, 82.57it/s]                                                                               


Train Loss: 10771.3164
Val Loss: 10442.1131
Epoch 4 of 50


938it [00:16, 56.29it/s]                                                                               
157it [00:01, 90.41it/s]                                                                               


Train Loss: 10431.5850
Val Loss: 10253.3759
Epoch 5 of 50


938it [00:16, 56.24it/s]                                                                               
157it [00:01, 81.88it/s]                                                                               


Train Loss: 10293.7804
Val Loss: 10273.5549
Epoch 6 of 50


938it [00:16, 56.05it/s]                                                                               
157it [00:01, 95.39it/s]                                                                               


Train Loss: 10218.7800
Val Loss: 10177.7077
Epoch 7 of 50


938it [00:16, 57.03it/s]                                                                               
157it [00:01, 79.53it/s]                                                                               


Train Loss: 10163.8081
Val Loss: 10072.6485
Epoch 8 of 50


938it [00:16, 56.62it/s]                                                                               
157it [00:01, 79.64it/s]                                                                               


Train Loss: 10118.0810
Val Loss: 10028.5627
Epoch 9 of 50


938it [00:16, 57.37it/s]                                                                               
157it [00:01, 81.60it/s]                                                                               


Train Loss: 10083.0220
Val Loss: 10019.4551
Epoch 10 of 50


938it [00:16, 56.54it/s]                                                                               
157it [00:01, 80.36it/s]                                                                               


Train Loss: 10051.8128
Val Loss: 10026.9529
Epoch 11 of 50


938it [00:16, 57.88it/s]                                                                               
157it [00:01, 82.07it/s]                                                                               


Train Loss: 10027.5616
Val Loss: 9944.6373
Epoch 12 of 50


938it [00:16, 56.84it/s]                                                                               
157it [00:01, 88.20it/s]                                                                               


Train Loss: 10002.6872
Val Loss: 9973.0585
Epoch 13 of 50


938it [00:16, 57.77it/s]                                                                               
157it [00:01, 83.98it/s]                                                                               


Train Loss: 9985.1598
Val Loss: 9936.2282
Epoch 14 of 50


938it [00:16, 56.84it/s]                                                                               
157it [00:01, 83.14it/s]                                                                               


Train Loss: 9964.2664
Val Loss: 9951.2518
Epoch 15 of 50


938it [00:16, 57.01it/s]                                                                               
157it [00:01, 88.43it/s]                                                                               


Train Loss: 9956.1454
Val Loss: 9922.9013
Epoch 16 of 50


938it [00:16, 57.71it/s]                                                                               
157it [00:01, 88.88it/s]                                                                               


Train Loss: 9935.0501
Val Loss: 9881.2599
Epoch 17 of 50


938it [00:16, 56.32it/s]                                                                               
157it [00:01, 82.08it/s]                                                                               


Train Loss: 9865.2870
Val Loss: 9808.8569
Epoch 18 of 50


938it [00:16, 56.35it/s]                                                                               
157it [00:01, 80.26it/s]                                                                               


Train Loss: 9803.7313
Val Loss: 9756.2090
Epoch 19 of 50


938it [00:16, 56.57it/s]                                                                               
157it [00:02, 77.87it/s]                                                                               


Train Loss: 9775.9512
Val Loss: 9747.2231
Epoch 20 of 50


938it [00:16, 57.16it/s]                                                                               
157it [00:01, 81.03it/s]                                                                               


Train Loss: 9757.0365
Val Loss: 9713.4075
Epoch 21 of 50


938it [00:16, 56.35it/s]                                                                               
157it [00:01, 81.12it/s]                                                                               


Train Loss: 9735.8339
Val Loss: 9696.2199
Epoch 22 of 50


938it [00:16, 56.38it/s]                                                                               
157it [00:01, 81.13it/s]                                                                               


Train Loss: 9718.9127
Val Loss: 9663.3267
Epoch 23 of 50


938it [00:16, 56.72it/s]                                                                               
157it [00:01, 81.53it/s]                                                                               


Train Loss: 9708.8836
Val Loss: 9678.0130
Epoch 24 of 50


938it [00:16, 56.44it/s]                                                                               
157it [00:01, 81.58it/s]                                                                               


Train Loss: 9695.5101
Val Loss: 9690.4778
Epoch 25 of 50


938it [00:16, 56.37it/s]                                                                               
157it [00:01, 84.41it/s]                                                                               


Train Loss: 9684.0017
Val Loss: 9642.3997
Epoch 26 of 50


938it [00:16, 56.92it/s]                                                                               
157it [00:01, 84.02it/s]                                                                               


Train Loss: 9674.0161
Val Loss: 9639.9700
Epoch 27 of 50


938it [00:16, 57.48it/s]                                                                               
157it [00:01, 80.01it/s]                                                                               


Train Loss: 9663.8914
Val Loss: 9650.0596
Epoch 28 of 50


938it [00:16, 56.68it/s]                                                                               
157it [00:01, 81.25it/s]                                                                               


Train Loss: 9654.3253
Val Loss: 9615.5667
Epoch 29 of 50


938it [00:16, 56.40it/s]                                                                               
157it [00:01, 84.50it/s]                                                                               


Train Loss: 9643.8870
Val Loss: 9660.8429
Epoch 30 of 50


938it [00:16, 56.59it/s]                                                                               
157it [00:01, 80.37it/s]                                                                               


Train Loss: 9640.3193
Val Loss: 9625.5177
Epoch 31 of 50


938it [00:16, 55.55it/s]                                                                               
157it [00:01, 82.93it/s]                                                                               


Train Loss: 9634.5163
Val Loss: 9607.5148
Epoch 32 of 50


938it [00:16, 56.14it/s]                                                                               
157it [00:01, 80.44it/s]                                                                               


Train Loss: 9624.7772
Val Loss: 9695.2729
Epoch 33 of 50


938it [00:16, 56.49it/s]                                                                               
157it [00:01, 94.35it/s]                                                                               


Train Loss: 9620.3113
Val Loss: 9601.2465
Epoch 34 of 50


938it [00:16, 57.08it/s]                                                                               
157it [00:01, 85.67it/s]                                                                               


Train Loss: 9614.3552
Val Loss: 9633.7184
Epoch 35 of 50


938it [00:16, 57.55it/s]                                                                               
157it [00:01, 87.97it/s]                                                                               


Train Loss: 9610.9361
Val Loss: 9649.0442
Epoch 36 of 50


938it [00:16, 56.76it/s]                                                                               
157it [00:01, 82.67it/s]                                                                               


Train Loss: 9605.5630
Val Loss: 9591.9820
Epoch 37 of 50


938it [00:16, 56.50it/s]                                                                               
157it [00:01, 87.42it/s]                                                                               


Train Loss: 9596.9719
Val Loss: 9598.1109
Epoch 38 of 50


938it [00:16, 57.06it/s]                                                                               
157it [00:01, 82.32it/s]                                                                               


Train Loss: 9596.0973
Val Loss: 9629.1912
Epoch 39 of 50


938it [00:16, 56.98it/s]                                                                               
157it [00:01, 82.32it/s]                                                                               


Train Loss: 9589.6874
Val Loss: 9571.7346
Epoch 40 of 50


938it [00:16, 56.50it/s]                                                                               
157it [00:01, 81.22it/s]                                                                               


Train Loss: 9582.6654
Val Loss: 9621.4914
Epoch 41 of 50


938it [00:16, 56.11it/s]                                                                               
157it [00:01, 93.98it/s]                                                                               


Train Loss: 9582.8695
Val Loss: 9557.2184
Epoch 42 of 50


938it [00:16, 57.81it/s]                                                                               
157it [00:01, 81.18it/s]                                                                               


Train Loss: 9579.0488
Val Loss: 9565.9947
Epoch 43 of 50


938it [00:16, 57.04it/s]                                                                               
157it [00:01, 93.54it/s]                                                                               


Train Loss: 9572.7386
Val Loss: 9575.9015
Epoch 44 of 50


938it [00:16, 56.79it/s]                                                                               
157it [00:01, 92.45it/s]                                                                               


Train Loss: 9571.6927
Val Loss: 9559.6248
Epoch 45 of 50


938it [00:16, 56.59it/s]                                                                               
157it [00:01, 89.64it/s]                                                                               


Train Loss: 9567.5542
Val Loss: 9602.0897
Epoch 46 of 50


938it [00:16, 57.44it/s]                                                                               
157it [00:01, 81.35it/s]                                                                               


Train Loss: 9561.8259
Val Loss: 9586.3184
Epoch 47 of 50


938it [00:16, 56.55it/s]                                                                               
157it [00:01, 87.80it/s]                                                                               


Train Loss: 9560.3071
Val Loss: 9562.7909
Epoch 48 of 50


938it [00:16, 56.75it/s]                                                                               
157it [00:01, 85.99it/s]                                                                               


Train Loss: 9557.1179
Val Loss: 9557.8264
Epoch 49 of 50


938it [00:16, 55.94it/s]                                                                               
157it [00:02, 76.00it/s]                                                                               


Train Loss: 9552.2973
Val Loss: 9560.2611
Epoch 50 of 50


938it [00:16, 56.34it/s]                                                                               
157it [00:01, 92.34it/s]                                                                               

Train Loss: 9548.3431
Val Loss: 9580.2658



