In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import numpy as np
import time

import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import functional as F

from pgm.layers import MarkovChainLayer, OneHotLayer, Layer
from pgm.data import Seq_SS_Data
from pgm.ebm import EnergyModel
from pgm.nn import ConvNet, ConvBlock
from pgm.metrics import hinge_loss, aa_acc
from pgm.utils import I

## Data

In [3]:
DATA = "/home/cyril/Documents/These/data/secondary_structure"
batch_size = 32

In [4]:
train_dataset = Seq_SS_Data(f"{DATA}/secondary_structure_train.json", size = 512)
train_loader = DataLoader(train_dataset, batch_size = batch_size, 
                          shuffle = True, drop_last=True)

val_dataset = Seq_SS_Data(f"{DATA}/secondary_structure_valid.json", size = 512)
val_loader = DataLoader(val_dataset, batch_size = batch_size, 
                        shuffle = True, drop_last=True)

In [5]:
t = np.zeros((3, 3))
for seq, length in zip(train_dataset.ss3, train_dataset.length):
    seq = np.argmax(seq, -1)
    x = seq[0]
    y = seq[0]
    for i in range(2, length):
        x, y = y, seq[i]
        t[x,y] += 1
t /= np.sum(t, 1)
t = torch.tensor(t).float()
t = (t+t.t())/2

## EM

In [6]:
from torch.distributions.normal import Normal

In [7]:
def leaky_relu(): 
    return nn.LeakyReLU(0.2, inplace = False)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3, bias=True):
        super(ResBlock, self).__init__()
        pad = (kernel_size-1)//2
        self.conv_1 = ConvBlock(nn.Conv1d, leaky_relu, nn.BatchNorm1d, 
                                in_channels, out_channels, kernel_size,
                                stride=1, padding=pad, bias=bias)
        self.conv2 = ConvBlock(nn.Conv1d, None, nn.BatchNorm1d, 
                                in_channels, out_channels, 1,
                                stride=1, padding=0, bias=bias)

    def forward(self, x):
        identity = x
        out = self.conv_1(x)
        identity = self.conv2(x)
        out += identity
        return out
    
class ResBlock2(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super(ResBlock2, self).__init__()
        self.conv_11 = ConvBlock(nn.Conv1d, leaky_relu, nn.BatchNorm1d, 
                                in_channels, out_channels, 3,
                                stride=1, padding=1, bias=bias)
        self.conv_12 = ConvBlock(nn.Conv1d, leaky_relu, nn.BatchNorm1d, 
                                out_channels, out_channels, 3,
                                stride=2, padding=1, bias=bias)
        self.conv2 = ConvBlock(nn.Conv1d, None, None, 
                                in_channels, out_channels, 3,
                                stride=2, padding=1, bias=bias)

    def forward(self, x):
        identity = x
        out = self.conv_11(x)
        out = self.conv_12(out)
        identity = self.conv2(x)
        out += identity
        return out

    
class ConvNet(nn.Module):
    def __init__(self, in_channels, out_channels = 100, bias=True):
        super(ConvNet, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.conv1 = ResBlock(in_channels, 100, 11)
        self.conv2 = ResBlock(100, 100, 11)
        self.conv3 = ResBlock(100, 100, 11)
        self.conv4 = ResBlock(100, 100, 11)
#         self.conv5 = ResBlock(100, 100, 11)
        self.conv5 = ConvBlock(nn.Conv1d, nn.PReLU, nn.BatchNorm1d,
                        100, out_channels, 1,
                        stride=1, padding=0, dilation=1)
    def forward(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
#         h = self.conv3(h)
#         h = self.conv4(h)
        h = self.conv5(h)
        return h
    
class ConvNet2(nn.Module):
    def __init__(self, in_channels, out_channels = 100, bias=True):
        super(ConvNet2, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        
#         self.att1 = SelfAttention(in_channels)
        self.l_conv1 = ResBlock(in_channels, 100, 5)
        self.l_conv2 = ResBlock(100, out_channels, 5)
        self.l_pool1 = nn.MaxPool1d(2) # 256

        self.m_conv1 = ResBlock2(out_channels, out_channels // 2, 5)
        self.m_pool1 = nn.MaxPool1d(2) # 64
        
        self.g_conv1 = ResBlock2(out_channels, out_channels, 5)
        self.g_pool1 = nn.MaxPool1d(2) # 64
        self.g_conv2 = ResBlock2(out_channels, out_channels // 2, 5)
        self.g_pool2 = nn.MaxPool1d(2) # 16
        self.g_conv3 = ResBlock2(out_channels // 2, out_channels // 2, 5)
        self.g_pool3 = nn.MaxPool1d(2) # 4


        self.g_dense = nn.Linear(out_channels * 2, out_channels//2)

    def forward(self, x):
        batch_size, _, N = x.size()
        h_l = self.l_conv1(x)
        h_l = self.l_conv2(h_l)
        h = self.l_pool1(h_l)
        
        h_m = self.m_pool1(self.m_conv1(h))
        h_m = torch.cat([h_m[:,:,i:i+1].expand(-1, -1, N//h_m.size(-1)) for i in range(h_m.size(-1))], -1)

        h_g = self.g_pool1(self.g_conv1(h))
        h_g = self.g_pool2(self.g_conv2(h_g))
        h_g = self.g_pool3(self.g_conv3(h_g)).view(batch_size, -1)
        h_g = self.g_dense(h_g).view(batch_size, -1, 1)
        h_g = h_g.expand(-1, -1, N)

        return torch.cat([h_l, h_m, h_g], 1)
    
class ConvNet3(nn.Module):
    def __init__(self, in_channels, out_channels = 100, N = 128, bias=True):
        super(ConvNet3, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.conv1 = ResBlock(in_channels, 100, 11)
        self.conv2 = ResBlock(100, 100, 11)
        self.lstm = nn.LSTM(input_size = 100, 
                            hidden_size = out_channels, 
                            num_layers = 2,
                            bias = True,
                            bidirectional = True)
        
    def forward(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
        h = h.permute(2, 0, 1)
        h = self.lstm(h)[0]
        return h.permute(1, 2, 0)


In [8]:
class HMMLayer(Layer):
    r"""
    Layer of One Hot neurons linked by a Markov Chain

    Args:
        T (Numpy Array): Transition Matrix for the Markov Chain
        N (Integer): Number of neurons
        q (Integer): Number of values the neuron can take
        name (String): Name of the layer
    """
    def __init__(self, T, N = 100, q = 21, h = 100, name = "layer0"):
        super(HMMLayer, self).__init__(name)
        self.full_name = f"MC_{name}"
        self.N = N
        self.q = q
        self.h = h
        self.T = nn.Parameter(T.float(), requires_grad=True)
        self.mu = nn.Parameter(torch.rand(self.q, self.h)/self.h, requires_grad=True)
        self.sig = nn.Parameter(torch.rand(self.q, self.h, self.h)/self.h, requires_grad=True)
        
    def E(self, x, l):
        tau = taus(x, self.T, self.mu, self.sig, l)
        return tau
    
    def max_likelihood(self, x):
        return calcul_tau(x, self.T, self.mu, max_op)

In [9]:
class EnergyModel(nn.Module):
    def __init__(self, xlay, ylay, Dx = I, Dy = I):
        super(EnergyModel, self).__init__()
        self.Dx, self.Dy = Dx, Dy
        self.xlay, self.ylay = xlay, ylay
                
    def energy(self, x):
        x = x.detach()
        Dx = self.Dx(x)
        e = self.ylay.E(Dx)
        return e
    
    def forward(self, x, l):
        x = x.detach()
        Dx = self.Dx(x)
        e = self.ylay.E(Dx, l)
        return e
    
    def ypredict(self, x):
        x = x.detach()
        Dx = self.Dx(x)
        e = self.ylay.max_likelihood(Dx)
        return e

In [10]:
import math

def log_pdf(x, mu, sigma = None):
    N = mu.size(0)
    y = x - mu
    Sy = y
    sigma = (1/N) * torch.eye(N)
    if sigma is not None:
        Sy = sigma.mv(x)
    logZ = (N/2)*math.log(2 * math.pi/N)
    return -0.5*(y * Sy).sum() - logZ

In [11]:
sum_op = lambda x : torch.sum(x, 0)
max_op = lambda x : torch.max(x, dim = 0)[0]

def alphas(Dx, T, mu, sigma, operator = sum_op):
    _, N = Dx.size()
    k = mu.size(0)
    log_alphas = torch.zeros(k, N)
    log_norm = torch.zeros(N)
    
    log_alphas[:,0] = torch.cat([log_pdf(Dx[:,0], mu[j]).view(-1,1) for j in range(k)],1)
    log_norm[0] = torch.log((torch.exp(log_alphas[:,0])).sum())
    for s in range(1,N):
        for j in range(k):
            log_alphas[j,s] = log_norm[s-1] + log_pdf(Dx[:,s], mu[j]) 
            log_alphas[j,s] += torch.logsumexp(log_alphas[:,s-1] - log_norm[s-1] + torch.log(T[:,j]), 0)
        log_norm[s] = log_norm[s-1] + torch.log(torch.exp(log_alphas[:,s]-log_norm[s-1]).sum()) 
    return(log_alphas, log_norm)

def betas(Dx, T, mu, sigma, operator = sum_op):
    q, N = Dx.size()
    k = mu.size(0)
    log_betas = torch.zeros(k, N)
    log_norm = torch.zeros(N)
    
    log_norm[-1] = torch.log(torch.exp(log_betas[:, -1]).sum())
    for s in range(2,N+1):
        for j in range(k):
            log_betas[j,-s] = log_norm[-s+1] + torch.logsumexp(
                    torch.tensor([log_pdf(Dx[:,-s+1], mu[i]) + log_betas[i,-s+1] - log_norm[-s+1] + torch.log(T[j,i]) for i in range(k)]), 0)
        log_norm[-s] = log_norm[-s+1] + torch.log(torch.exp(log_betas[:,-s] - log_norm[-s+1]).sum())

    return(log_betas, log_norm)

def taus(Dx, T, mu, sigma, lengths, operator = sum_op):
    batch_size, _, N = Dx.size()
    tau = torch.zeros(batch_size, mu.size(0), N)
    for i, l in enumerate(lengths):
        log_alp, lnorma = alphas(Dx[i,:,:l], T, mu, sigma, operator)
        log_bet, lnormb = betas(Dx[i,:,:l], T, mu, sigma, operator)
        tau[i,:, :l] = log_alp - lnorma.view(1, -1) + log_bet - lnormb.view(1, -1) 
#     tau /= tau.sum(-1).view(batch_size, tau.size(1), 1)

#     for s in range(N):
#         alp_norm[:,:, s] = torch.exp(log_alp[:,:,s] - lnorma[:, s].view(-1, 1))
#         beta_norm[:,:, s] = torch.exp(log_bet[:,:,s] - lnormb[:, s].view(-1, 1))
#     tau2 = torch.zeros((batch_size, k, k, N-1))
    
#     for q1 in range(k):
#         for q2 in range(k):
#             tau2[:,q1,q2,:] = alp_norm[:,q1,:]*bet_norm[:,q2,:]
#     for s2 in range(k):
#         for s in range(N-1):
#             tau2[:,:,:,:]*=log_pdf(Dx[:,:,s+1], mu[s2], sigma[s2])
#     tau2 /= tau2.sum(-1).view(batch_size, k, k, 1)
    
    return tau

## Train

In [12]:
N, qx, qs, h = 512, 21, 3, 1
# del model
device = torch.device('cpu')


x = OneHotLayer(torch.zeros(qx*N), N = N, q = qx, name = "x")
s = HMMLayer(t, N = N, q = qs, h = h, name = "ss")

Dx = ConvNet(qx, h)
Ds = I
model = EnergyModel(x, s, Dx, Ds)

optimizer = optim.Adam(model.parameters(), lr=0.01)
model

EnergyModel(
  (Dx): ConvNet(
    (conv1): ResBlock(
      (conv_1): ConvBlock(
        (conv): Conv1d(21, 100, kernel_size=(11,), stride=(1,), padding=(5,))
        (activation): LeakyReLU(negative_slope=0.2)
        (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): ConvBlock(
        (conv): Conv1d(21, 100, kernel_size=(1,), stride=(1,))
        (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (conv2): ResBlock(
      (conv_1): ConvBlock(
        (conv): Conv1d(100, 100, kernel_size=(11,), stride=(1,), padding=(5,))
        (activation): LeakyReLU(negative_slope=0.2)
        (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): ConvBlock(
        (conv): Conv1d(100, 100, kernel_size=(1,), stride=(1,))
        (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, tr

In [13]:
from sklearn.metrics import confusion_matrix

def hinge_loss(model, x, y, l, m = 1):
    e = -model(x, l)
    e_bar = torch.min(e+y*1e9, 1, keepdim=True).values.view(e.size(0), 1,                                              e.size(-1))
    loss = F.relu(m+(e-e_bar)*y)
    return loss.sum()/(l.sum()), e


def train(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    for batch_idx, data in enumerate(train_loader):
        x = data[0].float().permute(0, 2, 1).to(device)
        s = data[1].float().permute(0, 2, 1).to(device)
        l = data[2].int().to(device)
        length = data[2].int().to(device)
        # Optimization
        optimizer.zero_grad()
        loss, pred = hinge_loss(model, x, s, l)
        loss.backward()
        optimizer.step()
        acc = aa_acc(s, pred)

        del x; del s
        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = acc
        m, s = int(time.time()-start)//60, int(time.time()-start)%60
        print(f'''Train Epoch: {epoch} [{int(100*batch_idx/len(train_loader))}%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}''', end="\r")
        
    
def val(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    cm = np.zeros((3,3))
    for batch_idx, data in enumerate(val_loader):
        x = data[0].float().permute(0, 2, 1).to(device)
        s = data[1].float().permute(0, 2, 1).to(device)
        l = data[2].int().to(device)

        
        # Optimization
        loss, pred = hinge_loss(model, x, s, l)
        acc = aa_acc(s, pred)
        
        cm += confusion_matrix(s.argmax(1).view(-1), 
                         model(x).argmax(1).view(-1))
        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = (mean_acc*batch_idx + acc)/ (batch_idx+1)
        

        m, s = int(time.time()-start)//60, int(time.time()-start)%60
        print(f'''Val: {epoch} [{int(100*batch_idx/len(val_loader))}%] || Time: {m} min {s} || Loss: {mean_loss:.3f} ''', end="\r")
    
    print(f'''Val: {epoch} [100%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}           ''')
    cm = (np.array(cm.T, dtype=np.float)/np.sum(cm, 1)).T
    print(cm)

In [14]:
start = time.time()
model.train()
for epoch in range(30):
    train(epoch)
    if not epoch%1:
        val(epoch)

Train Epoch: 0 [0%] || Time: 0 min 51 || Loss: nan || Acc: 0.39171

KeyboardInterrupt: 

| Model  | Train Acc  | Val Acc |
|---|---|---|
| HMM  | 0.57  | 0.54  |
| HMM + 5-Conv  | 0.971  | 0.70 |
|   |   |   |