In [None]:
import torch
from torch import optim
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
from torchviz import make_dot
from torch.utils.data import DataLoader
from pathlib import Path
import pandas as pd

from autoencoder import AutoEncoder
import pickle

# https://stackoverflow.com/questions/8223811/a-top-like-utility-for-monitoring-cuda-activity-on-a-gpu

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
data_path = "/home/amal/UbuntuDocuments/data/torch_datasets"

train_data = datasets.CelebA(data_path, split="train", transform=transforms.PILToTensor(), download=True)
validation_data =  datasets.CelebA(data_path, split="valid", transform=transforms.PILToTensor(), download=True)

In [None]:
def validate(
    model: torch.nn.Module,
    val_dataloader: DataLoader,
    device,
):
    val_loss = []
    model.eval()
    with torch.no_grad():
        for indx, batch in enumerate(val_dataloader):
            input_image = batch[0]
            input = torch.tensor(input_image/255, dtype=torch.float).to(device)
            output = model(input)

            loss = loss_func(output, input)
            val_loss.append(loss.item())
    return np.mean(val_loss)


def train_epoch(
    model: torch.nn.Module,
    optimizer: torch.optim,
    train_loader: DataLoader,
    loss_fn,
    device,
):
    # specifcy training mode
    model.train()
    batch_loss = []
    for indx, batch in enumerate(train_loader):
    
        if indx % 25 == 0:
            print(f"running index: {indx}")
        
        input_image = batch[0]
        input = torch.tensor(input_image/255, dtype=torch.float).to(device)
        output = model(input)
    
        loss = loss_func(output, input)
        batch_loss.append(loss.item())
        loss.backward()
        opt.step()
        opt.zero_grad()

    return batch_loss


In [None]:
batch_size = 256
latent_dim = 200
train_dataloader = DataLoader(train_data, batch_size=batch_size)
val_dataloader = DataLoader(validation_data, batch_size=batch_size)
loss_func = nn.MSELoss(reduction = "mean")

N_epochs = 15
out_path = Path("/home/amal/UbuntuDocuments/projects/generative_modelling/saved_models")

load_model = False

In [None]:
model = AutoEncoder(latent_dim=latent_dim).to(device)
opt = optim.Adam(model.parameters())
if load_model:
    checkpoint_path = out_path / f"autoencoder_epoch_14.pth"
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    opt.load_state_dict(checkpoint['optimizer_state_dict'])


In [None]:
if Path("losses.pkl").is_file():
    with open("losses.pkl", "rb") as f:
        losses = pickle.load(f)
else:
    losses = []


batch_loss = []


start_epoch = len(losses)

TRAIN_MODEL = False

if TRAIN_MODEL:
    for epoch in range(N_epochs):

        epoch = start_epoch+epoch
        
        print(f"---\nRunning epoch {epoch + 1}")
        
        b_losses = train_epoch(
            model,
            opt,
            train_dataloader,
            loss_func,
            device
        )
        epoch_loss = np.mean(b_losses)
        val_loss = validate(
            model,
            val_dataloader,
            device,
        )
    
        out = {
            "epoch_loss" : epoch_loss,
            "val_loss" : val_loss
        }
        losses.append(out)
        batch_loss = batch_loss + b_losses
        out_file = out_path / f"autoencoder_epoch_{epoch}.pth"
        torch.save(
            {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'train_loss': epoch_loss,
            'val_loss' : val_loss    
            },
            out_file
        )

    with open("losses.pkl", "wb") as f:
        pickle.dump(losses, f)




In [None]:
plt.plot(batch_loss,)
plt.yscale("log")

In [None]:
loss_df = pd.DataFrame(losses)

In [None]:
loss_df

In [None]:
plt.plot(loss_df.epoch_loss, label = "train_loss")
plt.plot(loss_df.val_loss, label = "val loss")
plt.yscale("log")
plt.legend()

## Load checkpoint and evalute model

In [None]:
model = AutoEncoder(latent_dim=latent_dim)
model.to(device)
checkpoint_path = out_path / f"autoencoder_epoch_14.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

In [None]:
val_index = 5

test = torch.tensor(validation_data[val_index][0]/255, dtype=torch.float).to(device)

In [None]:
test_out = model(test.unsqueeze(0))

In [None]:
test_out.shape

In [None]:
plt.imshow(test_out.detach().cpu().squeeze().permute(1, 2, 0))

In [None]:
plt.imshow(validation_data[val_index][0].permute(1, 2, 0))

In [None]:
model.recursive_apply(
            (218, 178), model.conv_output_shape, 5
        )

In [None]:
7*6*128

In [None]:
14*12*128, 7*6*256