In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import numpy as np
import time
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import functional as F

from pgm.layers import MarkovChainLayer, OneHotLayer, Layer
from pgm.data import Seq_SS_Data
from pgm.ebm import EnergyModel
from pgm.nn import ConvNet, ConvBlock
from pgm.metrics import hinge_loss, aa_acc
from pgm.utils import I

## Data

In [3]:
DATA = "/home/cyril/Documents/These/data/secondary_structure"
batch_size = 16

In [4]:
aa = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K','L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
keys = [4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 24, 25, 27]
aa_dict = {k:i for i,k in enumerate(keys)}

AAINDEX = "/home/cyril/Documents/These/data/aaindex"
df_index = pd.read_csv(f"{AAINDEX}/aa_index.csv", index_col = 0)
df_index[aa] = ((df_index[aa].values.T - df_index[aa].mean(1).values)/df_index[aa].std(1).values).T
df_index = df_index.fillna(0)
AA_MAT = df_index[aa].values.T

In [5]:
def np_onehot(a, shape):
    onehot = np.zeros(shape)
    onehot[np.arange(len(a)), a] = 1
    return onehot

def to_aaindex(a, size):
    aaindex = np.zeros((size, AA_MAT.shape[1]))
    for i, x in enumerate(a):
        if x == 26:
            aaindex[i] = np.zeros(AA_MAT.shape[1])
        else:
            aaindex[i] = AA_MAT[aa_dict[x]]
    return aaindex

class Seq_SS_Data(object):
    def __init__(self, file, size = 512):
        self.primary, self.ss3 = [], []
        df = pd.read_json(file)
        df["length"] = df.primary.apply(lambda x : len(x))
        df = df[df.length <= size]
        
        self.length = list(df.length)
        self.primary = list(df.primary.apply(lambda d : np.array(d)))
        self.ss3 = list(df.ss3.apply(lambda d : np_onehot(d, (size, 3))))
        del df
        self.ss3 = np.array(self.ss3)
            
    def __len__(self):
        return len(self.primary)
    
    def __getitem__(self, i):
        if self.length[i] <= 128:
            return to_aaindex(self.primary[i], 128), self.ss3[i, :128], self.length[i]
        cursor = np.random.randint(0, self.length[i]-128)
        return to_aaindex(self.primary[i][cursor:cursor+128], 128), self.ss3[i, cursor:cursor+128], self.length[i]

In [6]:
train_dataset = Seq_SS_Data(f"{DATA}/secondary_structure_train.json", size = 512)
train_loader = DataLoader(train_dataset, batch_size = batch_size, 
                        shuffle = True, drop_last=True)

In [7]:
val_dataset = Seq_SS_Data(f"{DATA}/secondary_structure_valid.json", size = 512)
val_loader = DataLoader(val_dataset, batch_size = batch_size, 
                        shuffle = True, drop_last=True)

In [8]:
t = np.zeros((3, 3))
for seq, length in zip(train_dataset.ss3, train_dataset.length):
    seq = np.argmax(seq, -1)
    x = seq[0]
    y = seq[0]
    for i in range(1, length):
        x, y = y, seq[i]
        t[x,y] += 1
t /= np.sum(t, 1)
t = torch.tensor(t).float()
t = (t+t.t())/2

## EM

In [9]:
from torch.distributions.normal import Normal

In [61]:
def leaky_relu(): 
    return nn.LeakyReLU(0.2, inplace = False)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3, bias=True):
        super(ResBlock, self).__init__()
        pad = (kernel_size-1)//2
        self.conv_1 = ConvBlock(nn.Conv1d, leaky_relu, nn.BatchNorm1d, 
                                in_channels, out_channels, kernel_size,
                                stride=1, padding=pad, bias=bias)
        self.conv2 = ConvBlock(nn.Conv1d, None, nn.BatchNorm1d, 
                                in_channels, out_channels, 1,
                                stride=1, padding=0, bias=bias)

    def forward(self, x):
        identity = x
        out = self.conv_1(x)
        identity = self.conv2(x)
        out += identity
        return out
    
class ResBlock2(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super(ResBlock2, self).__init__()
        self.conv_11 = ConvBlock(nn.Conv1d, leaky_relu, nn.BatchNorm1d, 
                                in_channels, out_channels, 3,
                                stride=1, padding=1, bias=bias)
        self.conv_12 = ConvBlock(nn.Conv1d, leaky_relu, nn.BatchNorm1d, 
                                out_channels, out_channels, 3,
                                stride=2, padding=1, bias=bias)
        self.conv2 = ConvBlock(nn.Conv1d, None, None, 
                                in_channels, out_channels, 3,
                                stride=2, padding=1, bias=bias)

    def forward(self, x):
        identity = x
        out = self.conv_11(x)
        out = self.conv_12(out)
        identity = self.conv2(x)
        out += identity
        return out

    
class ConvNet(nn.Module):
    def __init__(self, in_channels, out_channels = 100, bias=True):
        super(ConvNet, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.conv1 = ResBlock(in_channels, 100, 11)
        self.conv2 = ResBlock(100, 100, 11)
        self.conv3 = ResBlock(100, 100, 11)
        self.conv4 = ResBlock(100, 100, 11)
#         self.conv5 = ResBlock(100, 100, 11)
        self.conv5 = ConvBlock(nn.Conv1d, nn.PReLU, nn.BatchNorm1d,
                        100, out_channels, 1,
                        stride=1, padding=0, dilation=1)
    def forward(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
        h = self.conv3(h)
        h = self.conv4(h)
        h = self.conv5(h)
        return h
    
class ConvNet2(nn.Module):
    def __init__(self, in_channels, out_channels = 100, bias=True):
        super(ConvNet2, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        
#         self.att1 = SelfAttention(in_channels)
        self.l_conv1 = ResBlock(in_channels, 100, 5)
        self.l_conv2 = ResBlock(100, out_channels, 5)
        self.l_pool1 = nn.MaxPool1d(2) # 64

        self.m_conv1 = ResBlock2(out_channels, out_channels // 2, 5)
        self.m_pool1 = nn.MaxPool1d(2) # 16
        
        self.g_conv1 = ResBlock2(out_channels, out_channels, 5)
        self.g_pool1 = nn.MaxPool1d(2) # 16
        self.g_conv2 = ResBlock2(out_channels, out_channels // 2, 5)
        self.g_pool2 = nn.MaxPool1d(2) # 4

        self.g_dense = nn.Linear(out_channels * 2, out_channels)

    def forward(self, x):
        batch_size, _, N = x.size()
        h_l = self.l_conv1(x)
        h_l = self.l_conv2(h_l)
        h = self.l_pool1(h_l)
        
        h_m = self.m_pool1(self.m_conv1(h))
        h_m = torch.cat([h_m[:,:,i:i+1].expand(-1, -1, N//h_m.size(-1)) for i in range(h_m.size(-1))], -1)

        h_g = self.g_pool1(self.g_conv1(h))
        h_g = self.g_pool2(self.g_conv2(h_g)).view(batch_size, -1)
        h_g = self.g_dense(h_g).view(batch_size, -1, 1)
        h_g = h_g.expand(-1, -1, N)

        return F.softmax(torch.cat([h_l, h_m, h_g], 1)
    
class ConvNet3(nn.Module):
    def __init__(self, in_channels, out_channels = 100, N = 128, bias=True):
        super(ConvNet3, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.conv1 = ResBlock(in_channels, 100, 11)
        self.conv2 = ResBlock(100, 100, 11)
        self.lstm1 = nn.LSTM(input_size = 100, 
                            hidden_size = 50, 
                            num_layers = 2,
                            bias = True,
                            bidirectional = True)
        self.lstm2 = nn.LSTM(input_size = 100, 
                    hidden_size = 50, 
                    num_layers = 2,
                    bias = True,
                    bidirectional = True)

        self.linear = nn.Conv1d(100, out_channels, 1)
        
    def forward(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
        h = h.permute(2, 0, 1)
        h = self.lstm1(h)[0]
        h = self.lstm2(h)[0]
        return F.softmax(self.linear(h.permute(1, 2, 0)),1)

    
class AttentionNet(nn.Module):
    def __init__(self, in_channels, out_channels = 100, N = 128, bias=True):
        super(ConvNet3, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.conv1 = ResBlock(in_channels, 100, 11)
        self.conv2 = ResBlock(100, 100, 11)
        self.lstm1 = nn.LSTM(input_size = 100, 
                            hidden_size = 50, 
                            num_layers = 2,
                            bias = True,
                            bidirectional = True)
        self.lstm2 = nn.LSTM(input_size = 100, 
                    hidden_size = 50, 
                    num_layers = 2,
                    bias = True,
                    bidirectional = True)

        self.linear = nn.Conv1d(100, out_channels, 1)
        
    def forward(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
        h = h.permute(2, 0, 1)
        h = self.lstm1(h)[0]
        h = self.lstm2(h)[0]
        return F.softmax(self.linear(h.permute(1, 2, 0)),1)


In [15]:
class HMMLayer(Layer):
    r"""
    Layer of One Hot neurons linked by a Markov Chain

    Args:
        T (Numpy Array): Transition Matrix for the Markov Chain
        N (Integer): Number of neurons
        q (Integer): Number of values the neuron can take
        name (String): Name of the layer
    """
    def __init__(self, T, N = 100, q = 21, h = 100, name = "layer0"):
        super(HMMLayer, self).__init__(name)
        self.full_name = f"MC_{name}"
        self.N = N
        self.q = q
        self.h = h
        self.T = T.float()
        self.mu = nn.Parameter(torch.rand(self.q, self.h).float()/self.h, requires_grad=True)
        self.sig = nn.Parameter(torch.rand(self.q, self.h, self.h).float()/self.h, requires_grad=True)
        
    def E(self, x):
        tau = taus(x, self.T, self.mu, None)
#         mu = torch.cat([((tau[:,s:s+1,:] * Dx).sum(-1)/tau.sum(-1)).mean(0).view(1, -1) for s in range(self.q)],0)
#         sig = torch.cat([((tau[:,s:s+1,:] * Dx).sum(-1)/tau.sum(-1)).mean(0).view(1, -1) for s in range(self.q)],0)


        return tau
    
    def max_likelihood(self, x):
        return calcul_tau(x, self.T, self.mu, max_op)

In [12]:
import math

def log_pdf(x, mu, sigma = None):
    N = mu.size(0)
    y = x - mu
    Sy = y
    if sigma is not None:
        Sy = F.linear(x, sigma)
    logZ = (N/2)*math.log(2 * math.pi)
    return -0.5*(y * Sy).sum(1) - logZ

In [13]:
sum_op = lambda x : torch.sum(x, 1)
max_op = lambda x : torch.max(x, dim = 1)[0]

def alphas(Dx, T, mu, sigma, operator = sum_op):
    batch_size, _, N = Dx.size()
    k = mu.size(0)
    log_alphas = torch.zeros(batch_size, k, N)
    log_norm = torch.zeros(batch_size, N)
    
    log_alphas[:,:,0] = torch.cat([log_pdf(Dx[:,:,0], mu[j]).view(-1,1) for j in range(k)],1)
    log_norm[:,0] = torch.log((torch.exp(log_alphas[:,:,0])).sum(1))
    for s in range(1,N):
        for j in range(k):
            log_alphas[:,j,s] = log_norm[:, s-1] + log_pdf(Dx[:,:,s], mu[j]) 
            log_alphas[:,j,s] += torch.log(operator(torch.cat([torch.exp((log_alphas[:,i,s-1] - log_norm[:,s-1] + torch.log(T[i,j])).view(-1,1)) for i in range(k)],1)))
        log_norm[:,s] = log_norm[:,s-1] + torch.log(torch.exp(log_alphas[:,:,s]-log_norm[:,s-1].view(-1, 1)).sum(1)) 
    return(log_alphas, log_norm)

def betas(Dx, T, mu, sigma, operator = sum_op):
    batch_size, q, N = Dx.size()
    k = mu.size(0)
    log_betas = torch.zeros(batch_size, k, N)
    log_norm = torch.zeros(batch_size, N)
    
    log_norm[:, -1] = torch.log(torch.exp(log_betas[:, :, -1]).sum(1))
    for s in range(2,N+1):
        for j in range(k):
            log_betas[:,j,-s] = log_norm[:,-s+1] + torch.log(operator(torch.cat(
                    [torch.exp(
                        log_pdf(Dx[:,:,-s+1], mu[i]) + log_betas[:,i,-s+1] - log_norm[:,-s+1] + torch.log(T[j,i])).view(-1,1) for i in range(k)],1)))
        log_norm[:,-s] = log_norm[:,-s+1] + torch.log(torch.exp(log_betas[:,:,-s] - log_norm[:,-s+1].view(-1, 1)).sum(1))

    return(log_betas, log_norm)

def taus(Dx, T, mu, sigma, operator = sum_op):
    batch_size, _, N = Dx.size()        
    log_alp, lnorma = alphas(Dx, T, mu, sigma, operator)
    log_bet, lnormb = betas(Dx, T, mu, sigma, operator)
    tau = log_alp - lnorma.view(batch_size, 1, N) + log_bet - lnormb.view(batch_size, 1, N)

    return tau

## Train

In [19]:
N, qx, qs, h = 128, 77, 3, 8
# del model
device = torch.device('cpu')

x = OneHotLayer(torch.zeros(qx*N), N = N, q = qx, name = "x")
s = HMMLayer(t, N = N, q = qs, h = h, name = "ss")

Dx = ConvNet(qx, qs)
Ds = I
# model = EnergyModel(x, s, Dx, Ds).float()
model = Dx
# model = nn.Transformer(d_model = 77, 
#                        nhead = 7, 
#                        dim_feedforward = 256)

optimizer = optim.Adam(model.parameters(), lr=0.1)
model

ConvNet(
  (conv1): ResBlock(
    (conv_1): ConvBlock(
      (conv): Conv1d(77, 100, kernel_size=(11,), stride=(1,), padding=(5,))
      (activation): LeakyReLU(negative_slope=0.2)
      (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv2): ConvBlock(
      (conv): Conv1d(77, 100, kernel_size=(1,), stride=(1,))
      (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (conv2): ResBlock(
    (conv_1): ConvBlock(
      (conv): Conv1d(100, 100, kernel_size=(11,), stride=(1,), padding=(5,))
      (activation): LeakyReLU(negative_slope=0.2)
      (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv2): ConvBlock(
      (conv): Conv1d(100, 100, kernel_size=(1,), stride=(1,))
      (normalization): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (conv3): ResBlock(
    (c

In [20]:
###### TRANSFORMER #######

from sklearn.metrics import confusion_matrix

AA_TENS = torch.tensor(AA_MAT).view(1, 1, 20, 77)
MASK = model.generate_square_subsequent_mask(128)

def to_seq(x):
    return ((x.permute(1, 2, 0).reshape(16, 128, 1, 77).expand(16, 128, 20, 77) - AA_TENS)**2).sum(-1).argmin(-1)

def hinge_loss(model, x, y, m = 1):
    e = -model(x)
    e_bar = torch.min(e+y*1e9, 1, keepdim=True).values.view(e.size(0), 1, 
                                                                 e.size(-1))
    loss = F.relu(m+(e-e_bar)*y)[:,:,10:-10]
    return loss.sum()/(e.size(0))


def aa_acc(x, recon_x):
    r"""
    Evaluate the ratio of amino acids retrieved in the reconstructed sequences

    Args:
        x (torch.Tensor): true sequence(s)
        recon_x (torch.Tensor): reconstructed sequence(s)
    """
    x = x[:,10:-10]
    recon_x = recon_x[:,10:-10]
    return (x==recon_x).int().float().mean().item()


def train(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    model.train()
    for batch_idx, data in enumerate(train_loader):
        x = data[0].float().permute(1, 0, 2).to(device)
        s = data[1].float().permute(1, 0, 2).to(device)
        length = data[2].int().to(device)
        # Optimization
        optimizer.zero_grad()
        p = model(x, x, memory_mask = MASK)
        loss = F.mse_loss(p, x)
        acc = aa_acc(to_seq(x), to_seq(p))
        loss.backward()
        optimizer.step()

        del x; del s; del p
        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = (mean_acc*batch_idx + acc)/ (batch_idx+1)
        m, s = int(time.time()-start)//60, int(time.time()-start)%60
        print(f'''Train Epoch: {epoch} [{int(100*batch_idx/len(train_loader))}%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}''', end="\r")
    
def val(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    model.eval()
    cm = np.zeros((3,3))
    for batch_idx, data in enumerate(val_loader):
        x = data[0].float().permute(1, 0, 2).to(device)
        s = data[1].float().permute(1, 0, 2).to(device)
        p = model(x, x, memory_mask = MASK)

        # Optimization
        loss = F.mse_loss(p, x)
        acc = aa_acc(to_seq(x), to_seq(p))

        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = (mean_acc*batch_idx + acc)/ (batch_idx+1)
        

        m, s = int(time.time()-start)//60, int(time.time()-start)%60
    
    print(f'''Val: {epoch} [100%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}        ''')
    cm = (np.array(cm.T, dtype=np.float)/np.sum(cm, 1)).T
    print(cm)

AttributeError: 'ConvNet' object has no attribute 'generate_square_subsequent_mask'

In [21]:
####### CONVO ######

from sklearn.metrics import confusion_matrix

def hinge_loss(model, x, y, m = 1):
    e = -model(x)
    e_bar = torch.min(e+y*1e9, 
                      1, keepdim=True).values.view(e.size(0), 1, 
                                                    e.size(-1))
    loss = F.relu(m+(e-e_bar)*y)[:,:,10:-10]
    return loss.sum()/(e.size(0))


def aa_acc(x, recon_x):
    r"""
    Evaluate the ratio of amino acids retrieved in the reconstructed sequences

    Args:
        x (torch.Tensor): true sequence(s)
        recon_x (torch.Tensor): reconstructed sequence(s)
    """
    x = x[:, :, 10:-10]
    recon_x = recon_x[:,:,10:-10]
    empty = torch.max(x, 1)[0].view(-1)
    x = torch.argmax(x, 1).view(-1)
    recon_x = torch.argmax(recon_x, 1).view(-1)
    return (((x==recon_x) * (empty!=0)).int().sum().item())/((empty!=0).int().sum().item())


def train(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    model.train()
    for batch_idx, data in enumerate(train_loader):
        x = data[0].float().permute(0, 2, 1).to(device)
        s = data[1].float().permute(0, 2, 1).to(device)
        length = data[2].int().to(device)
        # Optimization
        optimizer.zero_grad()
        loss = hinge_loss(model, x, s)
#         p = model(x)
#         loss = F.cross_entropy(p, s.argmax(1))
        loss.backward()
        optimizer.step()
#         print(d_0["visible"].argmax(-1)[0], d_f["visible"].argmax(-1)[0])
        acc = aa_acc(s, model(x))

        del x; del s
        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = (mean_acc*batch_idx + acc)/ (batch_idx+1)
        m, s = int(time.time()-start)//60, int(time.time()-start)%60
        print(f'''Train Epoch: {epoch} [{int(100*batch_idx/len(train_loader))}%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}''', end="\r")
    
def val(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    model.eval()
    cm = np.zeros((3,3))
    for batch_idx, data in enumerate(val_loader):
        x = data[0].float().permute(0, 2, 1).to(device)
        s = data[1].float().permute(0, 2, 1).to(device)
        p = model(x)

        # Optimization
        loss = hinge_loss(model, x, s)
#         loss = F.cross_entropy(p, s.argmax(1))
        acc = aa_acc(s, p)
        
        cm += confusion_matrix(s.argmax(1).view(-1), 
                         p.argmax(1).view(-1), labels = [0,1,2])
        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = (mean_acc*batch_idx + acc)/ (batch_idx+1)
        

        m, s = int(time.time()-start)//60, int(time.time()-start)%60
    
    print(f'''Val: {epoch} [100%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}           ''')
    cm = (np.array(cm.T, dtype=np.float)/np.sum(cm, 1)).T
    print(cm)

In [22]:
start = time.time()

for i in range(50):
    train(i)
    val(i)

Val: 0 [100%] || Time: 1 min 5 || Loss: 286.896 || Acc: 0.675           
[[0.74933265 0.06660711 0.18406024]
 [0.18907586 0.48433181 0.32659233]
 [0.19557271 0.07252514 0.73190215]]
Val: 1 [100%] || Time: 2 min 16 || Loss: 286.294 || Acc: 0.677           
[[0.75359821 0.07031541 0.17608638]
 [0.22675492 0.54665797 0.22658712]
 [0.21026001 0.12369282 0.66604717]]
Val: 2 [100%] || Time: 3 min 19 || Loss: 289.151 || Acc: 0.660           
[[0.65579844 0.16627322 0.17792833]
 [0.09801595 0.73955127 0.16243278]
 [0.1161046  0.23279111 0.65110429]]
Val: 3 [100%] || Time: 4 min 35 || Loss: 285.324 || Acc: 0.682           
[[0.66965584 0.1220325  0.20831167]
 [0.11120562 0.66052145 0.22827293]
 [0.11656096 0.16416936 0.71926968]]
Val: 4 [100%] || Time: 5 min 39 || Loss: 284.876 || Acc: 0.685           
[[0.78808818 0.04182705 0.17008477]
 [0.20497161 0.45310399 0.3419244 ]
 [0.20370503 0.06088634 0.73540863]]
Val: 5 [100%] || Time: 6 min 38 || Loss: 284.497 || Acc: 0.687           
[[0.78767196

KeyboardInterrupt: 

| Model  | Train Acc  | Val Acc |
|---|---|---|
| HMM  | 0.57  | 0.54  |
| HMM + 5-Conv  | 0.971  | 0.70 |
|   |   |   |

In [23]:
for batch_idx, data in enumerate(val_loader):
    x = data[0].float().permute(0, 2, 1).to(device)
    s = data[1].float().permute(0, 2, 1).to(device)
    p = model(x)
    break