In [1]:
import pandas as pd
import numpy as np
import torch as T
from torch import nn
import torch.nn.functional as F

import chemprop
import rdkit

import pickle as pkl
import gzip

from tqdm.notebook import tqdm
import argparse

In [None]:
import covid
from covid.modules import *
from covid.data import *

from covid.modules.chemistry import MPNEncoder

In [6]:
DROPOUT_RATE = 0.2

In [7]:
def create_chemprop_args():
    args = argparse.Namespace()
    args.seed = 0
    args.ensemble_size = 1
    args.hidden_size = 300
    args.bias = False
    args.depth = 3
    args.dropout = 0.0
    args.activation = 'ReLU'
    args.undirected = False
    args.atom_messages = False
    
    return args

In [8]:
args = create_chemprop_args()

In [9]:
all_data = pd.read_csv('./data/stitch_preprocessed.csv.gz')
with gzip.open('./data/stitch_proteins.pkl.gz', 'rb') as f:
    all_proteins = pkl.load(f)
with gzip.open('./data/stitch_chemicals.pkl.gz', 'rb') as f:
    all_chemicals = pkl.load(f)

In [10]:
batch = chemprop.features.BatchMolGraph(
    [all_chemicals[x][4] for x in all_data['item_id_a'].iloc[:5].values], 
    args
)

In [11]:
f_batch = T.stack([
    T.tensor(all_chemicals[x][5]) for x in all_data['item_id_a'].iloc[:5].values
])

In [14]:
chem_model = MPNEncoder(
    batch.atom_fdim, 
    batch.bond_fdim, 
    layers_per_message=2, 
    dropout=DROPOUT_RATE
)

In [87]:
protein_model = nn.Sequential(
    # 21->100 channels inplace convolution
    apply_to_protein_batch(nn.Conv1d(23, 100, (1, ), 1, 0)),
    
    #Do some resnet
    create_resnet_block_1d(100, 16, inner_kernel=3, for_protein_batch=True),
    create_resnet_block_1d(100, 16, inner_kernel=5, for_protein_batch=True),
    create_resnet_block_1d(100, 16, inner_kernel=7, for_protein_batch=True),
    
    # Scale it down
    DownscaleConv1d(100, 200, 4, maxpool=True, for_protein_batch=True),
    apply_to_protein_batch(nn.Dropout(DROPOUT_RATE)),
    
    #Do some resnet
    create_resnet_block_1d(200, 32, inner_kernel=3, for_protein_batch=True),
    create_resnet_block_1d(200, 32, inner_kernel=5, for_protein_batch=True),
    
    # Scale it down again
    DownscaleConv1d(200,400,4,'silu', maxpool=True, for_protein_batch=True),
    apply_to_protein_batch(nn.MaxPool1d(10000, ceil_mode=True)),
    
    # Convert protein batch to standard batch format
    ProteinBatchToPaddedBatch(),
    
    Squeeze(-1),
    nn.Dropout(DROPOUT_RATE),
    nn.Linear(400,400),
    nn.Tanhshrink(),
    nn.Tanh(),
)

In [88]:
protein_names = np.random.choice(list(all_proteins.keys()), 5)
x = create_protein_batch(
    [encode_protein(all_proteins[n]).unsqueeze(0) for n in protein_names]
)

In [101]:
protein_model(x)

tensor([[ 6.0182e-06,  8.1120e-01, -2.6744e-03,  ...,  3.9421e-01,
         -4.6660e-02, -5.5384e-05],
        [-3.6413e-04,  8.2134e-01, -5.5198e-02,  ...,  6.1332e-02,
          5.6472e-03, -3.7811e-02],
        [-1.6838e-05,  8.0771e-02, -4.7725e-05,  ...,  2.8394e-01,
          1.7621e-06,  7.4050e-02],
        [-1.6651e-03,  7.5296e-01, -2.6436e-02,  ...,  3.3927e-03,
         -1.0249e-02,  7.0409e-02],
        [ 1.1029e-02,  7.8234e-01, -7.2476e-02,  ...,  8.1359e-02,
          1.9961e-02, -4.0960e-02]], grad_fn=<TanhBackward>)

In [102]:
chem_model(batch, f_batch)

tensor([[0.3962, 0.2333, 0.0000,  ..., 0.2204, 0.3245, 0.3500],
        [0.3026, 0.2606, 0.0000,  ..., 0.2903, 0.3789, 0.2969],
        [0.3438, 0.2644, 0.0000,  ..., 0.2807, 0.3148, 0.2879],
        [0.3375, 0.2875, 0.0000,  ..., 0.2471, 0.3576, 0.2391],
        [0.3318, 0.4053, 0.0000,  ..., 0.3227, 0.4726, 0.2699]],
       grad_fn=<StackBackward>)