In [1]:
import pandas as pd
import numpy as np
import torch as T
from torch import nn
import torch.nn.functional as F

import chemprop
import rdkit

import pickle as pkl
import gzip

from tqdm.notebook import tqdm
import argparse

In [2]:
import covid
from covid.modules import *
from covid.data import *

from covid.modules.chemistry import MPNEncoder

In [3]:
DROPOUT_RATE = 0.2

In [4]:
def create_chemprop_args():
    args = argparse.Namespace()
    args.seed = 0
    args.ensemble_size = 1
    args.hidden_size = 300
    args.bias = False
    args.depth = 3
    args.dropout = 0.0
    args.activation = 'ReLU'
    args.undirected = False
    args.atom_messages = False
    
    return args

In [5]:
args = create_chemprop_args()

In [6]:
all_data = pd.read_csv('./data/stitch_preprocessed.csv.gz')
with gzip.open('./data/stitch_proteins.pkl.gz', 'rb') as f:
    all_proteins = pkl.load(f)
with gzip.open('./data/stitch_chemicals.pkl.gz', 'rb') as f:
    all_chemicals = pkl.load(f)

In [7]:
class BatchMolGraph(chemprop.features.BatchMolGraph):
    def to(self, *args, **kwargs):
        self.f_atoms = self.f_atoms.to(*args, **kwargs)
        self.f_bonds = self.f_bonds.to(*args, **kwargs)
        self.a2b = self.a2b.to(*args, **kwargs)
        self.b2a = self.b2a.to(*args, **kwargs)
        self.b2revb = self.b2revb.to(*args, **kwargs)
        if self.b2b is not None:
            self.b2b = self.b2b.to(*args, **kwargs)
        if self.a2a is not None:
            self.a2a = self.a2a.to(*args, **kwargs)
            
        return self

In [8]:
batch = BatchMolGraph(
    [all_chemicals[x][4] for x in all_data['item_id_a'].iloc[:5].values], 
    args
)

In [9]:
f_batch = T.stack([
    T.tensor(all_chemicals[x][5]) for x in all_data['item_id_a'].iloc[:5].values
])

In [10]:
chem_model = MPNEncoder(
    batch.atom_fdim, 
    batch.bond_fdim, 
    layers_per_message=2, 
    dropout=DROPOUT_RATE
)

In [11]:
protein_model = nn.Sequential(
    # 21->100 channels inplace convolution
    apply_to_protein_batch(nn.Conv1d(23, 100, (1, ), 1, 0)),
    
    #Do some resnet
    create_resnet_block_1d(100, 32, inner_kernel=3, for_protein_batch=True),
    create_resnet_block_1d(100, 32, inner_kernel=5, for_protein_batch=True),
    create_resnet_block_1d(100, 32, inner_kernel=7, for_protein_batch=True),
    
    # Scale it down
    DownscaleConv1d(100, 300, 4, maxpool=True, for_protein_batch=True),
    apply_to_protein_batch(nn.Dropout(DROPOUT_RATE)),
    
    #Do some resnet
    create_resnet_block_1d(300, 32, inner_kernel=3, for_protein_batch=True),
    create_resnet_block_1d(300, 32, inner_kernel=5, for_protein_batch=True),
    
    # Scale it down again
    DownscaleConv1d(300,600,4,'silu', maxpool=True, for_protein_batch=True),
    apply_to_protein_batch(nn.MaxPool1d(10000, ceil_mode=True)),
    
    # Convert protein batch to standard batch format
    ProteinBatchToPaddedBatch(),
    
    Squeeze(-1),
    nn.Dropout(DROPOUT_RATE),
    nn.Linear(600,600),
    nn.Tanhshrink(),
    nn.Tanh(),
)

In [12]:
protein_names = np.random.choice(list(all_proteins.keys()), 5)
x = create_protein_batch(
    [encode_protein(all_proteins[n]).unsqueeze(0) for n in protein_names]
)

In [13]:
protein_model(x).shape

torch.Size([5, 300])

In [14]:
chem_model(batch, f_batch).shape

torch.Size([5, 300])

In [15]:
protein_model.to('cuda')
chem_model.to('cuda')
x = x.to('cuda')
f_batch = f_batch.to('cuda')
batch = batch.to('cuda')

In [16]:
protein_model(x).shape

torch.Size([5, 300])

In [17]:
chem_model(batch, f_batch).shape

torch.Size([5, 300])

In [None]:
class 