In [5]:
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
import seaborn as sns
import numpy as np
DATA = "/home/cyril/Documents/These/data/proteinnet"

## Data

In [6]:
import pandas as pd
from scipy.sparse import csr_matrix
AA = [4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 24, 25, 27]

def np_onehot(a, shape):
    if shape[0]  == -1:
        shape = (len(a), shape[1])
    onehot = np.zeros(shape)
    onehot[np.arange(len(a)), a] = 1
    return onehot

def sparse_distance(x, mask, max_dist = 8):
    mask = np.where(np.array(mask))[0]
    contactmap = squareform(pdist(np.array(x)[mask]))
    contactmap = (max_dist - contactmap) * (contactmap < max_dist)
    return csr_matrix(contactmap)

class TrainData(object):
    def __init__(self, file, size = 512):
        self.primary, self.cm, self.length, self.mask = [], [], [], []
        for i in range(1, 100):
            df = pd.read_json(f"{DATA}/contact_map_train_{i}.json")
            df["length"] = df.primary.apply(lambda x : len(x))
            df = df[df.length < size]
            self.length += list(df.length)
            self.primary += list(df.apply(lambda x : np_onehot(x["primary"], (-1, 28))[np.where(np.array(x["mask"]))[0]][:, AA], axis=1))
            self.cm += list(df.apply(lambda x : sparse_distance(x["tertiary"], x["mask"]), axis = 1))
            del df
            
    def __len__(self):
        return len(self.primary)
    
    def __getitem__(self, i):
        return self.primary[i], self.cm[i], self.length[i]
    
class ValData(object):
    def __init__(self, file, size = 512):
        self.primary, self.cm, self.length, self.mask = [], [], [], []
        df = pd.read_json(f"{DATA}/contact_map_valid.json")
        df["length"] = df.primary.apply(lambda x : len(x))
        df = df[df.length < size]
        self.length += list(df.length)
        self.primary += list(df.apply(lambda x : np_onehot(x["primary"], (-1, 28))[np.where(np.array(x["mask"]))[0]][:, AA], axis=1))
        self.tertiary += list(df.apply(lambda x : sparse_distance(x["tertiary"], x["mask"]), axis = 1))
        self.cm += list(df.apply(lambda x : sparse_distance(x["tertiary"], x["mask"]), axis = 1))
        del df
            
    def __len__(self):
        return len(self.primary)
    
    def __getitem__(self, i):
        return self.primary[i], self.cm[i], self.length[i]

In [4]:
training_set = TrainData(f"{DATA}/contact_map_train.json")
train_loader = DataLoader(training_set, shuffle = True)

validation_set = ValData(f"{DATA}/contact_map_valid.json")
val_loader = DataLoader(validation_set, shuffle = True)

AttributeError: 'ValData' object has no attribute 'tertiary'

## Model

In [18]:
import torch
from torch_geometric.nn import VGAE, GMMConv
from torch import nn, optim

In [19]:
class Encoder(nn.Module):
    def __init__(self, in_channels = 20, out_channels = 8, N = 128, bias=True):
        super(Encoder, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.conv1 = GMMConv(in_channels, 100, 1, 8)
        self.conv2 = GMMConv(100, 100, 1, 8)
        self.conv3 = GMMConv(100, 100, 1, 8)
        self.conv4 = GMMConv(100, out_channels, 1, 8)
        
    def forward(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
        h = self.conv3(h)
        h = self.conv4(h)
        return h

In [20]:
device = torch.device('cpu')

encoder = Encoder()
model = VGAE(encoder).float()

optimizer = optim.Adam(model.parameters(), lr=0.01)
model

VGAE(
  (encoder): Encoder(
    (conv1): GMMConv(20, 100)
    (conv2): GMMConv(100, 100)
    (conv3): GMMConv(100, 100)
    (conv4): GMMConv(100, 8)
  )
  (decoder): InnerProductDecoder()
)

In [None]:
from sklearn.metrics import confusion_matrix

def hinge_loss(model, x, y, m = 1):
    e = -model(x)
    e_bar = torch.min(e+y*1e9, 1, keepdim=True).values.view(e.size(0), 1, 
                                                                 e.size(-1))
    loss = F.relu(m+(e-e_bar)*y)[:,:,10:-10]
    return loss.sum()/(e.size(0))


def aa_acc(x, recon_x):
    r"""
    Evaluate the ratio of amino acids retrieved in the reconstructed sequences

    Args:
        x (torch.Tensor): true sequence(s)
        recon_x (torch.Tensor): reconstructed sequence(s)
    """
    x = x[:, :, 10:-10]
    recon_x = recon_x[:,:,10:-10]
    empty = torch.max(x, 1)[0].view(-1)
    x = torch.argmax(x, 1).view(-1)
    recon_x = torch.argmax(recon_x, 1).view(-1)
    return (((x==recon_x) * (empty!=0)).int().sum().item())/((empty!=0).int().sum().item())


def train(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    model.train()
    for batch_idx, data in enumerate(train_loader):
        x = data[0].float().permute(0, 2, 1).to(device)
        s = data[1].float().permute(0, 2, 1).to(device)
        length = data[2].int().to(device)
        # Optimization
        optimizer.zero_grad()
        loss = hinge_loss(model, x, s)
        loss.backward()
        optimizer.step()
#         print(d_0["visible"].argmax(-1)[0], d_f["visible"].argmax(-1)[0])
        acc = aa_acc(s, model(x))

        del x; del s
        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = (mean_acc*batch_idx + acc)/ (batch_idx+1)
        m, s = int(time.time()-start)//60, int(time.time()-start)%60
        print(f'''Train Epoch: {epoch} [{int(100*batch_idx/len(train_loader))}%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}''', end="\r")
    
def val(epoch):
    mean_loss, mean_reg, mean_acc = 0, 0, 0
    model.eval()
    cm = np.zeros((8,8))
    for batch_idx, data in enumerate(val_loader):
        x = data[0].float().permute(0, 2, 1).to(device)
        s = data[1].float().permute(0, 2, 1).to(device)
        
        # Optimization
        loss = hinge_loss(model, x, s)
        acc = aa_acc(s, model(x))
        
        cm += confusion_matrix(s.argmax(1).view(-1), 
                         model(x).argmax(1).view(-1), labels = [0,1,2,3,4,5,6,7])
        # Metrics
        mean_loss = (mean_loss*batch_idx + loss.item())/ (batch_idx+1)
        mean_acc = (mean_acc*batch_idx + acc)/ (batch_idx+1)
        

        m, s = int(time.time()-start)//60, int(time.time()-start)%60
        print(f'''Val: {epoch} [{int(100*batch_idx/len(val_loader))}%] || Time: {m} min {s} || Loss: {mean_loss:.3f} ''', end="\r")
    
    print(f'''Val: {epoch} [100%] || Time: {m} min {s} || Loss: {mean_loss:.3f} || Acc: {mean_acc:.3f}           ''')
    cm = (np.array(cm.T, dtype=np.float)/np.sum(cm, 1)).T
    print(cm)