In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Adam
import numpy as np
from PIL import Image
import sys
import torch.nn.functional as F
from math import ceil
import data_module as dm


In [4]:
# Model hyperparameters

dataset_path = '~/datasets'
cuda = 0
DEVICE = torch.device("cuda" if cuda else "cpu")
batch_size = 64
lr = 1e-3
max_epochs = 2


x_dim  = 3072
hidden_dim = 432
latent_dim = 64

In [5]:
# default version --effectively no transforms
cifar_transform = transforms.Compose([
        transforms.ToTensor(),
])


In [6]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=9, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=9, out_channels=27, kernel_size=4, stride=2, padding=1)
        self.maxpool = nn.MaxPool2d(2)

#takes input from [batch, latent]
    def forward(self, x):
        x = x.view(-1, 3, 32, 32)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.maxpool(x)
        x = x.view(-1, hidden_dim)
        return x

In [7]:
class DeConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.deconv0 = nn.Linear(in_features=432, out_features=1728) # corresponds to encoder max pool layer
        self.deconv1 = nn.ConvTranspose2d(in_channels=27, out_channels=9, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(in_channels=9, out_channels=3, kernel_size=4, stride=2, padding=1)
    def forward(self, x):

        x = self.deconv0(x)
        x = F.relu(x)
        x = x.view(-1, 27, 8, 8)
        x = self.deconv1(x)
        x = F.relu(x)
        x = self.deconv2(x)
        x = F.relu(x)
        x = x.view(-1, 3072)
        return x

In [8]:
# Custom loss tracking feature. Populates array with loss according to source. Throws error when max epochs exceeded
batch_index = 0
max_iterations = len(dm.module.train_dataloader())+3
loss_array = np.zeros(shape=(3,ceil(max_iterations*max_epochs)))
def render(self, *args):
    indexer = 0
    global batch_index
    for loss_category in args:
            try:
                loss_array[indexer][batch_index] = loss_category
                np.save('loss_array_2', loss_array)
            except IndexError:
                print('done with specified number of iterations')
                np.save('loss_array_2', loss_array)
                sys.exit()
            indexer += 1
    batch_index += 1

In [9]:
loss_array.shape
#[3, max_epochs*batches_per]

(3, 1414)

In [10]:
#The encoder models the approximate posterior distribution of the latent variables Z given the observed data X. It's denoted as q(Z|X), which is an approximation to the true posterior distribution p(Z|X).
class Encoder(nn.Module):

    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.conv = ConvNet()
        self.FC_mean  = nn.Sequential(nn.Linear(hidden_dim, latent_dim*2), nn.Linear(latent_dim*2, latent_dim))
        self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        self.ReLU = nn.ReLU(0.2)

        self.training = True

    def forward(self, x):
        x      = self.ReLU(self.conv(x))
        mean     = self.FC_mean(x)
        logvar  = self.FC_var(x)


        return mean, logvar

In [11]:
# The decoder models the likelihood of the observed data X given the latent variables Z. It's denoted as p(X|Z).
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.reconstructor = DeConvNet()

        self.ReLU = nn.ReLU()

    def forward(self, x):
        h     = self.ReLU(self.FC_hidden(x))
        x_hat     = self.ReLU(self.reconstructor(h))
        return x_hat


In [12]:
# high level model declaration. See above for layer structure
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)
        z = mean + var*epsilon
        return z

    def forward(self, x):
        mean, logvar = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * logvar)) # convert logvar to sigma
        x_hat            = self.Decoder(z)

        return x_hat, mean, logvar

In [13]:
# takes input from 3072, gives output of 192
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)
model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)

In [14]:
def loss_function(x, x_hat, mean, logvar):
    # (Q(z|x)[log P(x|z)]
    reconstruction_loss = nn.functional.mse_loss(x_hat, x, reduction='sum')
    # KL(Q(z|x) || P(z|x))
    KL_Div      = - 0.5 * torch.sum(1+ logvar - mean.pow(2) - logvar.exp())
    render(reconstruction_loss, KL_Div, (reconstruction_loss+KL_Div))
    return reconstruction_loss + KL_Div


optimizer = Adam(model.parameters(), lr=lr)

In [15]:


def save_and_upload_image(tensor_batch, filename):
    """
    Save and upload the 0th element of a batched tensor as an image.

    Args:
    tensor_batch (torch.Tensor): Batched tensor.
    filename (str): The name of the file to save the image as.
    """
    # Take the 0th element of the batch
    img_tensor = tensor_batch[0].view(3, 32, 32)

    # Convert the tensor to a NumPy array and normalize the pixel values to the range [0, 255]
    img_array = (img_tensor.detach().permute(1,2,0).numpy() * 255).astype(np.uint8)
    #img_array = (img_tensor.detach().cpu().numpy() * 255).astype(np.uint8)

    img = Image.fromarray(img_array)
    img.save(filename)


# Call the function with the batched tensor and desired filename


In [None]:
# Check hyperparams and paths before running this training loop!
data = dm.module
data.setup('train')
torch.autograd.set_detect_anomaly(True)
for epoch in range(max_epochs):
    loss = 0
    for i, (x, _) in enumerate(data.train_dataloader()):
        x = x.view(-1, x_dim)
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, logvar = model(x)
#log loss every 25 steps, save images every 1500 steps
        loss = loss_function(x, x_hat, mean, logvar)
        if i % 25 == 0:
            print('loss right now is: ', loss)
            if i % 1500 == 0:
                print('saving outmage')
                save_and_upload_image(x, "inmage.png")
                save_and_upload_image(x_hat, "outmage.png")

        loss += loss.item()

        loss.backward()
        optimizer.step()


    print("Epoch", epoch + 1, "complete!", "Average Loss: ", loss / (i*batch_size))
model.train()

loss right now is:  tensor(45629.5312, grad_fn=<AddBackward0>)
saving outmage
loss right now is:  tensor(23049.1016, grad_fn=<AddBackward0>)
loss right now is:  tensor(12782.9072, grad_fn=<AddBackward0>)
