# Variational Autoencoder with Transfer Learning

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import torch


import numpy as np
import matplotlib.pyplot as plt
from data import get_img_dataset
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
from torchvision.datasets import ImageFolder
from project3Lib.transforms import EnhanceContrast
from masked_dataset import MaskedDataset
from pathlib import Path

from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD, RMSprop, lr_scheduler
import tqdm
from sklearn.metrics import accuracy_score, f1_score, plot_confusion_matrix, confusion_matrix, ConfusionMatrixDisplay

from project3Lib.VAE import * 
import project3Lib.utils as utils


import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)


In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Device state:', device)

## Loading variational autoencoder training data (larger dataset for trasnfer learning)

In [None]:
# Load Data
train_transform = []
train_transform += [
            transforms.Resize(128),             # resize shortest side to 128 pixels
            transforms.CenterCrop(128),         # crop longest side to 128 pixels at center
            transforms.ToTensor(),               # convert PIL image to tensor
            EnhanceContrast(reduce_dim=False), 
            #transforms.Grayscale()
    ]



transform = [ transforms.RandomRotation(90), transforms.RandomHorizontalFlip(), transforms.ColorJitter() ]
transform+= train_transform
transform = transforms.Compose(transform)

train_transform = transforms.Compose(train_transform)

dataset = ImageFolder(root='./data/tl_dataset', transform=transform)
dataset2 = ImageFolder(root='./data/tl_dataset', transform=train_transform)

dataset = ConcatDataset([dataset,dataset2] )
batch_size = 16
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Classes for Variational Autoencoder

In [None]:
vae = VariationalAutoencoder(imgChannels=3).to(device) # GPU
vae.train(dataloader, epochs= 30, save='trained_weights/vae_trans_final.torch')

In [None]:
vae_loaded  = VariationalAutoencoder(imgChannels=3)
vae_loaded.load_state_dict(torch.load('trained_weights/vae_trans_final.torch'))
vae_loaded  =vae_loaded.to(device)

In [None]:
plot_latent(vae_loaded, dataloader)

## Perturbed Reconstructions

We now take a sample image, get its latent space embedding perturb specific dimension combinations in this embedding and recosntruct images from perturbed embeddings. 30 such reconstruction spaces are given below:

In [None]:
for i in range(30):
    plot_reconstructed(vae_loaded, dataloader, r0=(-40, 40), r1=(-40, 40), n=10, dims=(i*16,i*16+1))

# Constructing classifier using trained encoder

## Loading project dataset 

In [None]:
#contrast enhancing
only_enhance = [EnhanceContrast(reduce_dim=False) 
               #,transforms.Grayscale()
               ]

#more transformations to increase dataset size and variety
transform = [ transforms.RandomRotation(90), transforms.RandomHorizontalFlip(), transforms.ColorJitter() ]
transform+= only_enhance

#using only unique images
input_path = "data/unique_images"

#concatenating datasets with and without transformations
train_dataset,val_dataset, test_dataset = get_img_dataset(only_enhance, data_path=input_path, use_same_transforms = True)
train_dataset2,val_dataset2, _ = get_img_dataset(transform,data_path=input_path, use_same_transforms = True)

train_dataset = ConcatDataset([train_dataset,train_dataset2, train_dataset2] )
val_dataset = ConcatDataset([val_dataset,val_dataset2,val_dataset2])

In [None]:
batch_size = 16

# Data Loaders
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
validloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
print(f"Class sizes{np.unique([y for x,y in train_dataset], return_counts = True)}")

In [None]:
model = Encoder_Classifier(imgChannels=1, vae_path='trained_weights/vae_trans_final.torch' ).to(device)

epochs = 30

#first train on large dataset
model.traindata(dataloader, validloader, epochs=epochs)

#then continue training on small dataset
model.traindata(trainloader, validloader, epochs=epochs, save="trained_weights/vae_transfer_fc_final.pt")


In [None]:
model = Encoder_Classifier(imgChannels=3, vae_path='trained_weights/vae_trans_final.torch' )
model.load_state_dict(torch.load('trained_weights/vae_transfer_fc_final.pt'))
model  =model.to(device)


In [None]:
x_test = [i for i,j in test_dataset]
y_test = [j for i,j in test_dataset]
preds = []
outs = []
for t in x_test:
    pred, out = model.predict( t.to(device))
    preds.append(pred)
    
print(f"Accuracy: {accuracy_score(preds,y_test)}")
print(f"F1 score: {f1_score(preds,y_test)}")

In [None]:
cm=confusion_matrix(y_test,preds,normalize="true")
cmd = ConfusionMatrixDisplay(cm)
cmd.plot()

# Interpretability Methods

## Loading Masked Test Set

In [None]:
test_dataset_nomasks = test_dataset
common_transform = [EnhanceContrast(reduce_dim=False)]
_,_, test_dataset_mask = get_img_dataset(common_transforms=common_transform, \
                                        data_path=input_path, \
                                        folder_type = MaskedDataset, \
                                        mask_folder=Path("data/masks"))

# Getting Shaply features for the encoder based classifier

In [None]:
import shap
import numpy as np

In [None]:
def get_only_oneclass(target, dataset):
    return [(i,j) for i,j in dataset if j == target]

In [None]:
# Deep Explainer
bg = [i for i,j in train_dataset]
bg = torch.stack(bg).to(device)
e = shap.DeepExplainer(model, bg)
outs = []
for i in bg:
    pred, out = model.predict(i.to(device))
    outs.append((out[0][0].item(), out[0][1].item()))
print(f"Mean values {np.mean([i for i,j in outs])}, {np.mean([j for i,j in outs])}")

In [None]:
ious = []
for i, (image,mask,target) in enumerate(test_dataset_mask):
    image = image.reshape((1,3,128,128))
    pred, out = model.predict(image.to(device).squeeze())
    
    shap_values = e.shap_values(image)
    shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
    test_numpy = np.swapaxes(np.swapaxes(image.cpu().numpy(), 1, -1), 1, 2)
    print(f"Image #{i}: True Class {target}, Prediction {pred}, Probabilities {out}")
    
    shap.image_plot(shap_numpy, test_numpy)
    
    predicted_mask = np.copy(shap_values[1].reshape(3,128,128))
    mask = mask.reshape((128,128))
    mask = torch.stack([mask, mask,mask])
    pixels = int(np.sum(mask.numpy().flatten()))
    iou = utils.evaluate_interpretability(predicted_mask, mask,pixels)
    print(iou)
    if target == 1:
        ious.append(iou)
    if i == 0:
        np.save("Plots/VAE_SHAP_0", predicted_mask)
    if i == 1:
        np.save("Plots/VAE_SHAP_1", predicted_mask)
print(f"Mean IOU: {np.mean(ious)}")

# Integrated Gradients with Captum

In [None]:
model.encoder.encConv6 # choosing conv layer for grad cam -- layer 4 chosen

In [None]:
ious = []
for i, (image,mask,target) in enumerate(test_dataset_mask):
    data = (image,target)
    a, b = utils.plot_grads(data,model, idx = -1,plot=False,grad_type= "integ_grads")
    if target == 1:
        class_1 = a
    else:
        class_1 = b
    predicted_mask = np.copy(class_1.reshape(3,128,128).cpu())
    mask = mask.reshape((128,128))
    mask = torch.stack([mask, mask,mask])
    pixels = int(np.sum(mask.numpy().flatten()))
    iou = utils.evaluate_interpretability(predicted_mask, mask,pixels)
    print(iou)
    if target == 1:
        ious.append(iou)
    if i == 0:
        np.save("Plots/VAE_IntGrad_0", predicted_mask)
    if i == 1:
        np.save("Plots/VAE_IntGrad_1", predicted_mask)
print(f"The mean iou is {np.mean(ious)}")

In [None]:
utils.plot_grads_dataloader(test_dataset, model, grad_type= "integ_grads" , plot=True, save_name="vae")

# Grad Cam

In [None]:
ious = []
for i, (image,mask, target) in enumerate(test_dataset_mask):
    data = (image,target)
    a, b = utils.plot_grads(data,model, layer=model.encoder.encConv2, plot=False,grad_type= "grad_cam")
    if target ==1:
        class_1 = a
    else:
        class_1 = b
    predicted_mask = np.copy(class_1.detach().cpu().numpy().reshape(128,128))
    mask = mask.reshape((128,128))
    #mask = torch.stack([mask, mask,mask])
    pixels = int(np.sum(mask.numpy().flatten()))
    iou = utils.evaluate_interpretability(predicted_mask, mask,pixels)
    print(iou)
    if target == 1:
        ious.append(iou)
    if i == 0:
        np.save("Plots/VAE_GradCam_0", predicted_mask)
    if i == 1:
        np.save("Plots/VAE_GradCam_1", predicted_mask)
print(f"The mean iou is {np.mean(ious)}")

In [None]:
utils.plot_grads_dataloader(test_dataset, model, layer=model.encoder.encConv2, grad_type= "grad_cam" ,plot=True)