In [1]:
import torch
from torch_geometric.data import Data
from tqdm import tqdm

# Creating Edges
Index string_db, and connect nodes with a combined score over 500 (arbitrary)

In [2]:
import json

stringdb = json.load(open("../Corpuses/string_db.json"))

Mapping all gene names to a vocab, then creating an edge connection tensor

In [3]:
edge_index = [[], []] # starts, ends
vocab = {}
idx = 0
for k, v in tqdm(stringdb.items()):
    if 'combined_score' in v and int(v['combined_score']) > 500:
        if k not in vocab:
            vocab[k] = idx
            idx += 1
        if v['protein2'] not in vocab:
            vocab[v['protein2']] = idx
            idx += 1
        
        edge_index[0].extend([vocab[k], vocab[v['protein2']]])
        edge_index[1].extend([vocab[v['protein2']], vocab[k]])
edge_index = torch.tensor(edge_index, dtype=torch.long)
print(edge_index.shape)

100%|██████████| 19386/19386 [00:00<00:00, 430696.58it/s]

torch.Size([2, 3702])





# Creating Node Feature Vectors

In [4]:
import sys
import torch
from bioservices import UniProt
sys.path.append("../")
from protbert.inference import ProtBertModule

protbert = ProtBertModule(torch.device('cpu'))
u = UniProt(verbose=False)

Creating directory /home/dan/.config/bioservices 


ModuleNotFoundError: No module named 'transformers'

In [5]:
from pprint import pprint
from tqdm import tqdm

feat_vecs = [] # attach in protbert inference here

for node in tqdm(vocab.keys()):
    # Finding associated protein for each gene
    o = u.search(node, frmt="fasta").split("\n")
    fasta = "".join(o[1:])
    
    # inference
    fasta = protbert.encode(fasta)
    preds = protbert(fasta).logits
    preds = torch.squeeze(preds).cpu().detach().numpy()
    feat_vecs.append(preds)
    del preds
    del fasta

  0%|          | 1/3289 [00:02<1:56:57,  2.13s/it]

In [5]:
# Random feat vecs without inference
feat_vecs = []
for node in tqdm(vocab.keys()):
    feat_vecs.append(torch.rand(5))

feat_vecs = torch.stack(feat_vecs)

100%|██████████| 3289/3289 [00:00<00:00, 48156.87it/s]


In [6]:
dataset = Data(x=feat_vecs, edge_index=edge_index)

In [8]:
dataset.x.shape[1]

5

# Learning

Playing with ways to incorporate methylation values into batch

# Model

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
import pytorch_lightning as pl

class PSPred(pl.LightningModule):
    def __init__(self, num_classes, num_nodes, graph):
        super().__init__(self)
        self.save_hyperparameters()

        self.mlp = nn.Sequential(
            tgnn.CGConv(self.hparams.num_nodes + graph.x.shape[1], 128),
            nn.ReLU(inplace=True),
            tgnn.CGConv(128, self.hparams.num_classes),
        )
    
    def forward(self, methyl, idx):
        gnn = torch.cat([self.graph.x, methyl])
        gnn = self.mlp(gnn)
        return gnn

    def training_step(self, batch, idx):
        methyl, labels = batch
        preds = self.forward(methyl, idx)
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(preds, labels)

        self.log("train_loss": loss)
        return {"loss": loss}

    def validation_step(self, batch, idx):
        methyl, labels = batch
        preds = self.forward(methyl, idx)
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(preds, labels)

        self.log("val_loss", loss)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=1, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=1)


# Dataset

In [None]:
import torch
from torch.utils.data import Dataset

class MethylationDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        cur = self.data[idx]
        return {"data": cur[0], "label": cur[1]}


# DataModule

In [None]:
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch_geometric.nn as tgnn
import pytorch_lightning as pl

class PSDataModule(pl.LightningDataModule):
    """Handling the methylation data in batches"""
    def __init__(self, data_dir, feat_dim):
        super().__init__()
        self.data_dir = data_dir
        self.feat_dim = feat_dim

    def setup(self, data=None):
        # Load in the data
        # Methylation data structured as feature vectors padded to size N

        # Generate random training data
        if not data:
            data = []

            for i in range(100):
                cur = []
                for j in range(len(vocab.keys())):
                    cur.append((random.random(), random.randint(0,1))) # random label 
                data.append(cur)
        

        train, validate, test = np.split(data, [int(len(data)*0.8), int(len(data)*0.9)])

        self.train_dataset = MethylationDataset(train)
        self.test_dataset = MethylationDataset(test)
        self.val_dataset = MethylationDataset(val)

    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, shuffle=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset)

    def val_dataloader(self):
        return DataLoader(self.val_dataset)
    
    

    
