In [None]:
import torch
import torch.nn as nn
import torchvision as tv
from torch.utils.data import DataLoader

import numpy as np
import random
import os
from pathlib import Path
import matplotlib.pyplot as plt

from typing import Tuple, List

%matplotlib inline

In [None]:
def get_mnist_ds(is_train:bool):
    return tv.datasets.MNIST(root = Path('mnistdata'),
                             train = is_train,
                             transform = tv.transforms.ToTensor(),
                             target_transform = None,
                             download = True
                            )

In [None]:
def get_mnist_ds_loader(batch_size):
        train_ds = get_mnist_ds(is_train = True)
        valid_ds = get_mnist_ds(is_train = False)
        
        return (DataLoader(train_ds, batch_size, shuffle = True),
                DataLoader(valid_ds, 2*batch_size, shuffle = False))
    

In [None]:
train_dl, valid_dl = get_mnist_ds_loader(32)

In [None]:
len(train_dl.dataset), len(valid_dl.dataset)

In [None]:
train_dl.dataset.data.shape

In [None]:
for i, x in enumerate(train_dl):
    print(x[0].shape)
    print(x[1].shape)
    plt.imshow(x[0][0][0],cmap='gray')
    plt.show()
    if i>3 : break

In [None]:
class Encoder(nn.Module):
    def __init__(self, inp_size: Tuple[int,int], hidden_size: int, out_size: int):
        
        super().__init__()
        
        self.enc = nn.Sequential(nn.Flatten(),
                                 nn.Linear(np.prod(inp_size), hidden_size),
                                 nn.ReLU(),
                                 nn.Linear(hidden_size, out_size))
        
    def forward(self, x):
        return self.enc(x)
        

In [None]:
class Decoder(nn.Module):
    def __init__(self, inp_size: int, hidden_size: int, out_size: Tuple[int,int]):
        
        super().__init__()
        self.out_size = out_size
        self.dec = nn.Sequential(nn.Linear(inp_size, hidden_size),
                                 nn.ReLU(),
                                 nn.Linear(hidden_size, np.prod(out_size)),
                                 nn.Sigmoid())
        
    def forward(self, x):
        x = self.dec(x)
        return x.view((x.shape[0],*self.out_size))

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, inp_size: Tuple[int,int], hidden_size: int, out_size: int):
        super().__init__()
        
        self.enc = Encoder(inp_size, hidden_size, out_size)
        self.dec = Decoder(out_size, hidden_size, inp_size)
        
    def forward(self, x):
        out = self.enc(x)
        out = self.dec(out)
        
        return out

In [None]:
ae = AutoEncoder(inp_size=(28,28),hidden_size=512,out_size=20)
ae

In [None]:
x, y = valid_dl.dataset[30]

In [None]:
print(x.shape)
print(x.requires_grad)

In [None]:
with torch.no_grad():
    out = ae(x)
print(out.shape)
print(out.requires_grad)

In [None]:
out = ae(x)
print(out.shape)
print(out.requires_grad)

In [None]:
plt.imshow(x[0], cmap='gray')

In [None]:
plt.imshow(out[0].detach(),cmap='gray')

In [None]:
def show_summary(valid_dl: DataLoader, model: nn.Module):
    ELEM_NUM = 10
    
    device = model.parameters().__next__().device
    ae.eval()

    actual_list = [x for i, (x, y) in enumerate(valid_dl.dataset) if i < ELEM_NUM]
    
    with torch.no_grad():
        actuals_batch = torch.cat(actual_list).unsqueeze(1)
        reconst_batch = ae(actuals_batch.to(device)).cpu().unsqueeze(1)
    
    ae.train()
    
    grid_elems = [*actuals_batch, *reconst_batch]

    grid = tv.utils.make_grid(grid_elems, nrow=ELEM_NUM, padding=1, pad_value =1)
  
    plt.figure(figsize=(15,15))
    plt.imshow(grid.permute(1,2,0))
    plt.axis('off')
    plt.show()

In [None]:
def show_summary_1(valid_dl:DataLoader, model: nn.Module):
    ELEM_NUM = 15
    
    device = model.parameters().__next__().device
    
    ae.eval()
    
    actuals, reconst = [], []
    
    with torch.no_grad():
        for i, (x, y) in enumerate(valid_dl.dataset):
            actuals.append(x)
            reconst.append(ae(x.to(device)).detach().cpu())
            if i == ELEM_NUM-1: break
    
    ae.train()
    
    grid_elems = [*actuals, *reconst]
    grid = tv.utils.make_grid(grid_elems, nrow=ELEM_NUM, padding=1,pad_value=1)

    plt.figure(figsize=(15,15))
    plt.imshow(grid.permute(1,2,0))
    plt.axis('off')
    plt.show()

In [None]:
%%time
show_summary(valid_dl,ae)

In [None]:
%%time
show_summary_1(valid_dl,ae)

In [None]:
def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
seed_everything()

In [None]:
device = torch.device('cpu')
ae = ae.to(device)

In [None]:
def init_params(m:nn.Module):
    if type(m) == nn.Linear:
      nn.init.orthogonal_(m.weight.data)
      nn.init.zeros_(m.bias.data)
    

In [None]:
mse_loss = nn.MSELoss()
optim = torch.optim.Adam(ae.parameters(), lr = 0.001)

In [None]:
ae.apply(init_params)
total_epochs = 5

LOG_INTERVAL = 10
SUMMARY_INTERVAL = 20

acc_cost = 0

for epoch in range(0,total_epochs):
    for i, (x, y) in enumerate(train_dl):
        optim.zero_grad()
        x = x.to(device)
        target = ae(x).unsqueeze(1)
        cost = mse_loss(x, target)
        acc_cost += cost.item()
        cost.backward()
        optim.step()
        
        if (i%LOG_INTERVAL) == 0:
            print(f"epoch {epoch+1} | iter {i} | acc_cost {acc_cost/LOG_INTERVAL:.4f} | cost {cost:.4f}")
            acc_cost = 0       
        
        if (i%SUMMARY_INTERVAL)==0:
            show_summary_1(valid_dl,ae)

In [None]:
print(ae)
for k, val in ae.state_dict().items():
    print(f'key: {k} val.shape: {val.shape}')
    
print('\n',ae.state_dict()['enc.enc.1.weight'])

In [None]:
#torch.save(ae.state_dict(),'ae.pt')

In [None]:
ae.load_state_dict(torch.load('ae.pt'))

In [None]:
ae.eval()
z, labels = [], []
for x, y in valid_dl:
    print(x.shape)
    print(y.shape)
    z = torch.cat(ae.enc(x))
    break