In [None]:
import sys
sys.path.append("/home/projects/kaggle_nesp/pse")
sys.path.append("/home/projects/kaggle_nesp/pse/nesp/pLM/esm")
sys.path.append("..")

import os
import pandas as pd
import numpy as np
import random 

os.environ["GPU"] = "0"

In [None]:
# ssm2_meta_pkl = "/home/data/02vip.datasets/metas/ssm2_meta.pkl"
# q6428_meta_pkl = "/home/data/02vip.datasets/metas/q6428_meta.pkl"

# ssm2_meta = pd.read_pickle(ssm2_meta_pkl)
# q6428_meta = pd.read_pickle(q6428_meta_pkl)

# # replace
# ssm2_meta.pdb_path = ssm2_meta.pdb_path.apply(lambda r: r.replace("/public/home/chengyifan/data", "/home/data"))
# ssm2_meta.wt_path = ssm2_meta.wt_path.apply(lambda r: r.replace("/public/home/chengyifan/data", "/home/data"))

# # check
# ssm2_meta['tag'] = ssm2_meta.apply(
#     lambda r: r.sequence[r.pos - 1] == r.mut,
#     axis = 1
# )

# ssm2_meta['pdb_tag'] = ssm2_meta.pdb_path.apply(os.path.exists)
# ssm2_meta['wt_tag'] = ssm2_meta.wt_path.apply(os.path.exists)

# save_ssm2_meta = ssm2_meta[ssm2_meta.tag & ssm2_meta.pdb_tag & ssm2_meta.wt_tag]
# # save_ssm2_meta.reset_index(drop=True).drop(columns=['tag', 'pdb_tag', 'wt_tag']).to_csv("./ssm2_meta.csv", index=False)

In [None]:
from utils.data import load_wt_mut_pdb_pair
from utils.misc import recursive_to

import esm

import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import radius_graph
from torch import nn
from torch_geometric.nn import GCN2Conv, GCNConv, Sequential, global_max_pool, GATConv
import torch.nn.functional as F

from scipy.stats import pearsonr

from tqdm import tqdm
tqdm.pandas()

from Bio.PDB.Polypeptide import index_to_one

In [None]:
ssm2_meta_path = "./ssm2_meta.csv"

ssm2_meta_df = pd.read_csv("./ssm2_meta.csv")

In [None]:
seed = 42
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True # slowly
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
from typing import Any


class PairGraph(Data):
    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
        if key == "edge_index_mut":
            return self.x_mut.size(0)
        if key == "edge_index_wt":
            return self.x_wt.size(0)
        return super().__inc__(key, value, *args, **kwargs)
    
    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
        return super().__cat_dim__(key, value, *args, **kwargs)

In [None]:
class ProtDataset(Dataset):
    def __init__(self, meta_path, radius=5, graph_pt=None, **kwargs) -> None:
        super().__init__()
        self.meta_path = meta_path
        self.radius = radius
        self.graph_pt = graph_pt
        print(device)
        
        meta_df = pd.read_csv(meta_path)
        meta_df = self.__check_columns(meta_df)
        
        if graph_pt is not None:
            print("Parameter `graph_pt` is not None, loading graph data file from {} ...".format(graph_pt))
            self.datasets = torch.load(graph_pt)
            
        else:
            esmv2, batch_converter = ProtDataset.get_esm_model(device)
            print("esmv2 initial done.")
            
            self.datasets = []
            for _, items in tqdm(meta_df.iterrows()):
                seq, wt, pos, mut = items.sequence, items.wt, items.pos, items.mut
                pdb_path, wt_path = items.pdb_path, items.wt_path
                ddG = items.ddG
                
                batch = load_wt_mut_pdb_pair(wt_path, pdb_path)
                
                # check sequence
                resseq = batch['mut']['aa'].cpu()
                assert resseq.shape[0] == 1 and len(resseq.shape) == 2  # single chain
                parse_seq = [index_to_one(x) for x in np.array(resseq[0])]
                parse_seq = "".join(seq)
                if parse_seq != seq:
                    print("warnings: [MisMatch] the `sequence` in meta_df does not match the parsed `sequence` from pdb file. \
                                The sequence in meta is {}, the items has been skip.".format(seq))
                    continue
                    
                if parse_seq[pos-1] != mut:
                    print("warnings: [MisMatch] residue mismatch on mutant position, the mutation residue {} from meta does not match \
                          the residue {} parsed from the pdb file at position {}.".format(mut, parse_seq[pos-1], pos))
                    continue
                
                wt_seq = parse_seq[:pos-1] + wt + parse_seq[pos:]
                logits_wt, reprs_wt, contacts_wt = self.__esm2_infer(wt_seq, esmv2, batch_converter, device)
                logits_mut, reprs_mut, contacts_mut = self.__esm2_infer(parse_seq, esmv2, batch_converter, device)
                
                self.datasets.append(self.__build_graph(batch, reprs_wt, reprs_mut, radius=radius), pos, ddG)
        
    def __len__(self):
        return len(self.datasets)
    
    def __getitem__(self, idx):
        protgraph, mut_pos, target = self.datasets[idx]
        
        return protgraph, int(mut_pos), target
    
    def __build_graph(self, batch, wt_reprs, mut_reprs, radius):
        
        x_wt_pos_ca = batch['wt']['pos14'][0, :, 1, :]  # (N, 3) position of CA
        x_mut_pos_ca = batch['mut']['pos14'][0, :, 1, :]
        
        wt_reprs = torch.Tensor(wt_reprs[0, 1:-1, :]) # (N, V)
        mut_reprs = torch.Tensor(mut_reprs[0, 1:-1, :])
        # assert x_wt_pos_ca.shape[0] == wt_reprs.shape[0] and x_mut_pos_ca.shape[0] == mut_reprs.shape[0]
        
        edge_index_mut = radius_graph(x=x_mut_pos_ca, r=radius)
        edge_index_wt = radius_graph(x=x_wt_pos_ca, r=radius)
        
        data = PairGraph(x_mut=mut_reprs, x_wt=wt_reprs, edge_index_mut=edge_index_mut, edge_index_wt=edge_index_wt)

        return data
        
    def __check_columns(self, meta_df):
        cols = ['sequence', 'mut', 'pos', 'wt', 'pdb_path', 'wt_path', 'ddG']
        for col in cols:
            assert col in meta_df.columns, "column {} not in meta df.".format(col)
        
        # check pos
        for i, row in meta_df.iterrows():
            seq, pos, mut = row.sequence, row.pos, row.mut
            assert seq[pos - 1] == mut, "position error, residue {} in sequence at position {} does not match residue {} in meta df.".format(seq[pos-1], pos, mut)
        
        return meta_df
    
    def __esm2_infer(self, sequence, esmv2, batch_converter, device):
        data = [("protein1", sequence)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_tokens = batch_tokens.to(device)

        with torch.no_grad():
            results = esmv2.forward(batch_tokens, repr_layers=[33], need_head_weights=True, return_contacts=True)
        logits = results['logits'].detach().cpu().numpy()
        reps = results['representations'][33].detach().cpu().numpy()
        contacts = results['contacts'].detach().cpu().numpy()
        
        return logits, reps, contacts
    
    @staticmethod
    def get_esm_model(device):
        t_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        batch_converter = alphabet.get_batch_converter()
        t_model.eval()  
        t_model.to(device)
        
        return t_model, batch_converter
    
    def save_graphs(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            
        name = os.path.basename(self.meta_path).split(".")[0] + "_r" + str(self.radius) + ".pt"
        save_path = os.path.join(save_dir, name)
        
        torch.save(self.datasets, save_path)

In [92]:
# ssm2_dataset_r5 = ProtDataset(meta_path=ssm2_meta_path, radius=5, graph_pt="../data/ssm2_meta_r5.pt")
# ssm2_dataset_r6 = ProtDataset(meta_path=ssm2_meta_path, radius=6, graph_pt="../data/ssm2_meta_r6.pt", device=device)

ssm2_trainset_r5 = ProtDataset(meta_path="../data/ssm2_meta_r5_train.csv", radius=5, graph_pt="../data/ssm2_meta_r5_train.pt")
ssm2_testset_r5 = ProtDataset(meta_path='../data/ssm2_meta_r5_test.csv', radius=5, graph_pt="../data/ssm2_meta_r5_test.pt")

cuda
Parameter `graph_pt` is not None, loading graph data file from ../data/ssm2_meta_r5_train.pt ...
cuda
Parameter `graph_pt` is not None, loading graph data file from ../data/ssm2_meta_r5_test.pt ...


### split train val set

In [93]:
train_loader = DataLoader(ssm2_trainset_r5, batch_size=128, shuffle=True, follow_batch=['x_mut', 'x_wt'])
val_loader = DataLoader(ssm2_testset_r5, batch_size=128, shuffle=True, follow_batch=['x_mut', 'x_wt'])
batch = next(iter(train_loader))
prots, mut_pos, targets = batch
prots

PairGraphBatch(x_mut=[5504, 1280], x_mut_batch=[5504], x_mut_ptr=[129], x_wt=[5504, 1280], x_wt_batch=[5504], x_wt_ptr=[129], edge_index_mut=[2, 14016], edge_index_wt=[2, 14122])

In [None]:
class ProtMutNet(nn.Module):
    def __init__(self, gnn_mode="GAT") -> None:
        super().__init__()
        if gnn_mode == "GAT":
            self.graph_encoder = Sequential('x, edge_index', [
                (GATConv(1280, 512, head=4), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
                (GATConv(512, 256, head=4), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
                (GATConv(256, 128, head=4), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
            ])
        elif gnn_mode == "GCN":
            self.graph_encoder = Sequential('x, edge_index', [
                (GCNConv(1280, 512), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
                (GCNConv(512, 256), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
                (GCNConv(256, 128), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
            ])
        else:
            raise
        
        # self.local_project = nn.Sequential([nn.Linear(128, 128), nn.ReLU(inplace=True)])
        # self.global_project = nn.Sequential([nn.Linear(128, 128), nn.ReLU(inplace=True)])
        
        self.head = nn.Sequential(
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout1d(p=0.3),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
    
    def forward(self, pairgraphs, mut_pos):
        prot_mut_embeds = self.graph_encoder(pairgraphs.x_mut, pairgraphs.edge_index_mut)
        prot_wt_embeds = self.graph_encoder(pairgraphs.x_wt, pairgraphs.edge_index_wt)
        
        # whole residue embeds
        global_mut_embeds = global_max_pool(prot_mut_embeds, batch=pairgraphs.x_mut_batch)
        global_wt_embeds = global_max_pool(prot_wt_embeds, batch=pairgraphs.x_wt_batch)
        
        # mutation amino acid embeds
        mut_pos_ids = mut_pos + pairgraphs.x_mut_ptr[:-1]
        wt_pos_ids = mut_pos + pairgraphs.x_wt_ptr[:-1]
        local_mut_embeds = prot_mut_embeds[mut_pos_ids]
        local_wt_embeds = prot_wt_embeds[wt_pos_ids]
        
        diff_embeds = torch.concat([global_mut_embeds, local_mut_embeds], dim=1) - torch.concat([global_wt_embeds, local_wt_embeds], dim=1)
        
        # prot_embeds_diff = torch.concat([local_mut_embeds, local_wt_embeds], dim=1)
        
        out = self.head(diff_embeds)
        
        return out 

In [None]:
gcn_net = ProtMutNet(gnn_mode="GCN")
gat_net = ProtMutNet(gnn_mode="GAT")

net = gat_net

net.to(device=device)

optim = torch.optim.Adam(net.parameters() ,lr=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=2, gamma=0.98)
loss_func = torch.nn.SmoothL1Loss()

In [97]:
from sklearn.isotonic import spearmanr, spearmanr


train_epochs = 50

for epoch in range(train_epochs):
    if epoch != 0:
        scheduler.step()
    
    # train
    losses = 0.   
    total_preds = []
    total_targets = [] 
    net.train()
    for batch in train_loader:
        prots, mut_pos, targets = batch
        
        prots = prots.to(device)
        targets = targets.to(device)
        mut_pos = mut_pos.to(device)
        
        preds = net(prots, mut_pos=mut_pos)
        
        loss = loss_func(preds.flatten(), targets)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        losses += loss.item()
        
        total_preds.extend(preds.cpu().detach().numpy().flatten().tolist())
        total_targets.extend(targets.cpu().numpy().tolist())
    
    mean_mae = losses / len(train_loader)
    pr, p_value = pearsonr(total_targets, total_preds)
    lr = [x['lr'] for x in optim.param_groups][0]
    print("epoch ==> {}, lr ==>{:.5f}, mean_mae ==> {}, pearsonr ==> {}".format(epoch, lr, mean_mae, pr))
    
    # test
    test_losses = 0.   
    total_preds = []
    total_targets = [] 
    net.eval()
    for batch in val_loader:
        prots, mut_pos, targets = batch
        
        prots = prots.to(device)
        targets = targets.to(device)
        mut_pos = mut_pos.to(device)
        
        with torch.no_grad():
            preds = net(prots, mut_pos=mut_pos)
        
        loss = loss_func(preds.flatten(), targets)
        test_losses += loss.item()
        
        total_preds.extend(preds.cpu().detach().numpy().flatten().tolist())
        total_targets.extend(targets.cpu().numpy().tolist())
    
    mean_mae = test_losses / len(val_loader)
    pr, p_value = pearsonr(total_targets, total_preds)
    sr, p_value = spearmanr(total_targets, total_preds)
    print("              test_mean_mae ==> {}, test_pearsonr ==> {}, test_spearsonr ==> {}".format(mean_mae, pr, sr))
print("done.")

epoch ==> 0, lr ==>0.00046, mean_mae ==> 0.16887036619745954, pearsonr ==> 0.7867216985958094
              test_mean_mae ==> 0.13537550474015567, test_pearsonr ==> 0.8876148007538485, test_spearsonr ==> 0.8200148049406488
epoch ==> 1, lr ==>0.00046, mean_mae ==> 0.16349446378192123, pearsonr ==> 0.7940141029964163
              test_mean_mae ==> 0.1582452654838562, test_pearsonr ==> 0.8883140342721183, test_spearsonr ==> 0.8046316852195086
epoch ==> 2, lr ==>0.00045, mean_mae ==> 0.16993432963381008, pearsonr ==> 0.7825307761692307
              test_mean_mae ==> 0.1394504331625425, test_pearsonr ==> 0.8883455822383632, test_spearsonr ==> 0.8040453511622456
epoch ==> 3, lr ==>0.00045, mean_mae ==> 0.15784757219407022, pearsonr ==> 0.8007719411116674
              test_mean_mae ==> 0.1451554332788174, test_pearsonr ==> 0.880785375424649, test_spearsonr ==> 0.8012034875532902
epoch ==> 4, lr ==>0.00044, mean_mae ==> 0.1552451336566283, pearsonr ==> 0.8033900883231406
              test_

In [None]:
import argparse

import torch
import torch.nn.functional as F
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import HeteroConv, Linear, SAGEConv
from torch_geometric.utils import trim_to_layer

# parser = argparse.ArgumentParser()
# parser.add_argument('--device', type=str, default='cuda')
# parser.add_argument('--use-sparse-tensor', action='store_true')
# args = parser.parse_args()
use_sparse_tensor = True
device = "cuda"

device = device if torch.cuda.is_available() else 'cpu'

transforms = [T.ToUndirected(merge=True)]
if use_sparse_tensor:
    transforms.append(T.ToSparseTensor())
dataset = OGB_MAG(root='/home/data/tmp', preprocess='metapath2vec',
                  transform=T.Compose(transforms))
data = dataset[0].to(device, 'x', 'y')


class HierarchicalHeteroGraphSage(torch.nn.Module):
    def __init__(self, edge_types, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv(
                {
                    edge_type: SAGEConv((-1, -1), hidden_channels)
                    for edge_type in edge_types
                }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict,
                num_sampled_nodes_dict):

        for i, conv in enumerate(self.convs):
            x_dict, edge_index_dict, _ = trim_to_layer(
                layer=i,
                num_sampled_nodes_per_hop=num_sampled_nodes_dict,
                num_sampled_edges_per_hop=num_sampled_edges_dict,
                x=x_dict,
                edge_index=edge_index_dict,
            )

            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}

        return self.lin(x_dict['paper'])


model = HierarchicalHeteroGraphSage(
    edge_types=data.edge_types,
    hidden_channels=64,
    out_channels=dataset.num_classes,
    num_layers=2,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

kwargs = {'batch_size': 1024, 'num_workers': 0}
train_loader = NeighborLoader(
    data,
    num_neighbors=[10] * 2,
    shuffle=True,
    input_nodes=('paper', data['paper'].train_mask),
    **kwargs,
)

val_loader = NeighborLoader(
    data,
    num_neighbors=[10] * 2,
    shuffle=False,
    input_nodes=('paper', data['paper'].val_mask),
    **kwargs,
)


def train():
    model.train()

    total_examples = total_loss = 0
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()

        out = model(
            batch.x_dict,
            batch.adj_t_dict
            if use_sparse_tensor else batch.edge_index_dict,
            num_sampled_nodes_dict=batch.num_sampled_nodes_dict,
            num_sampled_edges_dict=batch.num_sampled_edges_dict,
        )

        batch_size = batch['paper'].batch_size
        loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size])
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples


@torch.no_grad()
def test(loader):
    model.eval()

    total_examples = total_correct = 0
    for batch in tqdm(loader):
        batch = batch.to(device)
        out = model(
            batch.x_dict,
            batch.adj_t_dict
            if use_sparse_tensor else batch.edge_index_dict,
            num_sampled_nodes_dict=batch.num_sampled_nodes_dict,
            num_sampled_edges_dict=batch.num_sampled_edges_dict,
        )

        batch_size = batch['paper'].batch_size
        pred = out[:batch_size].argmax(dim=-1)
        total_examples += batch_size
        total_correct += int((pred == batch['paper'].y[:batch_size]).sum())

    return total_correct / total_examples


for epoch in range(1, 6):
    loss = train()
    val_acc = test(val_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')

In [None]:
train_data = next(iter(train_loader))

val_data = next(iter(val_loader))