<a href="https://colab.research.google.com/github/JohannesMDr/vae/blob/master/vae_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# README
* ref: https://github.com/hsinyilin19/ResNetVAE


In [0]:
import os
import numpy as np
from PIL import Image
from torch.utils import data
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
import torchvision.transforms as transforms

# dataset

In [0]:
class Dataset(data.Dataset):
    "Characterizes a dataset for PyTorch"
    def __init__(self, filenames, labels, transform=None):
        "Initialization"
        self.filenames = filenames
        self.labels = labels
        self.transform = transform

    def __len__(self):
        "Denotes the total number of samples"
        return len(self.filenames)


    def __getitem__(self, index):
        "Generates one sample of data"
        # Select sample
        filename = self.filenames[index]
        X = Image.open(filename)

        if self.transform:
            X = self.transform(X)     # transform

        y = torch.LongTensor([self.labels[index]])
        return {
            'input': inp,
            'recon': inp
        }
    
    @classmethod
    def show_x(cls, ax, sample):

    @classmethod
    def show_y(cls, ax, sample):

    def show_sample(self, index, show_y=False, figsize=(16,8)):
        plt.show()

In [0]:
def make_dataset(...):
    return Dataset(indices[val_length:]), Dataset(indicies[:val_length])

trainset, valset = make_dataset()

# model

In [0]:
def conv2D_output_size(img_size, padding, kernel_size, stride):
    # compute output shape of conv2D
    outshape = (np.floor((img_size[0] + 2 * padding[0] - (kernel_size[0] - 1) - 1) / stride[0] + 1).astype(int),
                np.floor((img_size[1] + 2 * padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1).astype(int))
    return outshape

def convtrans2D_output_size(img_size, padding, kernel_size, stride):
    # compute output shape of conv2D
    outshape = ((img_size[0] - 1) * stride[0] - 2 * padding[0] + kernel_size[0],
                (img_size[1] - 1) * stride[1] - 2 * padding[1] + kernel_size[1])
    return outshape

In [0]:
class ResNet_VAE(nn.Module):
    def __init__(self, fc_hidden1=1024, fc_hidden2=768, drop_p=0.3, CNN_embed_dim=256):
        super(ResNet_VAE, self).__init__()

        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim

        # CNN architechtures
        self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
        self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3)      # 2d kernal size
        self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2)      # 2d strides
        self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0)  # 2d padding

        # encoding components
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.fc1 = nn.Linear(resnet.fc.in_features, self.fc_hidden1)
        self.bn1 = nn.BatchNorm1d(self.fc_hidden1, momentum=0.01)
        self.fc2 = nn.Linear(self.fc_hidden1, self.fc_hidden2)
        self.bn2 = nn.BatchNorm1d(self.fc_hidden2, momentum=0.01)
        # Latent vectors mu and sigma
        self.fc3_mu = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)      # output = CNN embedding latent variables
        self.fc3_logvar = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)  # output = CNN embedding latent variables

        # Sampling vector
        self.fc4 = nn.Linear(self.CNN_embed_dim, self.fc_hidden2)
        self.fc_bn4 = nn.BatchNorm1d(self.fc_hidden2)
        self.fc5 = nn.Linear(self.fc_hidden2, 64 * 4 * 4)
        self.fc_bn5 = nn.BatchNorm1d(64 * 4 * 4)
        self.relu = nn.ReLU(inplace=True)

        # Decoder
        self.convTrans6 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=self.k4, stride=self.s4,
                               padding=self.pd4),
            nn.BatchNorm2d(32, momentum=0.01),
            nn.ReLU(inplace=True),
        )
        self.convTrans7 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=8, kernel_size=self.k3, stride=self.s3,
                               padding=self.pd3),
            nn.BatchNorm2d(8, momentum=0.01),
            nn.ReLU(inplace=True),
        )

        self.convTrans8 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=8, out_channels=3, kernel_size=self.k2, stride=self.s2,
                               padding=self.pd2),
            nn.BatchNorm2d(3, momentum=0.01),
            nn.Sigmoid()    # y = (y1, y2, y3) \in [0 ,1]^3
        )


    def encode(self, x):
        x = self.resnet(x)  # ResNet
        x = x.view(x.size(0), -1)  # flatten output of conv

        # FC layers
        x = self.bn1(self.fc1(x))
        x = self.relu(x)
        x = self.bn2(self.fc2(x))
        x = self.relu(x)
        # x = F.dropout(x, p=self.drop_p, training=self.training)
        mu, logvar = self.fc3_mu(x), self.fc3_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        x = self.relu(self.fc_bn4(self.fc4(z)))
        x = self.relu(self.fc_bn5(self.fc5(x))).view(-1, 64, 4, 4)
        x = self.convTrans6(x)
        x = self.convTrans7(x)
        x = self.convTrans8(x)
        x = F.interpolate(x, size=(224, 224), mode='bilinear')
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_reconst = self.decode(z)

        return {
            'recon': x_reconst,
            'z': z,
            'mu': mu,
            'logvar': logvar
        }

In [0]:
def KL_div(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

class Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.crit_bce = F.binary_cross_entropy()
        self.crit_kld = KL_div
    def forward(self, pred, batch):
        loss = {}
        loss['bce'] = self.crit_bce(pred['recon'], batch['recon'], reduction='sum')
        loss['kld'] = self.crit_kld(pred['mu'], pred['logvar'])
        loss_all = loss['bce'] + loss['kld']
        return loss_all, loss

In [0]:
class Trainer:
    def __init__(self, model, opt, train_ld, val_ld, loss_fn=Loss(), device='cuda'):
        self.model = model.to(device)
        self.loss_fn = loss_fn
        self.opt = opt
        self.train_ld = train_ld
        self.val_ld = val_ld
        self.device = device

    def epoch(self, pbar, training=True):
        n_samples = 0
        stat_log = defaultdice(float)
        loss_log = 0
        for batch in pbar:
            bs = batch['input'].size(0)
            n_samples += bs
            for k in batch:
                batch[k] = F.pad(batch[k].to(self.device), (0,1))
            x = F.pad(batch['input'], (0,3))
            y = self.model(x)
            loss, stat = self.loss_fn(y, batch)
            if training:
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
            
            loss = loss.detach().cpu().numpy()
            loss_log += loss
            for k in stat:
                stat_log[k] += stat[k].detach().cpu().numpy()
            pbar.comment = '{:.4f}'.format(loss)
        for k in stat:
            stat_log[k] /= n_samples
        loss_log /= n_samples
        return loss_log, stat_log

    def train(self, mb):
        self.model.train()
        pbar = progress_bar(self.train_ld, parent=mb)
        return self.epoch(pbar, training=True)

    def val(self, mb):
        self.model.eval()
        with torch.no_grad():
            pbar = progress_bar(self.val_ld, parent=mb)
            return self.epoch(pbar, training=False)

    def show_results(self, num=3, figsize=(8,8)):
        fig, ax = plt.subplots(num, 2, figsize=figsize)
        if num == 1:
            ax = ax.reshape(1,-1)
        self.model.eval()
        show_x = self.val_ld.dataset.show_x
        show_y = self.val_ld.dataset.show_y
        n = 0
        with torch.no_grad():
            for batch in self.val_ld:
                bs = batch['input'].size(0)
                for k in batch:
                    batch[k] = F.pad(batch[k].to(self.device), (0,1))
                x = F.pad(batch['input'], (0,3))
                y = self.model(x)
                for b in range(bs):
                    sample_inp = {'input': batch['input'][b]}
                    sample_pred = {k:v[b] for k,v in y.items()}
                    sample_batch = {k:v[b] for k,v in batch.items()}
                    ax[n,0].set_title("pred")
                    show_x(ax[n,0], sample_inp)
                    show_y(ax[n,0], sample_pred)
                    ax[n,1].set_title("actual")
                    show_x(ax[n,1], sample_inp)
                    show_y(ax[n,1], sample_batch)
                    n += 1
                    if n >= num:
                        break
                if n >= num:
                    break
        plt.show()

In [0]:
model = VAE(///)
trainset, valset = make_dataset()
trainloader = data.DataLoader(trainset, batch_size=8, shuffle=True, drop_last=True)
valloader = data.DataLoader(valset, batch_size=8, shuffle=False)

# freeze
train_bn = True
flat_backbone = flatten_model(model.backbone)
for layer in flat_backbone:
    if train_bn and isinstance(layer, bn_types): continue
    for param in layer.parameters():
        param.requires_grad = False
    
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    base_lr,
    weight_decay=1e-3)

trainer = Trainer(model, optimizer, trainloader, valloader)

In [0]:
trainer.show_results(3, figsize=(12,8))

In [0]:
epoch = 40
drop = [10]
mb = master_bar(range(1, epoch+1))
mb.names = ['train', 'val']

tloss_list = []
vloss_lis = []
best_vloss = None
for e in mb:
    tloss, stat = trainer.train(mb)
    tloss_list.append(tloss)
    stat_str = ""
    for k in stat:
        stat_str += ", {}:{:.3e}".format(k, stat[k])
    mb.write("train {}, loss {:.4f}".format(e, tloss) + stat_str)
    vloss, stat = trainer.val(mb)
    vloss_list.append(vloss)
    stat_str = ""
    for k in stat:
        stat_str += ", {}:{:.3e}".format(k, stat[k])
    mb.write("val {}, loss {:.4f}".format(e, tloss) + stat_str)
    if (best_vloss is None) or (best_vloss > vloss):
        best_vloss = vloss
        torch.save(model.state_dict(), "~~/best_model")
    if e in drop:
        lr = base_lr * (0.2 ** (drop.index(e)+1))
        mb.write('Drop LR to {}'.format(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

In [0]:
plt.plot(np.arange(1,epoch+1), tloss_list, label="train")
plt.plot(np.arange(1,epoch+1), vloss_list, label="val")
plt.legend()
plt.savefig(path_model + '/history.png')
plt.show()

In [0]:
trainer.show_results(3, figsize=(12,8))