In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import os
import pickle
from tqdm import notebook, tqdm

In [2]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/preproc-vae-notebook/valid_sq.pickle
/kaggle/input/preproc-vae-notebook/custom.css
/kaggle/input/preproc-vae-notebook/__notebook__.ipynb
/kaggle/input/preproc-vae-notebook/__results__.html
/kaggle/input/preproc-vae-notebook/__output__.json
/kaggle/input/preproc-vae-notebook/train_sq.pickle
/kaggle/input/preproc-vae-notebook/test_sq.pickle


In [3]:
with open('/kaggle/input/preproc-vae-notebook/train_sq.pickle', 'rb') as train_pickle:
    train_data = 127 - pickle.load(train_pickle)
    
with open('/kaggle/input/preproc-vae-notebook/valid_sq.pickle', 'rb') as valid_pickle:
    valid_data = 127 - pickle.load(valid_pickle)

with open('/kaggle/input/preproc-vae-notebook/test_sq.pickle', 'rb') as test_pickle:
    test_data = 127 - pickle.load(test_pickle)

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

class MultiResBlock(nn.Module):
    def __init__(self,ch_in,ch_out):
        super().__init__()
        self.conv1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, padding=0)
        self.bnorm1x1 = nn.BatchNorm2d(ch_out,track_running_stats=False)
        self.fconv = nn.Conv2d(ch_in, ch_out//6, kernel_size=3, padding=1)
        self.fbnorm = nn.BatchNorm2d(ch_out//6,track_running_stats=False)
        self.sconv = nn.Conv2d(ch_out//6, ch_out//3, kernel_size=3, padding=1)
        self.sbnorm = nn.BatchNorm2d(ch_out//3,track_running_stats=False)
        self.tconv = nn.Conv2d(ch_out//3, ch_out//2+1, kernel_size=3, padding=1)
        self.tbnorm = nn.BatchNorm2d(ch_out//2+1,track_running_stats=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self,x):
        res1x1 = self.sigmoid(self.bnorm1x1(self.conv1x1(x)))
        #print("res1x1 done")
        first = self.sigmoid(self.fbnorm(self.fconv(x)))
        #print("fconv done")
        second = self.sigmoid(self.sbnorm(self.sconv(first)))
        third = self.sigmoid(self.tbnorm(self.tconv(second)))
        resconv = torch.cat((first,second,third),dim=1)
        y = res1x1+resconv
        return y

class VAE(nn.Module):
    def __init__(self, num_neurons, num_neurons2, num_neurons3, kernel_size, pool_size, latent_dim):
        super().__init__()
        
        self.mrb1 = MultiResBlock(1, num_neurons)
        self.mrb2 = MultiResBlock(num_neurons, num_neurons2)
        self.mrb3 = MultiResBlock(num_neurons2, num_neurons3)
        self.mrb3b = MultiResBlock(num_neurons3, num_neurons3)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up = nn.Upsample(scale_factor=2)

        self.mrb4 = MultiResBlock(num_neurons3, num_neurons2)
        self.mrb5 = MultiResBlock(num_neurons2, num_neurons)
        self.mrb6 = MultiResBlock(num_neurons, num_neurons)
        self.mrb7 = MultiResBlock(num_neurons, num_neurons)
        self.conv1x1 = nn.Conv2d(num_neurons, 1, kernel_size=3, padding=1)
        # Fully connected layers
        self.fc_mean = nn.Linear(128*16*16, latent_dim)
        self.fc_logvar = nn.Linear(128*16*16, latent_dim)

        self.fc_decoder = nn.Linear(latent_dim, 128*16*16)
        self.sigmoid = nn.Sigmoid()


    def encoder(self, x):
        x = self.mrb1(x)
        x = self.maxpool(x)
        x = self.mrb2(x)
        x = self.maxpool(x)
        x = self.mrb3(x)
        x = self.maxpool(x)
        x = self.mrb3b(x)
        x = self.maxpool(x)
        return x
    
    def decoder(self, x):
        x = self.mrb4(x)
        x = self.up(x)
        x = self.mrb5(x)
        x = self.up(x)
        x = self.mrb6(x)
        x = self.up(x)
        x = self.mrb7(x)
        x = self.up(x)
        x = self.conv1x1(x)
        x = self.sigmoid(x)
        return x

    def latent(self, z_mu, z_logvar):
        ''' 
            encoder: z = mu + sd * e
            input: mean, logvar. output: z
        '''
        sd = torch.exp(z_logvar * 0.5)
        e = Variable(torch.randn(sd.size()).cuda())
        z = e.mul(sd).add_(z_mu)
        return z

    def forward(self, x):
        x = self.encoder(x).view(-1, 16*16*128)
        z_mean = self.fc_mean(x)
        z_logvar = self.fc_logvar(x)
        
        z = self.latent(z_mean, z_logvar)
        z_decoder = self.fc_decoder(z).view(-1, 128, 16, 16)
        x_out = self.decoder(z_decoder)
        #print("x out shape: ", x_out.shape)
        return x_out, z_mean, z_logvar


def criterion(x_out, target, z_mean, z_logvar, alpha=0.5, beta=0.5):
    """
    Criterion for VAE done analytically
    output: loss
    output: mse
    output: KL Divergence
    """
    mse = F.mse_loss(x_out, target, reduction='sum') #Use MSE loss for images
    kl = -0.5 * torch.sum(1 + z_logvar - (z_mean**2) - torch.exp(z_logvar)) #Analytical KL Divergence - Assumes p(z) is Gaussian Distribution
    loss = ((alpha * mse) + (beta * kl)) / x_out.size(0)
    return loss, mse, kl


In [5]:
import pickle
import numpy as np
import random


def batch_generator(data, batch_size=1, shuffle=False):
    nsamples = len(data)
    if shuffle:
        perm = np.random.permutation(nsamples)
    else:
        perm = range(nsamples)

    for i in range(0, nsamples, batch_size):
        batch_idx = perm[i:i+batch_size]
        yield data[batch_idx]


In [6]:
############### Model Parameters ################

num_neurons = 32
num_neurons2 = 64
num_neurons3 = 128
kernel_size = 3
pool_size = 2
epochs = 25
batch_size = 64

In [7]:
############## Additional functions ###############

def train(model, optimizer, data, batch_size, epoch, device):
    model.train()
    total_loss = 0
    losses = []
    mses = []
    kls = []
    for X in batch_generator(data, batch_size, shuffle=True):
        X = torch.from_numpy(X).float().to(device)
        X = X/127
        X = torch.unsqueeze(X,1)
        model.zero_grad()
        output, z_mu, z_logvar = model(X)
        #print("Output size: ", output.shape)
        #print("Forward pass")
        loss, mse, kl = criterion(output, X, z_mu, z_logvar)
        #print("Loss computed")
        loss.backward()
        #print("Backward pass done")
        optimizer.step()
        #print("Optimezer step done")
        # Training statistics
        #print(loss.item())
        total_loss += loss.item()
        losses.append(loss.item())
        mses.append(mse.item())
        kls.append(kl.item())

    print(f'train loss={np.mean(losses):.6f} | mse loss:{np.mean(mses):.6f} | kl loss:{np.mean(kls):.6f}')
    return total_loss

def validate(model, data, batch_size, epoch, device):
    model.eval()
    total_loss = 0
    losses = []
    mses = []
    kls = []
    with torch.no_grad():
        for X in batch_generator(data, batch_size, shuffle=True):
            X = torch.from_numpy(X).float().to(device)
            X = X/127
            X = torch.unsqueeze(X,1)
            output, z_mu, z_logvar = model(X)
            #print("Output size: ", output.shape)
            #print("Forward pass")
            loss, mse, kl = criterion(output, X, z_mu, z_logvar)
            total_loss += loss.item()
            losses.append(loss.item())
            mses.append(mse.item())
            kls.append(kl.item())
        print(f'| epoch {epoch:03d} | validation loss={np.mean(losses):.6f} | validation mse loss:{np.mean(mses):.6f} | validation kl loss:{np.mean(kls):.6f}')
        return total_loss
    

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:

#file
model = VAE(num_neurons, num_neurons2, num_neurons3, kernel_size, pool_size, latent_dim=256)

#if torch.cuda.device_count() > 1:
#    model = nn.DataParallel(model)

model.to(device)
#model = VAE(num_neurons, num_neurons2, num_neurons3, kernel_size, pool_size, latent_dim=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
#train_loss = []
tr_loss = 0
valid_loss = 0
for epoch in range(epochs):
    tr_loss = train(model, optimizer, train_data, batch_size=batch_size, epoch=epoch, device=device)
    val_loss = validate(model, valid_data, batch_size, epoch, device)

train loss=2472.915163 | mse loss:309312.019567 | kl loss:7158.791315
| epoch 000 | validation loss=295.902927 | validation mse loss:36267.954532 | validation kl loss:609.253700
train loss=204.579540 | mse loss:25802.019926 | kl loss:352.326644
| epoch 001 | validation loss=159.955256 | validation mse loss:19762.213472 | validation kl loss:186.696231
train loss=139.399070 | mse loss:17666.706411 | kl loss:149.188806
| epoch 002 | validation loss=130.741865 | validation mse loss:16149.155090 | validation kl loss:115.361711
train loss=120.926655 | mse loss:15342.558492 | kl loss:112.658813
| epoch 003 | validation loss=120.356443 | validation mse loss:14815.017930 | validation kl loss:47.934926
train loss=114.235430 | mse loss:14324.380886 | kl loss:272.720163
| epoch 004 | validation loss=114.059134 | validation mse loss:14135.351322 | validation kl loss:45.883880
train loss=113.454105 | mse loss:13790.643879 | kl loss:708.730565
| epoch 005 | validation loss=112.798089 | validation mse

In [None]:
torch.save(modeL, 'VAE_lt256_lr1e-4_inv.pt')