In [0]:
import json
import pdb
import math
import dgl
import torch
import weakref
import numbers
import operator
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.utils.data
import cytoolz.curried as ct
import humanfriendly as hf
import itertools

import dgl.function as fn
import dgl.nn.pytorch as dglnn

import torch.nn.functional as F
import torch.utils.tensorboard as tb

from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import subprocess

TRUE_MLKN = True

In [0]:
settings = json.load(open('SETTINGS.json'))

models = settings['PREDICT']['MODELS']
data = settings['PREDICT']['INPUT']
mlk = settings['PREDICT']['Q9MLK']
output = settings['PREDICT']['OUTPUT']

CTYPES = ['1JHC', '1JHN', '2JHC', '2JHH', '2JHN', '3JHC', '3JHH', '3JHN']

## Data loading and transformations

In [0]:
test_data = torch.load(data)
if TRUE_MLKN:
    truemlkn = torch.load(mlk)
    for i in range(len(test_data)):
        natoms1 = test_data[i]['nodes']
        natoms2 = len(truemlkn[i]['mlkn'])
        assert natoms1==natoms2
        test_data[i]['ndata']['mulliken'] = truemlkn[i]['mlkn']

In [0]:
def make_dglg(mol, precision='double'):
    g = dgl.DGLGraph()
    nodes = mol['nodes']
    g.add_nodes(nodes)
    src, dst = mol['src'], mol['dst']
    g.ndata.update(mol['ndata'])
    g.ndata['type'] = g.ndata['type']
    g.add_edges(src, dst)
    g.add_edges(dst, src)
    g.add_edges(range(nodes), range(nodes))
    
    device = src.device
    zeros = torch.zeros(nodes, dtype=torch.float32, device=device)
    ones = torch.ones(nodes, dtype=torch.int64, device=device) 
    edata = {}
    edata['type'] = torch.cat([mol['edata']['type'].repeat(2), ones * 6])
    edata['distance'] = torch.cat([mol['edata']['distance'].repeat(2), zeros])
    edata['angle'] = torch.cat([mol['edata']['angle'].repeat(2), zeros])
    edata['dihedral'] = torch.cat([mol['edata']['dihedral'], -mol['edata']['dihedral'], zeros])
    edata['coupling_type'] = torch.cat([mol['edata']['coupling_type'].repeat(2).to(torch.int64), -ones])
    edata['coupling'] = torch.cat([mol['edata']['coupling'].repeat(2), zeros])
    edata['train_id'] = torch.cat([mol['edata']['train_id'].repeat(2), -ones])
    edata['test_id'] = torch.cat([mol['edata']['test_id'].repeat(2), -ones])

    g.edata.update(edata)
    return g

  
@ct.curry
def make_graph_batch(mols, precision='double'):
    g = dgl.batch([make_dglg(mol, precision) for mol in mols])
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
    return g

## Layers for building the model

In [0]:
class Embed(nn.Module):
    def __init__(self, node_dim, edge_dim):
        super(Embed, self).__init__()
        self.emb_node_types = nn.Embedding(10, node_dim-4)
        self.emb_edge_types = nn.Embedding(10, edge_dim-3)
        
    def forward(self, g):
        emb = self.emb_node_types(g.ndata['type'])
        g.ndata['emb'] = torch.cat([
            emb,
            g.ndata['el_aff'].unsqueeze(1),
            g.ndata['el_neg'].unsqueeze(1),
            g.ndata['1st_ion'].unsqueeze(1),
            g.ndata['mulliken'].unsqueeze(1)
        ], dim=-1)
        
        emb = self.emb_edge_types(g.edata['type'])
        g.edata['emb'] = torch.cat([
            emb,
            g.edata['distance'].unsqueeze(1),
            g.edata['angle'].unsqueeze(1),
            g.edata['dihedral'].unsqueeze(1),
        ], dim=-1)
        return g
      

class GraphCast(nn.Module):
    def __init__(self, dtype):
        super(GraphAct, self).__init__()
        self.dtype = dtype
    
    def forward(self, g):
        g.ndata['emb'] = g.ndata['emb'].to(self.dtype)
        g.edata['emb'] = g.edata['emb'].to(self.dtype)
        return g
      
    
class EdgeSoftmax(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, g, score):
        score_name = dgl.utils.get_edata_name(g, 'score')
        tmp_name = dgl.utils.get_ndata_name(g, 'tmp')
        out_name = dgl.utils.get_edata_name(g, 'out')
        g.edata[score_name] = score
        g.update_all(fn.copy_e(score_name, 'm'), fn.max('m', tmp_name))
        g.apply_edges(fn.e_sub_v(score_name, tmp_name, out_name))
        g.edata[out_name] = torch.exp(g.edata[out_name])
        g.update_all(fn.copy_e(out_name, 'm'), fn.sum('m', tmp_name))
        g.apply_edges(fn.e_div_v(out_name, tmp_name, out_name))
        g.edata.pop(score_name)
        g.ndata.pop(tmp_name)
        out = g.edata.pop(out_name)
        ctx.backward_cache = weakref.ref(g)
        ctx.save_for_backward(out)
        return out

    @staticmethod
    def backward(ctx, grad_out):
        g = ctx.backward_cache()
        out, = ctx.saved_tensors
        # clear backward cache explicitly
        ctx.backward_cache = None
        out_name = dgl.utils.get_edata_name(g, 'out')
        accum_name = dgl.utils.get_ndata_name(g, 'accum')
        grad_score_name = dgl.utils.get_edata_name(g, 'grad_score')
        g.edata[out_name] = out
        g.edata[grad_score_name] = out * grad_out
        g.update_all(fn.copy_e(grad_score_name, 'm'), fn.sum('m', accum_name))
        g.apply_edges(fn.e_mul_v(out_name, accum_name, out_name))
        g.ndata.pop(accum_name)
        grad_score = g.edata.pop(grad_score_name) - g.edata.pop(out_name)
        return None, grad_score

In [0]:
class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, init=ct.curry(nn.init.xavier_normal_)(gain=1.414)):
        self.init = init
        super(Linear, self).__init__(in_features, out_features, bias)
        
    def reset_parameters(self):
        self.init(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    
class MultiLinear(nn.Module):
    def __init__(self, in_features, out_features, n_linears=1, bias=True,
                init=ct.curry(nn.init.xavier_normal_)(gain=nn.init.calculate_gain('relu'))):
        super(MultiLinear, self).__init__()
        self.out_features = out_features
        self.n_linears = n_linears
        self.in_features = in_features
        self.init = init
        weights = torch.zeros(n_linears, in_features, out_features, dtype=torch.float32)
        init(weights)
        self.lin = nn.Parameter(weights)
        self.init(self.lin.data)
        if bias:
            b = torch.zeros((n_linears, 1, self.out_features), dtype=torch.float32)
            self.bias = nn.Parameter(b)
        else:
            self.bias = None
            
    def extra_repr(self):
        return f'{self.in_features}, {self.out_features}, {self.n_linears}'
    
    def forward(self, x):
        batch = x.shape[0]
        x = x.view(batch, self.n_linears, -1).permute(1, 0, 2)
        if self.bias is not None:
            y = torch.baddbmm(self.bias.expand(self.n_linears, batch, self.out_features),
                             x,
                             self.lin)
        else:
            y = torch.bmm(input=x,mat2=self.lin.data)
        return y.permute(1, 0, 2).contiguous()
         
        
class GraphLambda(nn.Module):
    def __init__(self, fn, node_key='emb', edge_key='emb'):
        super(GraphLambda, self).__init__()
        self.fn = fn
        self.edge_key = edge_key
        self.node_key = node_key
    
    def forward(self, g):
        if self.node_key:
            g.ndata[self.node_key] = self.fn(g.ndata[self.node_key])
        if self.edge_key:
            g.edata[self.edge_key] = self.fn(g.edata[self.edge_key])
        return g
    
ReduceMean = lambda: GraphLambda(lambda x: x.mean(dim=-2))
ReduceCat = lambda: GraphLambda(lambda x: x.view(x.shape[0], -1))

    
class GatedResidual(nn.Module):
    def __init__(self, module, node_dim, edge_dim):
        super(GatedResidual, self).__init__()
        self.module = module
        sig_gain = nn.init.calculate_gain('sigmoid')
        self.prev_node_gate = Linear(
            node_dim, node_dim, bias=False,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        self.curr_node_gate = Linear(
            node_dim, node_dim, bias=True,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        self.prev_edge_gate = Linear(
            edge_dim, edge_dim, bias=False,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        self.curr_edge_gate = Linear(
            edge_dim, edge_dim, bias=True,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        nn.init.zeros_(self.curr_node_gate.bias.data)
        nn.init.zeros_(self.curr_edge_gate.bias.data)
    
    def forward(self, g):
        prev_node = g.ndata['emb']
        prev_edge = g.edata['emb']
        
        g = self.module(g)
        
        node_z = torch.sigmoid(
            self.prev_node_gate(prev_node) + \
            self.curr_node_gate(g.ndata['emb']))
        edge_z = torch.sigmoid(
            self.prev_edge_gate(prev_edge) + \
            self.curr_edge_gate(g.edata['emb']))
        g.ndata['emb'] = node_z * g.ndata['emb'] + (1 - node_z) * prev_node
        g.edata['emb'] = edge_z * g.edata['emb'] + (1 - edge_z) * prev_edge
        
        return g
    
      
class TripletMultiLinear(nn.Module):
    def __init__(self, in_node_dim, in_edge_dim, out_edge_dim, n_lins, bias=False):
        super(TripletMultiLinear, self).__init__()
        self.lin = MultiLinear(in_node_dim * 2 + in_edge_dim, out_edge_dim, n_lins, bias)
        
    def triplet_linear(self, edges):
        triplets = torch.cat([edges.src['emb'], edges.data['emb'], edges.dst['emb']], dim=-1)
        return { 'triplets' : triplets }
    
    def forward(self, g):
        g.apply_edges(self.triplet_linear)
        g.edata['emb'] = self.lin(g.edata.pop('triplets'))
        return GraphLambda(lambda x: x.view(x.shape[0], -1), node_key=None)(g)


class TripletCat(nn.Module):
    def __init__(self, out='emb'):
        super(TripletCat, self).__init__()
        self.out = out

    def triplet_linear(self, edges):
        triplets = torch.cat([edges.src['emb'], edges.data['emb'], edges.dst['emb']], dim=-1)
        return { self.out : triplets }
    
    def forward(self, g):
        g.apply_edges(self.triplet_linear)
        return g
    

class MagicAttn(nn.Module):
    def __init__(self, node_dim, edge_dim, n_heads, attn_key='emb', msg_key='emb', alpha=.2):
        super(MagicAttn, self).__init__()
        self.attn = MultiLinear(
            edge_dim, 1, n_heads, bias=False,
            init=ct.curry(nn.init.xavier_normal_)(gain=nn.init.calculate_gain('leaky_relu', alpha)))
        self.leaky_relu = nn.LeakyReLU(alpha)
        self.n_heads = n_heads
        self.softmax = EdgeSoftmax.apply
        self.attn_key = attn_key
        self.msg_key = msg_key
        
    def forward(self, g):
        alpha_prime = self.leaky_relu(self.attn(g.edata[self.attn_key]))
        g.edata['a'] = self.softmax(g, alpha_prime) * g.edata['emb'].view(g.edata['emb'].shape[0], self.n_heads, -1)
        attn_emb = g.ndata[self.msg_key]
        if attn_emb.ndimension() == 2:
            g.ndata[self.msg_key] = attn_emb.view(g.number_of_nodes(), self.n_heads, -1)
        g.update_all(fn.src_mul_edge(self.msg_key, 'a', 'm'), fn.sum('m', 'emb'))
        return GraphLambda(lambda x: x.view(x.shape[0], -1))(g)

      
def NodeLinear(in_features, out_features, bias=False):
    return GraphLambda(Linear(in_features, out_features, bias), edge_key=None)


def EdgeLinear(in_features, out_features, bias=False):
    return GraphLambda(Linear(in_features, out_features, bias), node_key=None)

In [0]:
class MathDict(dict):
    def __init__(self, *args, **kwargs):
        super(MathDict, self).__init__(*args, **kwargs)

    def __op__(self, other, op):
        if isinstance(other, dict):
            return MathDict({ k: op(v, other[k]) for k, v in self.items() })
        if isinstance(other, (numbers.Number, torch.Tensor)):
            return MathDict({ k: op(v, other) for k, v in self.items() })
        
    def __add__(self, other):
        return self.__op__(other, op=operator.add)
    
    def __mul__(self, other):
        return self.__op__(other, op=operator.mul)
    
    def __sub__(self, other):
        return self.__op__(other, op=operator.sub)
    
    def __mod__(self, other):
        return self.__op__(other, op=operator.mod)
    
    def __truediv__(self, other):
        return self.__op__(other, op=operator.truediv)
    
    def __lt__(self, other):
        return self.__op__(other, op=operator.lt)
    
    def __le__(self, other):
        return self.__op__(other, op=operator.le)
    
    def __gt__(self, other):
        return self.__op__(other, op=operator.gt)
    
    def __ge__(self, other):
        return self.__op__(other, op=operator.ge)


def load_checkpoint(path):
    checkpoint = torch.load(path)
    global start_epoch, i, prefix, batch_size
    net.load_state_dict(checkpoint['model'])

## Code for extracting outputs from the network

In [0]:
ctype_means = torch.tensor([
    94.97615286418801, 47.47988448446838, -0.2706244378832182,
    -10.28660516398165, 3.124753613418501,3.6884695895354453,
    4.771023359735822, 0.9907298624943462], dtype=torch.float32)
ctype_stds = torch.tensor([
    18.277236880290143, 10.922171556272271, 4.523610750196489,
    3.9796071637303525, 3.6734741723096023, 3.0709074866562185,
    3.7049844341285763, 1.3153933535337567], dtype=torch.float32)


ctype_means_c = ctype_means.cuda()
ctype_stds_c = ctype_stds.cuda()
ones = torch.tensor([1., 1., 1., 1., 1., 1., 1., 1.], dtype=torch.float32).cuda()


def get_outputs(g, id_col='train_id', norm_truth=False, denorm_out=False, scaled=True):
    labeled = g.edata['coupling_type'] != -1
    out = g.edata['emb'][labeled].squeeze()
    truth = g.edata['coupling'][labeled]
    ctype = g.edata['coupling_type'][labeled]
    if scaled:
      stds = ctype_stds_c
    else:
      stds = ones
    
    if norm_truth:
        truth = (truth - ctype_means_c[ctype])/stds[ctype]
      
    if denorm_out:
        out = out * stds[ctype].unsqueeze(-1) + ctype_means_c[ctype].unsqueeze(-1)
      
    src, dst = g.all_edges('uv')
    src, dst = src[labeled].to(out.device), dst[labeled].to(out.device)
    tid = g.edata[id_col][labeled]

    return out, ctype, truth, src, dst, tid
        
        
def bidir_combine(out, ctype, truth, src, dst, tid):
    fwd = src > dst
    bwd = src < dst
    outs = (out[fwd] + out[bwd]) / 2.
    return outs, ctype[fwd], truth[fwd], src[fwd], dst[fwd], tid[fwd]

## Model arhitecture

In [0]:
combine_fn = bidir_combine

emb = 48
heads = 24
bias = False

def AttnBlock(in_emb, out_emb):
    return nn.Sequential(
        EdgeLinear(in_emb, out_emb),
        NodeLinear(in_emb, out_emb),
        GraphLambda(lambda x: x.view(x.shape[0], heads, -1)),
        TripletCat(out='triplet'),
        MagicAttn(emb, 3 * emb, heads, attn_key='triplet'),
        TripletMultiLinear(emb, emb, emb, heads, bias=bias),
        GraphLambda(torch.nn.LayerNorm(heads * emb))
    )

net = nn.Sequential(
    Embed(emb, emb),
    AttnBlock(emb, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    EdgeLinear(emb * heads, 512, bias=True), GraphLambda(nn.PReLU(), node_key=None),
    EdgeLinear(512, 8, bias=True)
)


use_cuda = True
if use_cuda:
    net = net.cuda()

## Making the predctions

In [0]:
def predict(test_data, net, scaled):
    loader = torch.utils.data.DataLoader(test_data, batch_size=512, 
                                         shuffle=False, num_workers=2, 
                                         collate_fn=make_graph_batch(precision='single'))
    net.eval()
    with torch.no_grad():
        outs_d = []
        for batch in tqdm(loader):
            batch.to(next(net.parameters()).device)
            g = net(batch)
            outs = get_outputs(g, id_col='test_id', denorm_out=True, scaled=scaled)
            outs = bidir_combine(*outs)
            outs_d.append([t.detach().cpu() for t in outs])
        outs_d = zip(*outs_d)
        outs_d = [torch.cat(ts) for ts in outs_d]
        out, ctype, truth, src, dst, tid = outs_d
        return tid, out, ctype

      
def save_df(tid, pred, ctype, ctype_name, output):
    ctype_idx = CTYPES.index(ctype_name)
    df = pd.DataFrame({'tid': tid.numpy(), 
                       'val': pred[:, ctype_idx].numpy(), 
                       'ctype': ctype.numpy()})
    df[df.ctype==ctype_idx].reset_index().to_csv(output)
    
    
def mean_folds(folds):
  dfs = [pd.read_csv(f) for f in folds]
  dfs[0]['val'] = pd.concat([d['val'] for d in dfs], axis=1).mean(axis=1)
  return dfs[0]

In [0]:
%%time
results = {ct: [] for ct in models.keys()}

for ct_name, ms in models.items():
  for i, (model, scaled) in enumerate(ms):
    load_checkpoint(model)
    tid, pred, ctype = predict(test_data, net, scaled)
    save_df(tid, pred, ctype, ct_name, f'{output}/{ct_name}_{i}.csv')
    results[ct_name].append( f'{ct_name}_{i}.csv')
    
    
df = pd.concat([mean_folds(results[x]) for x in CTYPES])
df.sort_values('tid', inplace=True)

df = pd.DataFrame({'id': df.tid, 'scalar_coupling_constant': df.val})
print(len(df))

df.to_csv(f'{output}/predictions.csv', index=False)

  6%|▌         | 5/90 [00:26<07:48,  5.52s/it]