In [None]:
from torch_geometric.nn import GCNConv
import torch
import torch.nn.functional as F
from torch_geometric.data import InMemoryDataset, Data
from tqdm import tqdm
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch_geometric.loader import DataLoader 
import argparse

In [2]:
int2primary = {0: 'A', 1: 'C', 2: 'D', 3: 'E', 4: 'F', 5: 'G', 6: 'H', 7: 'I', 8: 'K', 9: 'L', 10: 'M', 11: 'N', 12: 'P', 13: 'Q', 14: 'R', 15: 'S', 16: 'T', 17: 'V', 18: 'W', 19: 'Y', 20: 'X'}
int2second = {0: 'G', 1: 'H', 2: 'I', 3: 'B', 4: 'E', 5: 'S', 6: 'T', 7: 'C'}


class ProteinDataset(InMemoryDataset):
    def __init__(self,
                 root='/Data/deeksha/disha/ProtTrans/data/adjacency_data/',
                 transform=None,
                 pre_transform=None):
        super(ProteinDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['g_data_train_short_q3.pkl']
    
    @property
    def processed_file_names(self):
        return ['data_q3.pt']
    
    def download(self):
        pass

    def process(self):
        with open(self.raw_paths[0], 'rb') as f:
            data = pkl.load(f)

        primary_seqs, seq_length, secondary_seqs, adjacency_matrices = [], [], [], []
        for k, v in data.items():
            primary_seqs.append(v[0])
            seq_length.append(v[-2])
            argmax = np.argmax(v[1], axis=1)
            secondary_seqs.append(argmax)
            adj = data[k][-1]
            adj = np.vstack((adj.row, adj.col))
            adjacency_matrices.append(adj)
        
        primary_seqs = np.array(primary_seqs)
        seq_length = np.array(seq_length)
        secondary_seqs = np.array(secondary_seqs)
        # adjacency_matrices = np.array(adjacency_matrices)
        seq_length = torch.tensor(seq_length, dtype=torch.long)
        data_list = []
        for i in tqdm(range(len(primary_seqs))):
            x = torch.tensor(primary_seqs[i], dtype=torch.float)
            y = torch.tensor(secondary_seqs[i], dtype=torch.long)
            edge_index = torch.tensor(adjacency_matrices[i], dtype=torch.long)
            data = Data(x=x, y=y, edge_index=edge_index, seq_len=seq_length[i])
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [49]:
dataset = ProteinDataset()
train_dataset = dataset[:10000]
val_dataset = dataset[10000:]
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)


In [None]:
for i in train_loader:
    print(i.seq_len)
    break

In [4]:
class GCN(torch.nn.Module):
    def __init__(self, input_feature=21, hidden_channels=128, num_classes=3):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GCNConv(input_feature, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        x = x.float()
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN().to('cuda')


In [76]:
class CrossEntropy(object):

    def __init__(self):
        pass

    def __call__(self, out, target, seq_len):
        out = out.view(len(seq_len), -1, 3)
        target = target.view(len(seq_len), -1)
        loss=0
        for i in range(len(seq_len)):
            o = out[i][:seq_len[i]]
            t = target[i][:seq_len[i]]
            loss += nn.CrossEntropyLoss()(o, t)
        return loss/len(seq_len)
        

def accuracy(out, labels, seq_length):
    acc = 0
    out = out.view(len(seq_length), -1, 3)
    labels = labels.view(len(seq_length), -1)
    for o, t, l in zip(out, labels, seq_length):
        o = o[:l]
        t = t[:l]
        acc += (o.argmax(1) == t).sum().item() / l

    return (acc / len(seq_length)).item()

In [62]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = CrossEntropy()

In [77]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()

    for data in tqdm(train_loader): 
         data = data.to(device)
         out = model(data.x, data.edge_index)  
         seq_length = torch.tensor(data.seq_len).to('cuda')
        #  print(seq_length)
         loss = criterion(out, data.y, seq_length)  
         loss.backward() 
         optimizer.step()  
         optimizer.zero_grad() 

def test(model, criterion, loader, accuracy, device):
     model.eval()

     total_accuracy, losses = 0, 0
     for data in tqdm(loader):
         data = data.to(device) 
         out = model(data.x, data.edge_index)  
        # load the seq length list to cuda
        #  print(type(data.seq_len))
        #  seq_length = torch.tensor(data.seq_len).to('cuda')
        #  print(seq_length)
         seq_length = data.seq_len
        #  print(seq_length, len(seq_length))
         loss = criterion(out, data.y, seq_length)  
         losses += loss.item() 
         total_accuracy += accuracy(out, data.y, seq_length)    
     return total_accuracy / len(loader), losses / len(loader)



In [None]:
train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []
device='cuda'
for epoch in range(2):
    train_epoch(model, train_loader, criterion, optimizer, device)
    train_acc, train_loss = test(model, criterion, val_loader, accuracy, device)
    test_acc, test_loss = test(model, criterion, val_loader, accuracy, device)

    train_loss_list.append(train_loss)
    train_acc_list.append(train_acc)
    test_loss_list.append(test_loss)
    test_acc_list.append(test_acc)

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
# save the model
torch.save(model.state_dict(), '/Data/deeksha/disha/ProtTrans/scripts/train/model/gcn_model.pt')

In [None]:
test_acc, test_loss = test(model, criterion, val_loader, accuracy, device)

In [None]:
print(train_acc_list)
plt.plot(train_loss_list, label='train_loss')
plt.plot(test_loss_list, label='test_loss')
plt.legend()
plt.savefig(f'/Data/deeksha/disha/ProtTrans/scripts/train/model/gcn_loss_q3.png')
plt.close()

plt.plot(train_acc_list, label='train_acc')
plt.plot(test_acc_list, label='test_acc')
plt.legend()
plt.savefig(f'/Data/deeksha/disha/ProtTrans/scripts/train/model/gcn_acc_q3.png')
plt.close()

In [None]:
int2primary = {0: 'A', 1: 'C', 2: 'D', 3: 'E', 4: 'F', 5: 'G', 6: 'H', 7: 'I', 8: 'K', 9: 'L', 10: 'M', 11: 'N', 12: 'P', 13: 'Q', 14: 'R', 15: 'S', 16: 'T', 17: 'V', 18: 'W', 19: 'Y', 20: 'X'}
# int2second = {0: 'G', 1: 'H', 2: 'I', 3: 'B', 4: 'E', 5: 'S', 6: 'T', 7: 'C'}
int2second = {0: 'H', 1: 'E', 2: 'C'}
model.eval()
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
primary_seqs, predicted_secondary, actual_secondary = [], [], []
with torch.no_grad():
    for batch in tqdm(val_loader):
        data = batch.to(device)
        out = model(data.x, data.edge_index)
        seq_len = data.seq_len
        pred = out.argmax(dim=1)[:seq_len]
        primary_feats = data.x.cpu().numpy().argmax(axis=1)[:seq_len]
        primary_seqs.append(primary_feats)
        actual_secondary.append(data.y.cpu().numpy()[:seq_len])
        predicted_secondary.append(pred.cpu().numpy())
        if len(predicted_secondary) == 10:
            break

    predicted_secondary_alpha= [[int2second[aa] for aa in sample] for sample in predicted_secondary]
    actual_secondary_alpha = [[int2second[aa] for aa in sample] for sample in actual_secondary]
    primary_seqs_alpha = [[int2primary[aa] for aa in sample] for sample in primary_seqs]

with open('demo.txt', 'w') as f:
    for i in range(len(predicted_secondary_alpha)):
        f.write(f'Primary: {" ".join(primary_seqs_alpha[i])}\n')
        f.write(f'Predicted: {" ".join(predicted_secondary_alpha[i])}\n')
        f.write(f'Actual: {" ".join(actual_secondary_alpha[i])}\n\n')
        print()
