In [15]:
import pickle
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns


import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import functional as F

from structure.utils import *

import time

import warnings
warnings.filterwarnings("ignore")

torch.cuda.is_available()

True

## Decoder

In [16]:
BN = nn.BatchNorm1d

class FC_Decoder(nn.Module):
    def __init__(self, latent_dim, N, q = 21, decoder_hidden=[250], decoder_dropout=[0.],
                   n_conditions=0):
        super(FC_Decoder, self).__init__()
        self.N, self.q = N, q
        self.n_conditions = n_conditions
        last_h = latent_dim
        self.lays = []
        for h, d in zip(decoder_hidden, decoder_dropout):
            self.lays += [nn.Linear(last_h + n_conditions, h), nn.ReLU(),
                              nn.Dropout(d)]
            last_h = h + n_conditions
        self.lays = nn.ModuleList(self.lays)
        self.out_lay = nn.Linear(last_h, N*q)

    def forward(self, x, z, conditions = None):
        h = z
        if conditions is not None:
            h = torch.cat([h, conditions], 1)
        for lay in self.lays:
            h = lay(h)
            if conditions is not None:
                h = torch.cat([h, conditions], 1)
        h = self.out_lay(h).view(-1, self.q, self.N)
        return F.softmax(h,1)
                 
class Upsampler(nn.Module):
    def __init__(self, latent_dim, low_res_dim, min_deconv_dim=21,
              n_deconv=3, kernel_size=2, BN=nn.BatchNorm1d, dropout=nn.Identity,
                max_filters=336):
        super(Upsampler, self).__init__()
        low_res_features = min(min_deconv_dim * (2**n_deconv), max_filters)
        self.low_res_dim, self.low_res_features = low_res_dim, low_res_features
        self.linear = nn.Linear(latent_dim, low_res_dim * low_res_features)
        self.lays = []
        last_h = low_res_features
        for i in range(n_deconv):
            h = min(min_deconv_dim * 2**(n_deconv-(i+1)), max_filters)
            self.lays += [nn.ConvTranspose1d(last_h, h, kernel_size,
                                 stride = 2,
                                 bias = False), BN(h), nn.PReLU(), dropout()]
            last_h = h
        self.lays = nn.ModuleList(self.lays)

    def forward(self, h):    
        h = self.linear(h).view(-1, self.low_res_features, self.low_res_dim)
        for lay in self.lays:
            h = lay(h)
        return h
    
class RecurrentSequenceDecoder(nn.Module):
    def __init__(self, latent_dim, N, q = 21, ncell=512,
                               project_x=True,
                               upsample=False, min_deconv_dim=42,
                               input_dropout=nn.Identity, intermediate_dim=63,
                               max_filters=336, n_conditions=0,
                               cond_concat_each_timestep=False):
        super(RecurrentSequenceDecoder, self).__init__()
        self.N, self.q = N, q
        self.n_conditions = n_conditions
        self.upsample = upsample
        
        self.dropout1 = input_dropout()
        if project_x:
            self.project = nn.Conv1d(q, q, kernel_size = 1)
        if upsample:
            self.upsampler = Upsampler(latent_dim + n_conditions, intermediate_dim,
                                       min_deconv_dim, n_deconv = 3, max_filters = max_filters)
        if upsample:
            self.gru = nn.GRU(min_deconv_dim + q, ncell)
        else:
            self.gru = nn.GRU(latent_dim + n_conditions + q, ncell)
        self.conv = nn.Conv1d(ncell, q, 1)
            
    def forward(self, x, z, conditions):
        x = x[:,:self.q]
        x = F.pad(x,(1,0))[:,:,:-1]
        x = self.dropout1(x)
        if self.project is not None:
            x = self.project(x)
        if conditions is not None:
            z = torch.cat([z, conditions], 1)
            
        if self.upsample:
            z = self.upsampler(z)
        else:
            z = z.view(*z.size(),1).expand(*z.size(), x.size(-1))
        h = torch.cat([x,z],1).permute(2,0,1)
        h = self.gru(h)[0].permute(1,2,0)
        return F.softmax(self.conv(h),1)
    
class RecurrentStructureDecoder(nn.Module):
    def __init__(self, latent_dim, N, q = 21, ncell=512,
                               project_x=True,
                               upsample=False, min_deconv_dim=42,
                               input_dropout=nn.Identity, intermediate_dim=63,
                               max_filters=336, n_conditions=0,
                               cond_concat_each_timestep=False):
        super(RecurrentStructureDecoder, self).__init__()
        self.N, self.q = N, q
        self.n_conditions = n_conditions
        self.upsample = upsample
        
        self.dropout1 = input_dropout()
        if project_x:
            self.project = nn.Conv1d(q, q, kernel_size = 1)
        if upsample:
            self.upsampler = Upsampler(latent_dim + n_conditions, intermediate_dim,
                                       min_deconv_dim, n_deconv = 3, max_filters = max_filters)
        if upsample:
            self.gru = nn.GRU(min_deconv_dim + q, ncell)
        else:
            self.gru = nn.GRU(latent_dim + n_conditions + q, ncell)
        self.conv_sequence = nn.Conv1d(ncell, q, 1)
        self.conv_structure = nn.Conv1d(ncell, 4, 1)

            
    def forward(self, x, z, conditions):
        x = F.pad(x,(1,0))[:,:,:-1]
        x = self.dropout1(x)
        if self.project is not None:
            x = self.project(x)
        if conditions is not None:
            z = torch.cat([z, conditions], 1)
            
        if self.upsample:
            z = self.upsampler(z)
        else:
            z = z.view(*z.size(),1).expand(*z.size(), x.size(-1))
        h = torch.cat([x,z],1).permute(2,0,1)
        h = self.gru(h)[0].permute(1,2,0)
        return F.softmax(self.conv(h),1)

## Encoder

In [17]:
def size(x, n_conv):
    x -= 2
    for _ in range(n_conv-1):
        x = x//2
    return x

class FC_Encoder(nn.Module):
    def __init__(self, latent_dim, N, q = 21, encoder_hidden=[250,250,250],
               encoder_dropout=[0.7,0.,0.], n_conditions=0):
        super(FC_Encoder, self).__init__()
        self.N, self.q = N, q
        self.n_conditions = n_conditions
        last_h = N*q + n_conditions
        self.lays = []
        for h, d in zip(encoder_hidden, encoder_dropout):
            self.lays += [nn.Linear(last_h, h), nn.ReLU(),
                              nn.Dropout(d)]
            last_h = h
        self.lays = nn.ModuleList(self.lays)
        self.mean, self.var = nn.Linear(last_h, latent_dim), nn.Linear(last_h, latent_dim)

    def forward(self, x, conditions = None):
        h = x.view(x.size(0),-1)
        if conditions is not None:
            h = torch.cat([h, conditions], 1)
        for lay in self.lays:
            h = lay(h)
        z_mean, z_var = self.mean(h), self.var(h)
        return z_mean, z_var

class Cond_MLP(nn.Module):
    def __init__(self, out_dim, n_layers = 2, h = 6, n_conditions = 3):
        super(Cond_MLP, self).__init__()
        self.n_conditions = n_conditions
        last_h = n_conditions
        self.lays = []
        for i in range(n_layers):
            self.lays += [nn.Linear(last_h, h), nn.PReLU()]
            last_h = h
        self.lays = nn.ModuleList(self.lays)
        self.out = nn.Linear(h, out_dim)
        self.act = nn.PReLU()

    def forward(self, conditions):
        h = conditions
        for lay in self.lays:
            h = lay(h)
        return self.act(self.out(h))

class CNN_Encoder(nn.Module):
    def __init__(self, latent_dim, N, cond_concat_dim = None, q = 21, num_filters=21, kernel_size=2,
                BN = True, dropout = nn.Identity(), max_filters = 10000, n_conv = 5,
                n_conditions = 0, n_dense_cond = 6):
        super(CNN_Encoder, self).__init__()
        self.N, self.q = N, q
        self.n_conditions = n_conditions
        last_h = q
        self.lays = []
        for i in range(n_conv):
            h = min(num_filters * (2**i), max_filters)
            conv = nn.Conv1d(last_h, h, kernel_size,
                       stride = 1 if i==0 else 2, bias=not BN)
            batch_norm = nn.BatchNorm1d(h) if BN else nn.Identity()
            act = nn.PReLU()
            dropout = dropout
            self.lays += [conv, batch_norm, act, dropout]
            last_h = h
        self.lays = nn.ModuleList(self.lays)
        size_h = size(N,n_conv) * last_h
        cond_concat_dim = size_h if cond_concat_dim is None else cond_concat_dim
        if n_conditions>0:
            self.cond_mlp = Cond_MLP(cond_concat_dim, n_conditions=n_conditions, h=n_dense_cond)
            
        self.mean = nn.Linear(size_h + cond_concat_dim if n_conditions > 0 else size_h, latent_dim)
        self.var = nn.Linear(size_h + cond_concat_dim if n_conditions > 0 else size_h, latent_dim)

    def forward(self, x, conditions = None):
        h = x
        for lay in self.lays:
            h = lay(h)
        h = h.view(h.size(0),-1)
        if conditions is not None:
            h_cond = self.cond_mlp(conditions)
            h = torch.cat([h, h_cond], 1)
        z_mean, z_var = self.mean(h), self.var(h)
        return z_mean, z_var

## VAE

In [18]:
class ProtVAE(nn.Module):
    def __init__(self, encoder, decoder, 
                 N, q = 21, n_conditions = 0):
        super(ProtVAE, self).__init__()
        self.q, self.N = q, N
        self.encoder = encoder
        self.decoder = decoder
        self.n_conditions = n_conditions
    
    def encode(self, x, conditions = None):
        return self.encoder(x, conditions)
    
    def decode(self, x, z, conditions = None):
        return self.decoder(x, z, conditions)
    
    def sample(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def forward(self, x, conditions = None):
        mu, logvar = self.encode(x, conditions)
        z = self.sample(mu, logvar)
        
        return self.decode(x, z, conditions), mu, logvar
    
    def load_weights(self, filename):
        return self.load_state_dict(torch.load(filename))
    
    def save_weights(self, filename):
        return torch.save(self.state_dict(), filename)
    
    def prior_sample(self, n_samples = 1, mean = 0, stddev = 1,
                     conditions = None, batch_size = 5000):
        if n_samples > batch_size:
            x = []
            total = 0
            while total< n_samples:
                this_batch = min(batch_size, n_samples - total)
                z_sample = mean + stddev * torch.randn(this_batch, self.latent_dim)
                x += self._decode(z_sample, conditions)
                total += this_batch
        else:
            z_sample = mean + stddev * torch.randn(this_batch, self.latent_dim)
            x = self._decode(z_sample, conditions)
        return x
    
    def _decode(self, z, sample_func=None, conditions=None):
        x = torch.zeros(z.size(0), self.q, self.N)
        for i in range(self.N):
            # iteration is over positions in sequence, which can't be parallelized
            x_ = self.decode(x, z, conditions)[:, i, :]
            if sample_func is None:
                pred_ind = pos_pred.argmax(1) # convert probability to index
            else:
                pred_ind = sample_func(pos_pred)
            for j, p in enumerate(pred_ind):
                x[j,p,i] = 1
        return x

In [43]:
class HMM_Dataset(object):
    def __init__(self, file):
        data = pickle.load(open(file, 'rb'))
        self.length = []
        self.primary = np.zeros((len(data), 64, 20))
        for i, v in enumerate(data.values()):
            self.primary[i,:v.shape[0]] = v[:,:20]
            self.length.append(v.shape[0])

    def __len__(self):
        return len(self.primary)
    
    def __getitem__(self, i):
        return self.primary[i], self.length[i]
    
class HMM_Dataset2(object):
    def __init__(self, file):
        data = pickle.load(open(file, 'rb'))
        self.length = []
        self.primary = np.zeros((len(data), 68, 21))
        for i, v in enumerate(data.values()):
            self.primary[i,:v.shape[0]] = v[:,:40]
            self.length.append(v.shape[0])

    def __len__(self):
        return len(self.primary)
    
    def __getitem__(self, i):
        return self.primary[i], self.length[i]

In [30]:
DATA = "/home/cyril/Documents/These/data/"
DATASET = "PF00397"

In [45]:
train_dataset = HMM_Dataset(f"{DATA}/{DATASET}/hmm.pkl")
train_loader = DataLoader(train_dataset, batch_size = 16,
                        shuffle=True, drop_last=False)

# val_dataset = HMM_Dataset(f"{DATA}/{DATASET}/val.pkl")
# val_loader = DataLoader(val_dataset, batch_size = 16,
#                         shuffle=True, drop_last=False)

In [68]:
h, N, q = 16, 64, 20
device = "cuda"
enc = CNN_Encoder(latent_dim = h, N = N, q = q)
dec = RecurrentSequenceDecoder(latent_dim = h, N = N, q = 20, upsample = False)
model = ProtVAE(enc, dec, N).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
model

ProtVAE(
  (encoder): CNN_Encoder(
    (lays): ModuleList(
      (0): Conv1d(20, 21, kernel_size=(2,), stride=(1,), bias=False)
      (1): BatchNorm1d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=1)
      (3): Identity()
      (4): Conv1d(21, 42, kernel_size=(2,), stride=(2,), bias=False)
      (5): BatchNorm1d(42, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): PReLU(num_parameters=1)
      (7): Identity()
      (8): Conv1d(42, 84, kernel_size=(2,), stride=(2,), bias=False)
      (9): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): PReLU(num_parameters=1)
      (11): Identity()
      (12): Conv1d(84, 168, kernel_size=(2,), stride=(2,), bias=False)
      (13): BatchNorm1d(168, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): PReLU(num_parameters=1)
      (15): Identity()
      (16): Conv1d(168, 336, kernel_size=(2,), stride=(2,), bias=

In [65]:
threser = (torch.arange(64).view(1,  -1)).to(device)

def aa_acc(x, recon_x, length):
    thres = (threser<(length.view(-1,1))).int().view(-1)
    x = torch.argmax(x, 1).view(-1)
    recon_x = torch.argmax(recon_x, 1).view(-1)
    return (((x==recon_x).int() * thres)).sum().item()/(thres.sum().item())

def loss_function(x_r, x, mu, logvar):
    BCE = F.binary_cross_entropy(x_r, x, reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

def train(model, optimizer, epoch):
    start = time.time()
    model.train()
    train_loss, train_acc = 0, 0
    for batch_idx, (x,length) in enumerate(train_loader):
        x = x.float().permute(0, 2, 1).to(device)
        length = length.int().to(device)
        optimizer.zero_grad()
        x_r, mu, logvar = model(x)
        loss = loss_function(x_r, x[:,:20], mu, logvar)
        loss.backward()
        train_loss += torch.clamp(loss, 0, 100000).item()
        optimizer.step()
        acc = aa_acc(x[:,:20], x_r, length)
        train_acc += acc
        print(f'''Train Epoch: {epoch} [{100. * batch_idx / len(train_loader):.0f}%]\tLoss: {train_loss/(batch_idx+1):.2f} || Time: {time.time()-start:.2f}s || Acc: {train_acc/(batch_idx+1):.2f}''', end="\r")
        
    print(f'''Train Epoch: {epoch} [100%]\tLoss: {train_loss/(batch_idx+1):.2f} || Time: {time.time()-start:.2f}s || Acc: {train_acc/(batch_idx+1):.2f}''')
    
def validation(model, epoch):
    start = time.time()
    model.eval()
    val_loss, val_acc = 0, 0
    for batch_idx, x in enumerate(val_loader):
        x = x.float().permute(0, 2, 1).to(device)
        x_r, mu, logvar = model(x)
        loss = loss_function(x_r, x, mu, logvar)
        val_loss += loss.item()
        acc = aa_acc(x, x_r)
        val_acc += acc 
    print(f'''Val Epoch: {epoch} \tLoss: {val_loss/(batch_idx+1):.2f} || Time: {time.time()-start:.2f}s || Acc: {val_acc/(batch_idx+1):.2f}''')

In [70]:
for epoch in range(200):
    train(model, optimizer, epoch)
#     validation(model, epoch)

Train Epoch: 0 [100%]	Loss: 988.06 || Time: 17.18s || Acc: 0.84
Train Epoch: 1 [100%]	Loss: 988.37 || Time: 19.35s || Acc: 0.84
Train Epoch: 2 [100%]	Loss: 989.17 || Time: 18.69s || Acc: 0.84
Train Epoch: 3 [100%]	Loss: 993.09 || Time: 19.39s || Acc: 0.84
Train Epoch: 4 [8%]	Loss: 975.27 || Time: 1.63s || Acc: 0.84

KeyboardInterrupt: 

In [36]:
model.save_weights(f"{DATA}/{DATASET}/weights/alignements.h5")