In [None]:
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision import transforms
from matplotlib import pyplot as plt
import os
import time
import pandas as pd
import subprocess
import os
from datetime import datetime,timedelta
from datetime import date
import urllib.request
from PIL import Image
import pickle
import albumentations as A

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(3, 256, 256)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    fig, ax = plt.subplots(figsize = (15, 5))
    im = ax.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
 
##### Declare the model architecture
d = 20

In [None]:
#play around with this architecture is pretty much arbitrary 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.conv1 = nn.Sequential(
        nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,padding=1,stride=1),
            nn.ReLU(),
        nn.MaxPool2d(3, stride=2, padding = 1),
        nn.BatchNorm2d(16),
        nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,padding=1,stride=1),
        nn.ReLU(),
        nn.BatchNorm2d(32),
        nn.MaxPool2d(3, stride=2, padding = 1),
       
        )
        
        self.fc1 = nn.Sequential(
            ### Reduce the number of channels to 1 without changing the width and dimensions of the images
            nn.Linear(64*64*32,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,64),
            
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64,32)
            
        )
 
        self.encoder = nn.Sequential(
            
            nn.Linear(32, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, d * 2)
        )
        
 
        self.decoder = nn.Sequential(
            nn.Linear(d, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, 32)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(32,64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256,256*256*1),
            nn.ReLU(),
            nn.BatchNorm1d(256*256*1)
            
        )
        self.tconv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1,out_channels=1,kernel_size=3,stride=1,padding=1),
            #nn.ReLU(),
            #nn.BatchNorm2d(1),
            #nn.ConvTranspose2d(in_channels=1,out_channels=1,kernel_size=3,stride=1,padding=1),
            nn.Sigmoid()
        
        )
 
    def reparameterise(self, mu, logvar):
        if self.training:
            ## Using log variance to ensure that we get a positive std dev
            ## Converting to std dev in the real space
            std = logvar.mul(0.5).exp_()
            ### Create error term which has the same shape as std dev sampled from a N(0,1) distribution
            eps = std.data.new(std.size()).normal_()
            #eps = torch.zeros(std.size())
            ### Add the mean and the std_dev 
            return eps.mul(std).add_(mu)
        else:
            return mu
 
    def forward(self, x):
        
        conv1_output = self.conv1(x)
        #conv1_output = self.resnet_adder(conv1_output + x) 
        fc1_output = self.fc1(conv1_output.view(-1,64*64*32))
        
        ### Convert Encoded vector into shape (N,2,d)
        mu_logvar = self.encoder(fc1_output).view(-1, 2, d)
        ### First vector for each image is mean of the latent distribution
        mu = mu_logvar[:, 0, :]
        ### Second vector for each image is log-variance of the latent distribution
        logvar = mu_logvar[:, 1, :]
        ### Create variable Z = mu + error * Std_dev
        z = self.reparameterise(mu, logvar)
        ### Get decoder output
        decoder_output = self.decoder(z)
        
        fc2_output = self.fc2(decoder_output)
        tconv1_output = self.tconv1(fc2_output.view(fc2_output.size(0),1,256,256))
        ## Resize Decoder Output to Pass it to TransposedConv2d layer to recontruct 3 channeled image
        ## Return Reconstructed Output and mean and log-variance
        return tconv1_output, mu, logvar,z

In [None]:
model = VAE()

In [None]:
learning_rate = 3e-3
recon_criterion = nn.L1Loss()
lambda_recon = 900
#MSE = nn.MSELoss(reduction='sum')
 
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate
    )
### Loss function
 
#Reconstruction + KL divergence losses summed over all elements and batch
 
def loss_function(x_hat, x, mu, logvar, step = 1):
    
    ## Making sure that distributions do not overlap
#     loss = nn.functional.binary_cross_entropy(
#         x_hat, x, reduction='sum'
#     )
    #loss = 1.0*MSE(x_hat,x)
    #loss = MSE(x_hat,x.view(x.size(0), -1))
#     BCE = nn.functional.binary_cross_entropy(
#         x_hat, x.view(-1, 256*256*3), reduction='sum'
#     )
    pix2pix_loss = recon_criterion(x_hat, x)*lambda_recon
    ### Makes sure that distributions of each image span entire latent space and the range does not explode
    KLD = 0.000005 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
    if step % 100 == 0:
      print(f"KLD term: {KLD}, pix2pix term: {pix2pix_loss}")
    

 
    return KLD + pix2pix_loss

In [None]:
dataset = []
with open(r"/content/drive/MyDrive/ai4good/train.pkl", "rb") as file:
  dataset = pickle.load(file)

In [None]:
MEAN = (0.5, 0.5, 0.5,)
STD = (0.5, 0.5, 0.5,)
SIZE = 256

class Transform():
    def __init__(self, mean=MEAN, std=STD):
        self.data_transform = transforms.Compose([ 
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        
    def __call__(self, img: Image.Image):
        return self.data_transform(img)

class Mask_Transform():
    def __init__(self):
        self.data_transform = transforms.Compose([ 
            transforms.ToTensor(),
        ])
        
    def __call__(self, img: Image.Image):
        return self.data_transform(img)

    
class Dataset(object):
    def __init__(self, data, aug):
        self.data = data 
        self.transformer = Transform()
        self.mask_transform = Mask_Transform()
        self.aug = aug
        
    
    def __getitem__(self, idx: int):
        image, mask = self.data[idx]
        t = self.aug(image = image, mask = mask)
        image = t['image']
        mask = t['mask']
        image = self.transformer(image)
        mask = self.mask_transform(mask)
        return image, mask
    
    def __len__(self):
        return len(self.data)

In [None]:
train_transforms = A.Compose([# D4 Group augmentations
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                #A.RandomBrightness(limit = 0.1, p = 0.5),                 
                A.Normalize(mean = MEAN, std = STD)
                ])

In [None]:
#use this to load images per batch slow and sometimes runs a pytorch error
#still investigating
#data_read = list(zip(img_paths, mask_paths))
#train_read = Dataset_read(data_read)
train_ds = Dataset(dataset, aug = train_transforms)

In [None]:
len(train_ds)

In [None]:
epochs = 10
codes = dict(μ=list(), logσ2=list(), y=list())
cur_step = 0
display_step = 30
input_dim = 3
out_dim = 1
target_shape = 256
batch_size = 16
dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

loss = {
    'train_loss':[],
    'test_loss' : []
}
for epoch in range(0, epochs + 1):
    #safety save
    torch.save(model.state_dict(), f'/content/drive/MyDrive/vae_{epoch}.pth')
    # Training
    if epoch > 0:  # test untrained net first
        model.train()
        train_loss = 0
        for image, mask in dataloader:
            #x = x.to(device)
            # ===================forward=====================
            fake_map, mu, logvar, z = model(image)
            loss = loss_function(fake_map, mask, mu, logvar, cur_step)
            train_loss += loss.item()
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # ===================log========================
        ### Visualization code ###
            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(f"Epoch {epoch}: Step {cur_step}: Loss = {train_loss / cur_step}")
                show_tensor_images(image, size=(input_dim, target_shape, target_shape))
                show_tensor_images(mask, size=(out_dim, target_shape, target_shape))
                show_tensor_images(fake_map, size=(out_dim, target_shape, target_shape))
                mean_generator_loss = 0
                mean_discriminator_loss = 0
            cur_step += 1
        
        #loss['train_loss'].append(train_loss)
        
    
    # Testing
    
    means, logvars, labels = list(), list(), list()
    
    '''
    with torch.no_grad():
        model.eval()
        test_loss = 0
        for x, _ in test_loader:
            #x = x.to(device)
            # ===================forward=====================
            x_hat, mu, logvar,_ = model(x)
            test_loss += loss_function(x_hat, x, mu, logvar).item()
            # =====================log=======================
            means.append(mu.detach())
            logvars.append(logvar.detach())
            #labels.append(y.detach())
    # ===================log========================
    #loss['test_loss'].append(test_loss)
    codes['μ'].append(torch.cat(means))
    codes['logσ2'].append(torch.cat(logvars))
    test_loss /= len(test_loader.dataset)
    if epoch % 1 == 0:
        print(f'====> Test set loss: {test_loss:.4f}')
    '''