<a href="https://colab.research.google.com/github/GohVh/vae/blob/main/vae_github.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load dataset

In [None]:
%rm -rf '/content/sample_data'

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cp -av 'copy your dataset from gdrive'

In [None]:
!mkdir '/content/model'
!mkdir '/content/result'

# Initialize and import package

In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
from torch import nn, Tensor
from torchvision import utils
from torchvision import transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, Subset
import os
from PIL import Image
import numpy as np
from tqdm import tqdm
from time import sleep
import shutil
import pandas as pd
from torchvision.io import read_image
from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Design DL model

## Variational Autoencoders (ResNet18Encoder-ResNet18Decoder)

In [None]:
class ResizeConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x

class ResNet18Enc(nn.Module):
  def __init__(self, z_dim=32):
    super().__init__()
    self.z_dim=z_dim
    self.ResNet18 = models.resnet18(pretrained=True)
    self.num_features = self.ResNet18.fc.in_features
    self.ResNet18.fc = nn.Linear(self.num_features, 2 * self.z_dim)

  def forward(self, x):
    x = self.ResNet18(x)
    mu = x[:, :self.z_dim]
    logvar = x[:, self.z_dim:]
    return mu, logvar

class BasicBlockDec(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = int(in_planes/stride)

        self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(in_planes)
        # self.bn1 could have been placed here, but that messes up the order of the layers when printing the class

        if stride == 1:
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential()
        else:
            self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential(
                ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn2(self.conv2(x)))
        out = self.bn1(self.conv1(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2, 2, 2, 2], z_dim=32, nc=3):
        super().__init__()
        self.in_planes = 512

        self.linear = nn.Linear(z_dim, 512)

        self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2)
        self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2)
        self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2)
        self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1)
        self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, z):
        x = self.linear(z)
        x = x.view(z.size(0), 512, 1, 1)
        x = F.interpolate(x, scale_factor=7)
        x = self.layer4(x)
        x = self.layer3(x)
        x = self.layer2(x)
        x = self.layer1(x)
        x = F.interpolate(x, size=(112,112), mode='bilinear')
        x = torch.sigmoid(self.conv1(x))
        x = x.view(x.size(0), 3, 224, 224)
        return x


In [None]:
class VAE(nn.Module):

    def __init__(self, z_dim):
        # super(VAE, self).__init__()
        super().__init__()
        self.encoder = ResNet18Enc(z_dim=z_dim)
        self.decoder = ResNet18Dec(z_dim=z_dim)

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x, mean, logvar
    
    @staticmethod
    def reparameterize(mean, logvar):
        std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two
        epsilon = torch.randn_like(std).to(device)
        return epsilon * std + mean


In [None]:
def loss_func(recon_x, x, mu, logvar):
  BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
  KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return BCE + KLD

# Customized dataloaders and checkpoint saving

In [None]:
# Model checkpoint saving and loading
def save_checkpoint(state, is_best, checkpoint_path, best_model_path):
	c_path = checkpoint_path
	torch.save(state, c_path)
	if is_best:
		best_mpath = best_model_path
		shutil.copyfile(c_path, best_mpath)

def load_model(best_model_path, checkpoint_path, model, optimizer, checkpoint_type, train_loss_key, val_loss_key):

	current_ckp = torch.load(checkpoint_path)
	best_ckp = torch.load(best_model_path)

	if checkpoint_type == 'current':
		model.load_state_dict(current_ckp['state_dict'])
		optimizer.load_state_dict(current_ckp['optimizer'])
		epoch = current_ckp['epoch']
		trainloss = current_ckp[train_loss_key]
		valloss = current_ckp[val_loss_key]
		min_valloss = best_ckp[val_loss_key]

	elif checkpoint_type == 'best':
		model.load_state_dict(best_ckp['state_dict'])
		optimizer.load_state_dict(best_ckp['optimizer'])
		epoch = best_ckp['epoch']
		trainloss = best_ckp[train_loss_key]
		valloss = best_ckp[val_loss_key]
		min_valloss = valloss

	return model, optimizer, epoch, trainloss, valloss, min_valloss

def load_checkpoint(best_model_path, checkpoint_path, model, optimizer, checkpoint_type):
    train_loss_key, val_loss_key = 'train loss','val loss'
    model, optimizer, epoch, trainloss, valloss, min_valloss = load_model(best_model_path, checkpoint_path, model, optimizer, checkpoint_type, train_loss_key, val_loss_key)
    print("optimizer = ", optimizer)
    print("start_epoch = ", epoch)
    print(f'train loss -> {trainloss}')
    print(f'val loss -> {valloss}')
    print(f'min val loss -> {min_valloss}')
    
    return model, optimizer, trainloss, valloss, min_valloss

In [None]:
# Splitting dataset for Handler option 1: Using torchvision default ImageFolder, if image are catagorized in specific folders. For example, classification tasks: dog, cat, zebra, etc.
def train_val_dataset(dataset, val_split=0.3):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
    datasets = {}
    datasets['train'] = Subset(dataset, train_idx)
    datasets['val'] = Subset(dataset, val_idx)
    return datasets

# Handler option 2: Customized Dataset Folder handler, if image are not categorized in specific folders. For example, regression tasks: house price prediction from house image.
# The handler will obtain file path from dataframe.
class CustomImageDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.list_ID = self.dataframe[['your label here']]
        self.transform = transforms.Compose([transforms.CenterCrop(224)])
    
    def __len__(self):
        return len(self.list_ID)
        
    def __getitem__(self, index):
        data_path = self.list_ID.iloc[index,0]
        X = self.transform(read_image(data_path)/255)
        return X

In [None]:
'Define directory'
# For Handler option 1:
trainroot = "/content/train/"
testroot = "/content/test/"
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.CenterCrop(224)])
traindata = ImageFolder(trainroot, transform=transform)
testdata = ImageFolder(testroot, transform=transform)

# For Handler option 2:
trainroot = '/content/train.csv'
testroot = '/content/test.csv'
traindata = CustomImageDataset(pd.read_csv(trainroot))
testdata = CustomImageDataset(pd.read_csv(testroot))

# 'Only prepare train and validation dataset, test=val'
print(f'train set: {len(traindata)}')
print(f'val/test set: {len(testdata)}')

# Saving model checkpoint
checkpoint_path = f'/content/model/checkpoint.pth'
best_model_path = f'/content/model/bestmodel.pth'

'Initialize
# log files
log_df = pd.DataFrame({'trainloss': [], 'valloss': []})
epoch_num = 50
batch_size = 10
model = ResUnet18().to(device)
# optimizer = optim.SGD(model.parameters(), lr=1e-2)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
min_valloss = np.Inf
# save best model using validation loss

train_loader = DataLoader(traindata,batch_size, shuffle=True)
test_loader = DataLoader(testdata,batch_size, shuffle=True)

train set: 2280
val/test set: 1120


In [None]:
'training status key'
# 'init -> model training start from epoch 1
# 'continue' -> model training start from last stopped epoch

'checkpoint_type'
# 'current' -> model training start from epoch 1 ('init), or start from last stopped epoch
# 'best' -> model training start from last saved checkpoint with best validation loss. (used for test)
training_status = 'init'
checkpoint_type = 'current'
# epoch_num = 33

In [None]:
if training_status == 'continue':
  model, optimizer, trainloss, valloss, min_valloss = load_checkpoint(best_model_path, checkpoint_path, model, optimizer, checkpoint_type=checkpoint_type)

# Train

In [None]:
'train and val in every epoch'

for epoch in range(epoch_num):
  cumutrainloss, cumuvalloss = 0,0
  model.train()
  
  with tqdm(train_loader) as tepoch:
    for x in tepoch:

      tepoch.set_description(f"Epoch {epoch+1}")
      x = Variable(x).to(device)

      # Clear gradients
      optimizer.zero_grad()
      recon_x, mu, logvar = model.forward(x)
      loss = loss_func(recon_x, x, mu, logvar)
      cumutrainloss += loss
      loss.backward()
      optimizer.step()

      tepoch.set_postfix(train_loss=loss.item())
      sleep(0.1)
  
  scheduler.step()

  with torch.no_grad():
    model.eval()
    with tqdm(test_loader) as vepoch:

      for x in vepoch:

        x = Variable(x).to(device)
        vresult = model.forward(x)
        vresult, vmu, vlogvar = model.forward(x)
        loss = loss_func(vresult, x, vmu, vlogvar)

        cumuvalloss += loss.item()

        vepoch.set_postfix(val_loss=loss.item())
        sleep(0.1)

  trainloss = cumutrainloss/len(train_loader)
  trainloss = trainloss.clone().detach().cpu()
  valloss = cumuvalloss/len(test_loader)

  log_df = log_df.append({'trainloss': trainloss, 'valloss': valloss}, ignore_index=True)
  
  if (epoch+1)%1 == 0:
    print(f'Epoch [{epoch+1}/{epoch_num}]: Train loss= {trainloss:.4f}, Val loss= {valloss:.4f}')
    checkpoint = {
        'epoch': epoch + 1,
        'train loss': trainloss,
        'val loss': valloss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()}
  
  save_checkpoint(checkpoint, False, checkpoint_path, best_model_path)
  
  if valloss <= min_valloss:
    print(f'Train loss decreased ({min_valloss:.4f} --> {valloss:.4f}). Saving model ...')
    # save checkpoint as best model
    save_checkpoint(checkpoint, True, checkpoint_path, best_model_path)
    min_valloss = valloss

  log_df.astype('float32').to_csv(f'/content/log.csv')
  
print('Finished training')

# Test

In [None]:
# Test using validation set, could prepare another test set as an option.
i = 0
with torch.no_grad():
  for x in test_loader:
    x = Variable(x).to(device)
    vresult, vmu, vlogvar = model.forward(x)
    utils.save_image(vresult.data, f'/content/result/{i}.png', normalize=True)
    i+=1