In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import numpy as np
import torch
from torch import nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader

class Binarize:
    def __call__(self, sample):
        return torch.bernoulli(sample)

transform = Compose([ToTensor()])

dataset_train = CIFAR10('./data', train=True, transform=transform, download=True)
dataset_test = CIFAR10('./data', train=False, transform=transform, download=True)

loader_train = DataLoader(dataset_train, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)
loader_test = DataLoader(dataset_test, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)

loaders = {
    'train': loader_train,
    'test': loader_test
}

In [None]:
class ConvVAE(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
            
    def forward(self, x):
        mu, sigma = self.encoder(x)
        z = self.sample_latent(mu, sigma)
        dec = self.decoder(z)
        return dec, mu, sigma
    
    def sample_latent(self, mu, sigma):
        # if self.training:
        #     return mu
        z = torch.normal(0., 1., size=list(mu.size())).to(self.device)
        z = z * sigma + mu
        return z
    
    def sample_data(self, sample_size, sample=None):
        if sample is None:
            z = torch.normal(0., 1., size=sample_size).to(self.device)
        x_sampled = self.decoder(z)
        return x_sampled
    
    
class Encoder(nn.Module):
    def __init__(self, img_dims, noise=nn.Identity()):
        super().__init__()
        self.noise = noise
        self.normalize = nn.BatchNorm2d(img_dims[0])
        self.encoder_conv = nn.Sequential(nn.Conv2d(img_dims[0], 32, kernel_size=3, stride=1, padding=1), # 28
                                     nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 14
                                     nn.BatchNorm2d(64), nn.ReLU(),nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), # 7
                                     nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)) # 4
        self.encoder_fc = nn.Sequential(nn.Linear(4*4*128, 256), nn.BatchNorm1d(256), nn.ReLU())
        self.mu = nn.Linear(256, 256)
        self.log_sigma = nn.Linear(256, 256)
            
    def forward(self, x):
        x = self.normalize(x)
        x = self.noise(x)
        x = self.encoder_conv(x)
        x = x.flatten(start_dim=1)
        x = self.encoder_fc(x)
        mu = self.mu(x)
        log_sigma = self.log_sigma(x)
        sigma = torch.log(1 + torch.exp(log_sigma))
        return mu, sigma
    
    
class Decoder(nn.Module):
    def __init__(self, img_dims):
        super().__init__()
        self.img_dims = img_dims
        flatten_img_dims = np.prod(img_dims)
        self.decoder_fc = nn.Sequential(nn.Linear(256, 128*4*4))
        self.decoder_conv = nn.Sequential(nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
                                     nn.ConvTranspose2d(64, 64, 3, stride=2), nn.BatchNorm2d(64), nn.ReLU(),
                                     nn.ConvTranspose2d(64, 32, 3, stride=2), nn.BatchNorm2d(32), nn.ReLU(),
                                     nn.ConvTranspose2d(32, img_dims[0], 3, stride=1))
            
    def forward(self, x):
        x = self.decoder_fc(x)
        x = x.view(-1, 128, 4, 4)
        x = self.decoder_conv(x)
        x = x[:,:,:self.img_dims[-2], :self.img_dims[-1]]
        if self.img_dims[0] == 1:
            x = torch.sigmoid(x)
        return x

In [None]:
class ELBO(nn.Module):
    def __init__(self, main_criterion, beta):
        super().__init__()
        self.main_criterion = main_criterion
        self.beta = beta
        
    def forward(self, x_rec, x_true, mu, sigma):
        loss1 = self.main_criterion(x_rec, x_true)
        loss2 = self.kl_gaussian_loss(mu, sigma)
        loss = loss1 + self.beta * loss2
        return loss      
    
    def kl_gaussian_loss(self, mu, sigma):
        return 0.5 * torch.mean(sigma ** 2 + mu ** 2 - 2 * torch.log(sigma) - 1)     

In [None]:
import datetime
from tqdm.auto import tqdm
from tensorboard_pytorch import TensorboardPyTorch

def run_epoch(model, loaders, criterion, optim, writer, epoch, phase):
    running_acc = 0.0
    running_loss = 0.0
    for x_true, _ in loaders[phase]:
        x_true = x_true.to(device)
        x_rec, mu, sigma = model(x_true)
        loss = criterion(x_rec, x_true, mu, sigma)
        if 'train' in phase:
            optim.zero_grad()
            loss.backward()
            optim.step()
        running_loss += loss.item() * x_true.size(0)

    epoch_loss = running_loss / len(loaders[phase].dataset)
    writer.log_scalar(f'Loss/{phase}', round(epoch_loss, 4), epoch + 1)
    if 'test' in phase:
        writer.log_reconstructions_visualize(model, loaders[phase], epoch)
    
    
def simple_trainer(model, loaders, criterion, optim, writer, epoch_start, epoch_end):
    for epoch in tqdm(range(epoch_start, epoch_end)):
        model.train()
        run_epoch(model, loaders, criterion, optim, writer, epoch, phase='train')
        model.eval()
        with torch.no_grad():
            run_epoch(model, loaders, criterion, optim, writer, epoch, phase='test')
            

def configure_optimizers(model, optim, weight_decay=1e-4, **optim_kwargs):
    # separate out all parameters to those that will and won't experience regularizing weight decay
    decay = set()
    no_decay = set()
    whitelist_weight_modules = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.ConvTranspose2d)
    blacklist_weight_modules = (nn.LayerNorm, nn.Embedding, nn.BatchNorm1d, nn.BatchNorm2d)
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

            if pn.endswith('bias'):
                # all biases will not be decayed
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                # weights of whitelist modules will be weight decayed
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                # weights of blacklist modules will NOT be weight decayed
                no_decay.add(fpn)

    # special case the position embedding parameter in the root GPT module as not decayed
    # no_decay.add('pos_emb')

    # validate that we considered every parameter
    param_dict = {pn: p for pn, p in model.named_parameters()}
    inter_params = decay & no_decay
    union_params = decay | no_decay
    assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
    assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                % (str(param_dict.keys() - union_params), )

    # create the pytorch optimizer object
    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]
    optimizer = optim(optim_groups, **optim_kwargs)
    return optimizer

# Setting

In [None]:
EPOCHS = 100
IMG_DIMS = (3, 32, 32)

encoder = Encoder(IMG_DIMS)
decoder = Decoder(IMG_DIMS)
model = ConvVAE(encoder, decoder, device).to(device)
main_criterion = nn.MSELoss()
criterion = ELBO(main_criterion, beta=1.0).to(device)
optim = configure_optimizers(model, torch.optim.SGD, weight_decay=0.001, lr=0.01)
# optim = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)

date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = TensorboardPyTorch(f'tensorboard/ConvVAE/cifar10/mainloss:mse_optim:sgd_lr:{0.01}_wd:{0.001}_epochs:{EPOCHS}/{date}', device)

In [None]:
simple_trainer(model, loaders, criterion, optim, writer, epoch_start=0, epoch_end=EPOCHS)

In [None]:
%tensorboard --logdir=tensorboard

# Denoising Autoencoders (DAE) (Dropout or Noise)

## Dropout

In [None]:
EPOCHS = 100
IMG_DIMS = (1, 28, 28)

encoder = Encoder(IMG_DIMS, nn.Dropout(0.2))
decoder = Decoder(IMG_DIMS)
model = ConvAutoEncoder(encoder, decoder).to(device)
criterion = nn.BCELoss().to(device)
optim = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)

date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = TensorboardPyTorch(f'tensorboard/convautoencoder_shorter/fmnist_noise:dropout:{0.2}/loss:bce_optim:sgd_lr:{0.01}_wd:{0.001}_epochs:{EPOCHS}/{date}', device)

In [None]:
simple_trainer(model, loaders, criterion, optim, writer, epoch_start=0, epoch_end=EPOCHS)

In [None]:
###### how dropout works
import matplotlib.pyplot as plt
dropout = nn.Dropout(0.6)
x_true = next(iter(loader_test))[0]
plt.imshow(np.squeeze(dropout(x_true)[0].numpy()))

## Gaussian Noise

In [None]:
class GaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
EPOCHS = 100
IMG_DIMS = (1, 28, 28)

encoder = Encoder(IMG_DIMS, GaussianNoise(std=0.2))
decoder = Decoder(IMG_DIMS)
model = ConvAutoEncoder(encoder, decoder).to(device)
criterion = nn.BCELoss().to(device)
optim = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)

date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = TensorboardPyTorch(f'tensorboard/convautoencoder_shorter/fmnist_noise:gaussian:{0.2}/loss:bce_optim:sgd_lr:{0.01}_wd:{0.001}_epochs:{EPOCHS}/{date}', device)

simple_trainer(model, loaders, criterion, optim, writer, epoch_start=0, epoch_end=EPOCHS)