In [1]:
import numpy as np
import pandas as pd
import sys
from tqdm.notebook import tqdm
import networkx as nx
sys.path.append('..')

import torch
from torch.functional import F
import torch.nn as nn

from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.utils import from_networkx, to_networkx, degree
from torch_geometric.nn import GATConv, GCNConv, global_add_pool, PNAConv, BatchNorm, CGConv, global_max_pool
from torch_geometric.utils.metric import accuracy, precision, f1_score
import torch_geometric.transforms as T

from models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.metrics import roc_auc_score, accuracy_score, average_precision_score, precision_score

from utils.data_gen import load_prot_embs, to_categorical

dev = torch.device('cuda:0')

In [2]:
prot_embs, global_dict = load_prot_embs(512, norm=False)

In [3]:
def wcsv2graph(fname, global_dict, y):
    """
    Weighted Graph Creator
    """
    sample = pd.read_csv('../snac_data/' + fname)
    
    G = nx.from_pandas_edgelist(sample, source='node1', target='node2', 
                            edge_attr=['sign', 'weight'], create_using=nx.DiGraph())

    n1a1d = sample[['node1','downact1']]
    n1a1u = sample[['node1','upact1']]
    n1a1d.columns = ['node','downact']
    n1a1u.columns = ['node', 'upact']

    n2a2d = sample[['node2','downact2']]
    n2a2u = sample[['node2','upact2']]
    n2a2d.columns = ['node','downact']
    n2a2u.columns = ['node', 'upact']
    
    nad = pd.concat([n1a1d,n2a2d])
    nad = nad.drop_duplicates('node')
    nad = nad.set_index('node')
    nad['downacts'] = nad[['downact']].apply(lambda x: np.hstack(x), axis=1)
    nad = nad.drop(['downact'], axis=1)['downacts'].to_dict()
    
    nau = pd.concat([n1a1u,n2a2u])
    nau = nau.drop_duplicates('node')
    nau = nau.set_index('node')
    nau['upacts'] = nau[['upact']].apply(lambda x: np.hstack(x), axis=1)
    nau = nau.drop(['upact'], axis=1)['upacts'].to_dict()
    
    nx.set_node_attributes(G, global_dict,'global_idx')
    nx.set_node_attributes(G, nad, 'downacts')
    nx.set_node_attributes(G, nau, 'upacts')
    
    data = from_networkx(G)
    
    G = to_networkx(data)
    data.weight = data.weight.float()
    data.downacts, data.upacts = data.downacts.double(), data.upacts.float()
    data.acts = torch.cat([data.downacts.float(), data.upacts.float()], dim=-1).double()
    data.downacts = data.upacts = None
    
    data.sign[data.sign < 0] = 0
    data.sign = data.sign.long()
    data.sign = to_categorical(data.sign, 2).reshape(-1,2).float()
    
    data.y = torch.tensor(y)
    data.label = torch.tensor(np.argmax(y)).view(-1).long()
    
    return data

class SNLDataset(Dataset):
    def __init__(self, fnames, y, global_dict):
        super(SNLDataset, self).__init__()
        self.fnames = fnames
        self.gd = global_dict
        self.y = y
        
    def len(self):
        return len(self.fnames)
        
    def get(self, idx):
        return wcsv2graph(self.fnames[idx], self.gd, self.y[idx])

In [4]:
labelled_ugraphs = pd.read_csv('../snac_data/graph_classification_all.csv')
weighted_df = pd.read_csv('../snac_data/file_info_weighted.csv')

val_set_1 = pd.read_csv('../snac_data/splits/val_set_1.csv')
val_set_2 = pd.read_csv('../snac_data/splits/val_set_2.csv')
val_set_3 = pd.read_csv('../snac_data/splits/val_set_3.csv')
val_set_4 = pd.read_csv('../snac_data/splits/val_set_4.csv')
test_set = pd.read_csv('../snac_data/splits/test_set.csv')

In [5]:
wsample_path = weighted_df.files_weighted.to_numpy()[200]
data = wcsv2graph(wsample_path, global_dict, [0,0,1])

In [60]:
usm = pd.DataFrame(labelled_ugraphs.groupby('sig_id').moa_v1.unique()).reset_index()
usm_corr = np.array([np.array(i) for i in usm.moa_v1.to_numpy()]).reshape(-1)
usm['moa_v1'] = usm_corr

X_df = pd.merge(weighted_df, usm, on='sig_id')
val_df =  pd.merge(X_df, val_set_4, on='sig_id')
test_df = pd.merge(X_df, test_set, on='sig_id')

for sig in tqdm(test_set.sig_id):
    X_df = X_df[X_df['sig_id'] != sig]

HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))




In [61]:
X_train, y_train = X_df.files_weighted.to_numpy(), X_df.moa_v1.to_numpy()
X_val, y_val = val_df.files_weighted.to_numpy(), val_df.moa_v1_x.to_numpy()
X_test, y_test = test_df.files_weighted.to_numpy(), test_df.moa_v1_x.to_numpy()

In [62]:
le = OneHotEncoder()
y = np.concatenate([y_train, y_val, y_test])
le = le.fit(y.reshape(-1,1))
y_train = le.transform(y_train.reshape(len(y_train),-1)).toarray()
y_val = le.transform(y_val.reshape(len(y_val),-1)).toarray()
y_test = le.transform(y_test.reshape(len(y_test), -1)).toarray()

In [64]:
train_data = SNLDataset(X_train, y_train, global_dict)
val_data = SNLDataset(X_val, y_val, global_dict)
test_data = SNLDataset(X_test, y_test, global_dict)

train_loader = DataLoader(train_data, batch_size=1, num_workers=12, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, num_workers=12, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, num_workers=12)

In [10]:
# Compute in-degree histogram over training data.
def deg_distr():
    deg = torch.zeros(58, dtype=torch.long)
    for data in tqdm(train_data):
        d = degree(data.edge_index[0], num_nodes=data.num_nodes, dtype=torch.long)
        deg += torch.bincount(d, minlength=58)
    return deg

In [11]:
class PostEncoding(nn.Module):
    def __init__(self, emb_dim):
        super(PostEncoding,self).__init__()
        self.emb_dim = emb_dim
        
        self.act_emb = nn.Linear(2, emb_dim)
        
    def forward(self, data):
        return self.act_emb(data.acts.float())

class Net(torch.nn.Module):
    def __init__(self, n_classes, pretrained_weights=None, train_embs=True):
        super(Net, self).__init__()
        self.pretrained_weights = pretrained_weights
        if pretrained_weights is not None:
            self.n_prots, self.in_channels = pretrained_weights.shape
        else:
            self.n_prots = 919
            self.in_channels = 512

        self.node_emb = nn.Embedding(self.n_prots, self.in_channels, sparse=True)
        self.node_emb.weight.requires_grad = train_embs
        
        self.edge_emb = nn.Embedding(2, 50)
        self.pe = PostEncoding(self.in_channels)

        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for _ in range(2):
            conv = PNAConv(in_channels=self.in_channels, out_channels=self.in_channels,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           edge_dim=50, towers=4, pre_layers=1, post_layers=1,
                           divide_input=False)
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(self.in_channels))
            
        self.fc1 = nn.Linear(self.in_channels, 2 * self.in_channels)
        self.fc_out = nn.Linear(2 * self.in_channels, n_classes)
        self.act = nn.PReLU()    
        
        self.init_weights()
        
    def init_weights(self):
        if self.pretrained_weights is not None:
            initrange = 0.1
            self.node_emb.weight.data.copy_(torch.from_numpy(self.pretrained_weights))
            self.edge_emb.weight.data.uniform_(-initrange, initrange)
        else:
            initrange = 0.1
            emb_layer.weight.data.uniform_(-initrange, initrange)
            
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                      m.bias.data.fill_(0.0)

    def forward(self, data):
        x = self.node_emb(data.global_idx)
        edge_attr = self.edge_emb(data.sign)
        act = self.pe(data)
        x = torch.add(x, act)

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, data.edge_index, edge_attr)))

        x = global_add_pool(x, data.batch)
        x = self.act(self.fc1(x))
        return F.log_softmax(self.fc_out(x), dim=-1)
    
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        #pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
    
class FF(nn.Module):
    def __init__(self, encoder, in_channels, n_classes):
        super(FF, self).__init__()
        
        self.encoder = encoder
        self.fc_out = nn.Linear(in_channels, n_classes)
        
    def forward(self, data):
        x = self.encoder(data)
        
        x = global_max_pool(x, data.batch)
        x = F.log_softmax(self.fc_out(x), dim=1)
        return x

In [12]:
N_CLASSES = 255
EMB_DIM = 512

model = BigNet(N_CLASSES, pretrained_weights=prot_embs).to(dev)
encoder = GraphTransformerEncoder(n_layers=2, n_heads=8, n_hid=512, 
                            pretrained_weights=prot_embs, summarizer=None).to(dev)
model = FF(encoder, EMB_DIM, N_CLASSES).to(dev)
ls = LabelSmoothingLoss(N_CLASSES, smoothing=0.2)

In [13]:
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.98)

In [65]:
def train(epoch):
    model.train()
    total_train_acc = []
    for tb in train_data_iterator:
        tb = tb.to(dev)
        optimizer.zero_grad()
        
        pred = model(tb).view(tb.num_graphs, -1)
        y = tb.y.reshape(tb.num_graphs, -1)
        y_l = tb.label
        
        #loss = F.nll_loss(pred, y_l)
        loss = ls(pred, y_l)
        
        acc_t = accuracy(torch.argmax(pred, dim=1), y_l)
        roc_t = roc_auc_score(y.long().detach().cpu().numpy(), 
                                pred.detach().cpu().numpy(), average='samples')
        
        train_data_iterator.set_postfix(Epoch=epoch+1,
                                        loss ='%.4f' % float(loss.item()),
                                        acc = '%.4f' % float(acc_t),
                                        roc= '%.4f' % float(roc_t))
        
        loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
        optimizer.step()
        total_train_acc.append(acc_t)

        del tb
    
    print(f'Total Training Accuracy: {np.mean(total_train_acc)}')
    
    val_data_iterator = tqdm(test_loader,leave=True,unit='batch',
                        postfix={'Epoch': epoch+1,'val_loss': '%.4f' % 0.0, 
                                 'acc': '%.4f' % 0.0,
                                 'roc': '%.4f' % 0.0})
    total_val_acc = []
    with torch.no_grad():
        for tb in val_data_iterator:
            tb = tb.to(dev)

            pred = model(tb)
            y = tb.y.reshape(tb.num_graphs, -1)
            y_l = tb.label

            #loss = F.nll_loss(pred, y_l)
            loss = ls(pred, y_l)
            
            acc_v = accuracy(torch.argmax(pred, dim=1), y_l)
            roc_v = roc_auc_score(y.long().detach().cpu().numpy(), 
                                        pred.detach().cpu().numpy(), average='samples')

            val_data_iterator.set_postfix(Epoch=epoch+1,
                                            val_loss ='%.4f' % float(loss.item()),
                                            acc = '%.4f' % float(acc_v),
                                            roc= '%.4f' % float(roc_v))
            total_val_acc.append(acc_v)
    
        
    print(f'Total Validation Accuracy: {np.mean(total_val_acc)}')
    scheduler.step()

In [66]:
m1 = FF(encoder, EMB_DIM, N_CLASSES).to(dev)
m2 = FF(encoder, EMB_DIM, N_CLASSES).to(dev)
m3 = FF(encoder, EMB_DIM, N_CLASSES).to(dev)
m4 = FF(encoder, EMB_DIM, N_CLASSES).to(dev)

m1.load_state_dict(torch.load('../snac_data/checkpoints/val_set_1/moa_vs_1_val_acc_0.5059288537549407.pt'))
m2.load_state_dict(torch.load('../snac_data/checkpoints/val_set_2/moa_vs_2_val_acc_0.2686084142394822.pt'))
m3.load_state_dict(torch.load('../snac_data/checkpoints/val_set_3/moa_vs_3_val_acc_0.47396963123644253.pt'))
m4.load_state_dict(torch.load('../snac_data/checkpoints/val_set_4/moa_vs_4_val_acc_0.3865546218487395.pt'))

<All keys matched successfully>

In [70]:
class Ensemble(nn.Module):
    def __init__(self, models):
        super(Ensemble, self).__init__()
        
        self.mlist = nn.ModuleList([*models])
        self.fc1 = nn.Linear(2048, 255)
        
        
        for m in self.mlist:
            m.fc_out = nn.Identity()
            for param in m.parameters():
                param.requires_grad_(False)
        
    def forward(self, data):
        total_preds = []
        m1 = self.mlist[0]
        m2 = self.mlist[1]
        m3 = self.mlist[2]
        m4 = self.mlist[3]
            
        preds = torch.cat([m1(data), m2(data), m3(data), m4(data)]).reshape(-1)
        preds = F.log_softmax(self.fc1(preds), dim=-1) 
        return preds
    
model = Ensemble([m1, m2, m3, m4]).to(dev)

In [71]:
EPOCHS = 25
lr = 1.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.98)

for epoch in range(EPOCHS):
    print('-' * 100)
    print('-' * 100)
    train_data_iterator = tqdm(train_loader,leave=True,unit='batch',
                        postfix={'Epoch': epoch+1,'loss': '%.4f' % 0.0,
                        'loss': '%.4f' % 0.0,
                        'acc': '%.4f' % 0.0,
                        'roc': '%.4f' % 0.0})
    train(epoch)

----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=2616.0), HTML(value='')))


Total Training Accuracy: 0.04128440366972477


HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [15]:
def test_eval(models, tloader):
    preds = []
    for model in models:
        model.eval()
        with torch.no_grad():
            for tb in tqdm(tloader):
                tb = tb.to(dev)
                y_l = tb.label
                pred = model(tb).view(1,-1).detach().cpu().numpy()

                preds.append(pred)
            
    return preds

def per_drug_acc(df):
    unique_drugs = df['rdkit.x'].unique()
    s = 0
    for drug in unique_drugs:
        filt = df[df['rdkit.x']==drug]
        score = accuracy_score(filt['true'], filt['predicted'])
        nunique_moa = filt['predicted'].nunique()
        if score >= (1/nunique_moa):
            s = s + 1
    return(s/len(unique_drugs))

def per_sig_acc(df):
    unique_sigs = df['sig_id'].unique()
    s = 0
    for sig in unique_sigs:
        filt = df[df['sig_id']==sig]
        score = accuracy_score(filt['true'], filt['predicted'])
        nunique_moa = filt['predicted'].nunique()
        if score >= (1/nunique_moa):
            s = s + 1
    return(s/len(unique_sigs))

In [16]:
preds = test_eval([m1, m2, m3, m4], test_loader)

HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))




In [18]:
preds = np.array(preds)
p_reshaped = preds.reshape(4, 1031, 255)

In [19]:
pred_mean = torch.mean(torch.tensor(p_reshaped), dim=0)

In [20]:
pred_labels = torch.argmax(pred_mean, dim=-1)
true_labels = np.argmax(y_test, axis=-1)

In [21]:
test_set['true'] = true_labels
test_set['predicted'] = pred_labels

In [22]:
per_drug_acc(test_set)

0.375

In [23]:
per_sig_acc(test_set)

0.1794871794871795