In [None]:
# import bcolz 
import importlib
import numpy as np
import torch.utils.data
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm

from torch.autograd import Variable
import torch

# Load Dataset


## Load your favorite dataset

In [None]:
# you're supposed to come up with 
# x_val (VALIDATION_SIZE, TIME_STEPS), y_val (VALIDATION_SIZE, N_CLASSES) => validataion dataset 
# x_train (TRAIN_SIZE, TIME_STEPS), y_train (TRAIN_SIZE, N_CLASSES)       => train dataset

## Size params

In [None]:
LENGTH = x_train.shape[-1]
NCLASSES = y_train.shape[-1]

## Normalize by max

In [None]:
x_val = x_val / x_val.max(axis=-1, keepdims=True)
x_train = x_train / x_train.max(axis=-1, keepdims=True)
y_train_oh = np.argmax(y_train, axis=-1) # one hot
y_val_oh = np.argmax(y_val, axis=-1)     # one hot

## Compute Class Weights

In [None]:
weights = 1.0 / np.sum(y_train, axis=0)
normalized_weights = weights / np.max(weights)
class_weight = {i : normalized_weights[i] for i in range(len(normalized_weights))}
weights = torch.from_numpy(np.array(list(class_weight.values()))).float().cuda()

# Construct Model

In [None]:
import gc; gc.collect()

## (De)Convolution block


In [None]:
class Conv_block(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, is_conv=True):
        super(Conv_block, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding 
        self.pool_op = torch.nn.AvgPool1d(2, ) if is_conv \
                  else torch.nn.Upsample(scale_factor=2, mode='linear')
        self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
        self.bn = torch.nn.BatchNorm1d(out_channels, eps=0.001, momentum=0.99)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return self.pool_op(x)

## Encoder/Classifier block


In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, in_channels, in_length, nclasses, latent_size, encoder_out_channels):
        super(Encoder, self).__init__()
        
        self.in_channels = in_channels
        self.in_length = in_length
        self.nclasses = nclasses
        self.latent_size = latent_size
        self.encoder_out_channels = encoder_out_channels
        length = self.in_length
        self.bn0 = torch.nn.BatchNorm1d(self.in_channels, eps=0.001, momentum=0.99)
        # Layer 1
        in_channels = self.in_channels
        out_channels = 32
        kernel_size = 201
        padding = kernel_size // 2
        self.conv_block_1 = Conv_block(in_channels, out_channels, kernel_size, padding)
        length = length // 2
        # Layer 2
        in_channels = out_channels
        out_channels = 32
        kernel_size = 201
        padding = kernel_size // 2
        self.conv_block_2 = Conv_block(in_channels, out_channels, kernel_size, padding)
        length = length // 2
        
        # Layer 3
        in_channels = out_channels
        last_featuremaps_channels = 64
        kernel_size = 201
        padding = kernel_size // 2
        self.conv_block_3 = Conv_block(in_channels, last_featuremaps_channels, kernel_size, padding)
        length = length // 2
        
        in_channels = last_featuremaps_channels
        out_channels = NCLASSES
        kernel_size = 30
        padding = kernel_size // 2
        self.conv_final = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
        self.gp_final = torch.nn.AvgPool1d(length)
        
        # encoder
        in_channels = last_featuremaps_channels
        out_channels = self.encoder_out_channels
        kernel_size = 51
        padding = kernel_size // 2
        self.adapt_pool = torch.nn.AvgPool1d(2); length = length // 2
        self.adapt_conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
        self.encode_mean = torch.nn.Linear(length*out_channels, self.latent_size)
        self.encode_logvar = torch.nn.Linear(length*out_channels, self.latent_size)
        self.relu = torch.nn.ReLU()
        length = 1

    def forward(self, x):
        x = x.view(-1, self.in_channels, self.in_length)
        x = self.bn0(x)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        x = self.conv_block_3(x)
        cv_final = self.conv_final(x)
        oh_class = self.gp_final(cv_final)
        x = self.adapt_pool(x)
        x = self.adapt_conv(x)
        x = x.view(x.size(0), -1)
        mean = self.relu(self.encode_mean(x)) 
        logvar = self.relu(self.encode_logvar(x))
        return [oh_class.view(oh_class.size(0), self.nclasses), 
                mean, logvar, 
                self._sample_latent(mean, logvar)]
        
    def _sample_latent(self, mean, logvar): # z ~ N(mean, var (sigma^2))   
        z_std = torch.from_numpy(np.random.normal(0, 1, size=mean.size())).float()
        sigma = torch.exp(logvar).cuda()
        return mean + sigma * Variable(z_std, requires_grad=False).cuda()

## Decoder block

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, length, in_channels, nclasses, latent_size):
        super(Decoder, self).__init__()
        
        self.in_channels = in_channels
        self.length = length
        self.latent_size = latent_size
        length = self.length  
        length = length // 2 // 2 // 2 
        # Adapt Layer
        self.relu = torch.nn.ReLU()
        self.tanh = torch.nn.Tanh()
        self.adapt_nn = torch.nn.Linear(latent_size, self.in_channels*length)
        # Layer 1
        in_channels = self.in_channels
        out_channels = 64
        kernel_size = 200
        padding = kernel_size // 2
        self.deconv_block_1 = Conv_block(in_channels, out_channels, kernel_size, padding, is_conv=False)
        length = length * 2
        # Layer 2
        in_channels = out_channels
        out_channels = 32
        kernel_size = 200
        padding = kernel_size // 2
        self.deconv_block_2 = Conv_block(in_channels, out_channels, kernel_size, padding, is_conv=False)
        length = length * 2
        
        # Layer 3
        in_channels = out_channels
        out_channels = 32
        kernel_size = 200
        padding = kernel_size // 2
        self.deconv_block_3 = Conv_block(in_channels, out_channels, kernel_size, padding, is_conv=False)
        length = length * 2
        
        in_channels = out_channels
        out_channels = 1
        kernel_size = 200
        padding = kernel_size // 2
        self.decode_conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
        
    def forward(self, z):

        x = self.relu(self.adapt_nn(z)).cuda()
        x = x.view(x.size(0), self.in_channels, self.length // 2 // 2 // 2)
        x = self.deconv_block_1(x)
        x = self.deconv_block_2(x)
        x = self.deconv_block_3(x)
        x = self.decode_conv(x)
        out = self.tanh(x)
        return out

## Variational autoencoder

In [None]:
class VAE(torch.nn.Module):
    def __init__(self, length, nclasses, latent_size, transition_channels):
        super(VAE, self).__init__()
        self.encoder = Encoder(1, length, nclasses, latent_size, transition_channels)
        self.decoder = Decoder(length, transition_channels, nclasses, latent_size)
    def count_parameters(self):
        return np.sum([np.prod(x.size()) for x in self.parameters()])
    def forward(self, x):
        oh_class, mean, z = self.encoder(x)
        x_decoded = self.decoder(z)
        return oh_class, mean, z, x_decoded

In [None]:
model = VAE(LENGTH, NCLASSES, 10, 4).cuda()
model.count_parameters()

In [None]:
# x = Variable(torch.from_numpy(x_train[:2])).cuda().float()
# a = model.encoder(x)

## Losses

In [None]:
class SSD(torch.nn.Module):
    def __init__(self):
        super(SSD, self).__init__()
    def forward(self, x_decoded, x):
        loss = torch.sum(torch.pow(x - x_decoded, 2))
        return loss / x_decoded.size(0)
class Variational_loss(torch.nn.Module):
    def __init__(self):
        super(Variational_loss, self).__init__()
    def forward(self, x_decoded, x, mu, logvar):
        return SSD()(x_decoded, x) + torch.sum(0.5 * (mu ** 2 + torch.exp(logvar) - logvar - 1))

class VAE_loss(torch.nn.Module):
    def __init__(self, weights):
        super(VAE_loss, self).__init__()
        self.classification_loss = torch.nn.CrossEntropyLoss(weights)
        self.variational_loss = Variational_loss()
        self.c = 0.001
    def forward(self, x_decoded, x, mu, oh_class, y):
        
        a = self.classification_loss(oh_class.cuda(), y)
        b = self.variational_loss(
            x_decoded.squeeze()[:,:LENGTH].cuda(), 
            x.squeeze()[:,:LENGTH].cuda(), 
            mu) * self.c
        return a + b, a, b

## Data loader

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    def __len__(self):
        return len(self.x)
batch_size = 256

train_loader = torch.utils.data.DataLoader(
    dataset=Dataset(x_train, y_train_oh), 
    batch_size=batch_size, 
    shuffle=True)
val_loader = torch.utils.data.DataLoader(
    dataset=Dataset(x_val, y_val_oh), 
    batch_size=batch_size, 
    shuffle=True)

## Tester

In [None]:
def test(model, loader):
    acc = []
    for batch_id, (x, y) in tqdm(enumerate(loader), total=len(loader)):
        x = Variable(x).float().cuda()
        y = Variable(y).cuda()
        out = model(x)
        y_pred = out[0]
        _, index = torch.max(y_pred, -1)
        acc.append((index == y).cpu().data.numpy())
    acc = np.concatenate(acc).mean()
    return acc

# Train

## Train Classifier/Encoder first

### Freeze variational encoder layers

In [None]:
classifier = model.encoder
parameters = []
layers = (classifier.adapt_conv, classifier.adapt_pool, classifier.encode_mean, classifier.encode_logvar)
for layer in layers:
    for param in layer.parameters():
        param.requires_grad = False
for param in classifier.parameters():
    if param.requires_grad == True:
        parameters.append(param)
classifier_parameters = iter(parameters)

### Create optimizer and classifier

In [None]:
optim_classifier = torch.optim.Adam(classifier_parameters)
Loss = torch.nn.CrossEntropyLoss(weights)

## Train

In [None]:
learning_rates = [0.01] * 4 + [0.001] * 4 + [0.0001] * 3 + [0.00001] * 3 
for lr in tqdm(learning_rates, total=len(learning_rates)):
    optim_classifier.param_groups[0]['lr'] = lr
    for i, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = Variable(x).float().cuda()
        y = Variable(y.long()).cuda()

        oh_class, mu, logvar, z = model.encoder(x)
        loss = Loss(oh_class.cuda(), y)
        optim_classifier.zero_grad()
        loss.backward()
        optim_classifier.step()
    print('Loss:' ,loss.data)
    print('Train Accuracy: ', test(model.encoder, train_loader))
    print('Validation Accuracy:', test(model.encoder, val_loader))


## Train Decoder

In [None]:
nepochs = 50
for epoch in tqdm(range(nepochs), total=nepochs):
    optim.param_groups[0]['lr'] *= 0.912011 # 0.001*(a^50) = 0.00001, a = 0.912011
    
    for i, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = Variable(x).float().cuda()
        y = Variable(y.long()).cuda()

        oh_class, mu, z, x_decoded = model(x)
        loss, class_loss, var_loss = \
            Loss(x_decoded.cuda(), x, mu.cuda(), oh_class.cuda(), y) # x_decoded, x, mu, oh_class, y
        optim.zero_grad()
        loss.backward()
        optim.step()
        if not i % 50:
            print('training encoder only\n')
            oh_class, _, _ = model.encoder(x)
            enc_aux_loss = torch.nn.CrossEntropyLoss(weights)(oh_class.cuda(), y)
            optim.zero_grad()
            enc_aux_loss.backward()
            optim.step()
        print('Loss:' ,loss.data)
        print('Class loss:' ,class_loss.data)
        print('Var loss:' ,var_loss.data)
        
        
    print('Train Accuracy: ', test(model, train_loader))
    print('Validation Accuracy:', test(model, val_loader))
