In [1]:
import sys
import os

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from tqdm import tqdm

In [2]:
df = pd.read_pickle('/Users/alicegao/work/psi-lab-sandbox/rna_ss/data_processing/rnafold_mini_data/data/rand_seqs_var_len_sample_mfe_10_50_100.pkl.gz')

In [3]:
df.head()

Unnamed: 0,ensemble_diversity,free_energy,len,mfe_frequency,one_idx,seq
0,2.48,-1.1,19,0.648747,"([0, 1, 2, 10, 11, 12], [12, 11, 10, 2, 1, 0])",GACCGCUAAUGUCGAAUCU
1,1.42,-6.0,30,0.705145,"([2, 3, 4, 7, 8, 9, 10, 15, 16, 17, 18, 21, 22...",ACCUGAUUACCAAACGGUGUCCAGAAAGCC
2,3.84,-5.3,46,0.204397,"([7, 8, 9, 10, 11, 18, 19, 20, 32, 33, 34, 40,...",CUCAAAUUCCAUUAUACACGAUAUAUUCUCGCUCGGCACCGUGGAG
3,1.35,-4.0,18,0.864861,"([0, 1, 2, 3, 10, 11, 12, 13], [13, 12, 11, 10...",CUGCUAACGCGCAGCAAU
4,13.75,-7.1,48,0.147935,"([5, 6, 7, 8, 11, 12, 16, 17, 20, 21, 22, 23, ...",GUCUAGCUUGAGAACUUUGAAGGCGCACCGUGUCAAAGCCGUGUAUCG


In [4]:
class RnaFeDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(RnaFeDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
#         pass
        return ['rna_fe.dataset']

    def download(self):
        pass
    
    def process(self):
        
        data_list = []
        
        for _, row in df.iterrows():
            seq = row['seq']
            fe = row['free_energy']
            one_idx = row['one_idx']
            
            # use integer encoding for now
            seq = seq.upper().replace('A', '1').replace('C', '2').replace('G', '3').replace('T', '4').replace('U', '4').replace('N', '0')
            node_features = np.asarray(list(map(int, list(seq))), dtype=np.int16)
            node_features = torch.LongTensor(node_features).unsqueeze(1)  # FIXME dtype
            
            # build edges
            edge_from = []
            edge_to = []
            # chain - undirected edge for now
            node_left = range(0, len(seq) - 1)
            node_right = range(1, len(seq))
            edge_from.extend(node_left)
            edge_to.extend(node_right)
            edge_from.extend(node_right)
            edge_to.extend(node_left)
            # pair matrix - undirected edge 
            for idx_left, idx_right in zip(one_idx[0], one_idx[1]):
                edge_from.append(idx_left)
                edge_to.append(idx_right)
                edge_from.append(idx_right)
                edge_to.append(idx_left)
            edge_index = torch.tensor([edge_from, edge_to], dtype=torch.long)
            
            # target value
            assert not np.isnan(fe)
            y = fe
            
            # make data point
            data = Data(x=node_features, edge_index=edge_index, y=y)
            data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [5]:
dataset = RnaFeDataset(root=os.path.join(os.getcwd(), 'dataset/rna_1/'))

In [6]:
len(dataset)

100

In [7]:
dataset = dataset.shuffle()
train_dataset = dataset[:80]
val_dataset = dataset[80:90]
test_dataset = dataset[90:]
len(train_dataset), len(val_dataset), len(test_dataset)

(80, 10, 10)

In [8]:
from torch_geometric.data import DataLoader

In [9]:
batch_size= 5
train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [10]:

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F


In [11]:
class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
#         self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_lin = torch.nn.Linear(in_channels + out_channels, out_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

#         print('sage, before', x_j.shape)
        x_j = self.lin(x_j)
#         print('sage, after lin', x_j.shape)
        x_j = self.act(x_j)
#         print('sage, after act', x_j.shape)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
#         print('sage, before update', new_embedding.shape)
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
#         print('sage, after update', new_embedding.shape)
        
        return new_embedding

In [12]:
embed_dim = 5
n_hid = 10

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = SAGEConv(embed_dim, n_hid)
#         self.pool1 = TopKPooling(n_hid, ratio=0.8)
        self.conv2 = SAGEConv(n_hid, n_hid)
#         self.pool2 = TopKPooling(n_hid, ratio=0.8)
        self.conv3 = SAGEConv(n_hid, n_hid)
#         self.pool3 = TopKPooling(n_hid, ratio=0.8)
        self.item_embedding = torch.nn.Embedding(num_embeddings=5, embedding_dim=embed_dim)
        self.lin1 = torch.nn.Linear(n_hid * 2, n_hid)
        self.lin2 = torch.nn.Linear(n_hid, n_hid//2)
        self.lin3 = torch.nn.Linear(n_hid//2, 1)
        self.bn1 = torch.nn.BatchNorm1d(n_hid)
        self.bn2 = torch.nn.BatchNorm1d(n_hid//2)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()        
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.item_embedding(x)
        x = x.squeeze(1)    
#         print('after embedding', x.shape)
#         print(x[0, :])

        x = F.relu(self.conv1(x, edge_index))
#         print('after conv1', x.shape)
#         print(x[0, :])

        #x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
#         x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
#         print('after conv2', x.shape)
#         print(x[0, :])
     
        #x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
#         x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))

        #x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
#         x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)      
        x = F.dropout(x, p=0.5, training=self.training)

#         x = torch.sigmoid(self.lin3(x)).squeeze(1)
        x = self.lin3(x).squeeze(1)

        return x

In [13]:
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
# crit = torch.nn.BCELoss()
crit = torch.nn.MSELoss()

In [14]:
# def train():
#     model.train()

#     loss_all = 0
#     for data in train_loader:
#         data = data.to(device)
#         optimizer.zero_grad()
#         output = model(data)
#         label = data.y.to(device)
#         loss = crit(output, label)
#         loss.backward()
#         loss_all += data.num_graphs * loss.item()
#         optimizer.step()
#     return loss_all / len(train_dataset)

In [15]:
from scipy.stats import spearmanr

def evaluate(loader):
    model.eval()

    predictions = []
    labels = []

    with torch.no_grad():
        for data in loader:

            data = data.to(device)
            pred = model(data).detach().cpu().numpy()

            label = data.y.detach().cpu().numpy()
            predictions.append(pred)
            labels.append(label)

    predictions = np.hstack(predictions)
    labels = np.hstack(labels)
    
    corr, pval = spearmanr(labels, predictions)
    return corr

In [16]:
for epoch in range(10):
#     loss = train()


    model.train()

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y.to(device)
        loss = crit(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
    loss = loss_all / len(train_dataset)

    train_acc = evaluate(train_loader)
    val_acc = evaluate(val_loader)    
    test_acc = evaluate(test_loader)
    print('Epoch: {:03d}, Loss: {:.5f}, Train corr: {:.5f}, Val corr: {:.5f}, Test corr: {:.5f}'.
          format(epoch, loss, train_acc, val_acc, test_acc))

Epoch: 000, Loss: 43.31988, Train corr: 0.28330, Val corr: 0.81709, Test corr: 0.38298
Epoch: 001, Loss: 31.71949, Train corr: 0.41872, Val corr: 0.83538, Test corr: 0.39514
Epoch: 002, Loss: 27.84745, Train corr: 0.46341, Val corr: 0.78660, Test corr: 0.54104
Epoch: 003, Loss: 28.94456, Train corr: 0.52783, Val corr: 0.84758, Test corr: 0.63830
Epoch: 004, Loss: 26.15729, Train corr: 0.55678, Val corr: 0.89026, Test corr: 0.69301
Epoch: 005, Loss: 24.83331, Train corr: 0.57378, Val corr: 0.94514, Test corr: 0.72949
Epoch: 006, Loss: 28.18099, Train corr: 0.57836, Val corr: 0.89026, Test corr: 0.70517
Epoch: 007, Loss: 24.37705, Train corr: 0.59364, Val corr: 0.93294, Test corr: 0.70517
Epoch: 008, Loss: 28.49526, Train corr: 0.60867, Val corr: 0.89026, Test corr: 0.69301
Epoch: 009, Loss: 26.67515, Train corr: 0.62437, Val corr: 0.93294, Test corr: 0.63222


In [None]:
# TODO check in to github
# TODO how to force re-process dataset?
# add back pooling
# take into account length? - mean+max might not work
# more data