In [0]:
import json
import pdb
import math
import dgl
import torch
import weakref
import numbers
import operator
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.utils.data
import cytoolz.curried as ct
import humanfriendly as hf
import itertools

import dgl.function as fn
import dgl.nn.pytorch as dglnn

import torch.nn.functional as F
import torch.utils.tensorboard as tb

from pathlib import Path
from datetime import datetime
from tqdm import tqdm_notebook as tqdm
import subprocess


SEED = 43
SCALED = True
TRUE_MLKN = True

TYPES_TO_PROCESS = [0, 1, 2, 3, 4, 5, 6, 7]
# ['1JHC', '1JHN', '2JHC', '2JHH', '2JHN', '3JHC', '3JHH', '3JHN']

In [0]:
settings = json.load(open('SETTINGS.json'))

RUNS = settings['TRAIN']['TBOARD_LOG']
CHKPS = settings['TRAIN']['MODEL_CHKPS']
prefix = settings['TRAIN']['MODEL_NAME']
dataset = settings['TRAIN']['INPUT']
mlk = settings['TRAIN']['Q9MLK']
START = settings['TRAIN']['START_FROM']

In [0]:
# RUNS = "gdrive/My Drive/CHAMPS/runs"
# DATA  = 'gdrive/My Drive/CHAMPS/champs/final/input_data'
# CHKPS = 'gdrive/My Drive/CHAMPS//checkpoints'
#prefix = settings['TRAIN']['MODEL_NAME']

In [0]:
gdata = torch.load(dataset)
if TRUE_MLKN:
    gdata_true = torch.load(mlk)
    for i in range(len(gdata)):
        assert len(gdata[i]['ndata']['mulliken']) == len(gdata_true[i]['mulliken']), i
        gdata[i]['ndata']['mulliken'] = gdata_true[i]['mulliken']

In [5]:
rng = np.random.RandomState(seed=SEED)
is_eval = rng.rand(len(gdata)) > .9
train_data = [g for g, e in zip(gdata, is_eval) if not e]
eval_data = [g for g, e in zip(gdata, is_eval) if e]
assert (len(train_data) + len(eval_data)) == len(gdata)
len(eval_data)

8653

# Optimizer

In [0]:
class LambW(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
                 weight_decay=0, adam=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        self.adam = adam
        super(LambW, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                step_size = group['lr']

                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
                adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
 
                adam_norm = adam_step.pow(2).sum().sqrt()
                if weight_norm == 0 or adam_norm == 0:
                    trust_ratio = 1
                else:
                    trust_ratio = weight_norm / adam_norm
                state['weight_norm'] = weight_norm
                state['adam_norm'] = adam_norm
                state['trust_ratio'] = trust_ratio
                if self.adam:
                    trust_ratio = 1

                if group['weight_decay'] != 0:
                    p.data.add_(-group['weight_decay'] * group['lr'], p.data)
                p.data.add_(-step_size * trust_ratio, adam_step)

        return loss
    
    
class CyclicLR(torch.optim.lr_scheduler._LRScheduler):
    """Taken form Pytorch code to patch a bug when working with Adam.
    """

    def __init__(self,
                 optimizer,
                 base_lr,
                 max_lr,
                 step_size_up=2000,
                 step_size_down=None,
                 mode='triangular',
                 gamma=1.,
                 scale_fn=None,
                 scale_mode='cycle',
                 cycle_momentum=True,
                 base_momentum=0.8,
                 max_momentum=0.9,
                 last_epoch=-1):

        if not isinstance(optimizer, torch.optim.Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        base_lrs = self._format_param('base_lr', optimizer, base_lr)
        if last_epoch == -1:
            for lr, group in zip(base_lrs, optimizer.param_groups):
                group['lr'] = lr

        self.max_lrs = self._format_param('max_lr', optimizer, max_lr)

        step_size_up = float(step_size_up)
        step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
        self.total_size = step_size_up + step_size_down
        self.step_ratio = step_size_up / self.total_size

        if mode not in ['triangular', 'triangular2', 'exp_range'] \
                and scale_fn is None:
            raise ValueError('mode is invalid and scale_fn is None')

        self.mode = mode
        self.gamma = gamma

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode

        self.cycle_momentum = cycle_momentum
        if cycle_momentum:

            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
            if last_epoch == -1:
                for momentum, group in zip(base_momentums, optimizer.param_groups):
                    group['betas'] = (group['betas'][0], momentum)
            self.base_momentums = list(map(lambda group: group['betas'][1], optimizer.param_groups))
            self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)

        super(CyclicLR, self).__init__(optimizer, last_epoch)

    def _format_param(self, name, optimizer, param):
        """Return correctly formatted lr/momentum for each param group."""
        if isinstance(param, (list, tuple)):
            if len(param) != len(optimizer.param_groups):
                raise ValueError("expected {} values for {}, got {}".format(
                    len(optimizer.param_groups), name, len(param)))
            return param
        else:
            return [param] * len(optimizer.param_groups)

    def _triangular_scale_fn(self, x):
        return 1.

    def _triangular2_scale_fn(self, x):
        return 1 / (2. ** (x - 1))

    def _exp_range_scale_fn(self, x):
        return self.gamma**(x)

    def get_lr(self):
        """Calculates the learning rate at batch index. This function treats
        `self.last_epoch` as the last batch index.

        If `self.cycle_momentum` is ``True``, this function has a side effect of
        updating the optimizer's momentum.
        """
        cycle = math.floor(1 + self.last_epoch / self.total_size)
        x = 1. + self.last_epoch / self.total_size - cycle
        if x <= self.step_ratio:
            scale_factor = x / self.step_ratio
        else:
            scale_factor = (x - 1) / (self.step_ratio - 1)

        lrs = []
        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
            base_height = (max_lr - base_lr) * scale_factor
            if self.scale_mode == 'cycle':
                lr = base_lr + base_height * self.scale_fn(cycle)
            else:
                lr = base_lr + base_height * self.scale_fn(self.last_epoch)
            lrs.append(lr)

        if self.cycle_momentum:
            momentums = []
            for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
                base_height = (max_momentum - base_momentum) * scale_factor
                if self.scale_mode == 'cycle':
                    momentum = max_momentum - base_height * self.scale_fn(cycle)
                else:
                    momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
                momentums.append(momentum)
            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
                param_group['betas'] = (param_group['betas'][0], momentum)

        return lrs

# Code for Stohastic Weight Averaging

In [0]:
class MathDict(dict):
    def __init__(self, *args, **kwargs):
        super(MathDict, self).__init__(*args, **kwargs)

    def __op__(self, other, op):
        if isinstance(other, dict):
            return MathDict({ k: op(v, other[k]) for k, v in self.items() })
        if isinstance(other, (numbers.Number, torch.Tensor)):
            return MathDict({ k: op(v, other) for k, v in self.items() })
        
    def __add__(self, other):
        return self.__op__(other, op=operator.add)
    
    def __mul__(self, other):
        return self.__op__(other, op=operator.mul)
    
    def __sub__(self, other):
        return self.__op__(other, op=operator.sub)
    
    def __mod__(self, other):
        return self.__op__(other, op=operator.mod)
    
    def __truediv__(self, other):
        return self.__op__(other, op=operator.truediv)
    
    def __lt__(self, other):
        return self.__op__(other, op=operator.lt)
    
    def __le__(self, other):
        return self.__op__(other, op=operator.le)
    
    def __gt__(self, other):
        return self.__op__(other, op=operator.gt)
    
    def __ge__(self, other):
        return self.__op__(other, op=operator.ge)
    
class SWA(object):
    def __init__(self):
        super(SWA, self).__init__()
        self.wswa = None
        self.nmodels = 0
    
    def add_model(self, model):
        if self.nmodels == 0:
            self.wswa = MathDict(model)
            self.nmodels = 1
            return
        self.wswa = (self.wswa * self.nmodels + model) / (self.nmodels + 1)
        self.nmodels += 1

@ct.curry
def last_n_swa(n, end_epoch, checkpoints_path, device=None):
    base = Path(checkpoints_path)
    epochs = range(end_epoch-n+1, end_epoch+1)
    swa = SWA()
    for epoch in epochs:
        chk = torch.load(base / f'{epoch}.torch', map_location=device)
        swa.add_model(chk['model'])
    chk['model'] = swa.wswa
    chk['swa_nmodels'] = swa.nmodels
    return chk

# Preparing the data

In [0]:
def make_dglg(mol, precision='double'):
    g = dgl.DGLGraph()
    nodes = mol['nodes']
    g.add_nodes(nodes)
    src, dst = mol['src'], mol['dst']
    g.ndata.update(mol['ndata'])
    g.ndata['type'] = g.ndata['type']
    g.add_edges(src, dst)
    g.add_edges(dst, src)
    g.add_edges(range(nodes), range(nodes))
    
    device = src.device
    zeros = torch.zeros(nodes, dtype=torch.float32, device=device)
    ones = torch.ones(nodes, dtype=torch.int64, device=device) 
    edata = {}
    edata['type'] = torch.cat([mol['edata']['type'].repeat(2), ones * 6])
    edata['distance'] = torch.cat([mol['edata']['distance'].repeat(2), zeros])
    edata['angle'] = torch.cat([mol['edata']['angle'].repeat(2), zeros])
    edata['dihedral'] = torch.cat([mol['edata']['dihedral'], -mol['edata']['dihedral'], zeros])
    edata['coupling_type'] = torch.cat([mol['edata']['coupling_type'].repeat(2).to(torch.int64), -ones])
    edata['coupling'] = torch.cat([mol['edata']['coupling'].repeat(2), zeros])
    edata['train_id'] = torch.cat([mol['edata']['train_id'].repeat(2), -ones])
    edata['test_id'] = torch.cat([mol['edata']['test_id'].repeat(2), -ones])

    g.edata.update(edata)
    return g

  
@ct.curry
def make_graph_batch(mols, precision='double'):
    g = dgl.batch([make_dglg(mol, precision) for mol in mols])
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
    return g

# Layers for the network

In [0]:
class Embed(nn.Module):
    
    def __init__(self, node_dim, edge_dim):
        super(Embed, self).__init__()
        self.emb_node_types = nn.Embedding(10, node_dim-4)
        self.emb_edge_types = nn.Embedding(10, edge_dim-3)
        
    def forward(self, g):
        emb = self.emb_node_types(g.ndata['type'])
        g.ndata['emb'] = torch.cat([
            emb,
            g.ndata['el_aff'].unsqueeze(1),
            g.ndata['el_neg'].unsqueeze(1),
            g.ndata['1st_ion'].unsqueeze(1),
            g.ndata['mulliken'].unsqueeze(1)
        ], dim=-1)
        
        emb = self.emb_edge_types(g.edata['type'])
        g.edata['emb'] = torch.cat([
            emb,
            g.edata['distance'].unsqueeze(1),
            g.edata['angle'].unsqueeze(1),
            g.edata['dihedral'].unsqueeze(1),
        ], dim=-1)
        return g
      

class GraphCast(nn.Module):
    def __init__(self, dtype):
        super(GraphAct, self).__init__()
        self.dtype = dtype
    
    def forward(self, g):
        g.ndata['emb'] = g.ndata['emb'].to(self.dtype)
        g.edata['emb'] = g.edata['emb'].to(self.dtype)
        return g
      
    
class EdgeSoftmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, g, score):
        score_name = dgl.utils.get_edata_name(g, 'score')
        tmp_name = dgl.utils.get_ndata_name(g, 'tmp')
        out_name = dgl.utils.get_edata_name(g, 'out')
        g.edata[score_name] = score
        g.update_all(fn.copy_e(score_name, 'm'), fn.max('m', tmp_name))
        g.apply_edges(fn.e_sub_v(score_name, tmp_name, out_name))
        g.edata[out_name] = torch.exp(g.edata[out_name])
        g.update_all(fn.copy_e(out_name, 'm'), fn.sum('m', tmp_name))
        g.apply_edges(fn.e_div_v(out_name, tmp_name, out_name))
        g.edata.pop(score_name)
        g.ndata.pop(tmp_name)
        out = g.edata.pop(out_name)
        ctx.backward_cache = weakref.ref(g)
        ctx.save_for_backward(out)
        return out

    @staticmethod
    def backward(ctx, grad_out):
        g = ctx.backward_cache()
        out, = ctx.saved_tensors
        # clear backward cache explicitly
        ctx.backward_cache = None
        out_name = dgl.utils.get_edata_name(g, 'out')
        accum_name = dgl.utils.get_ndata_name(g, 'accum')
        grad_score_name = dgl.utils.get_edata_name(g, 'grad_score')
        g.edata[out_name] = out
        g.edata[grad_score_name] = out * grad_out
        g.update_all(fn.copy_e(grad_score_name, 'm'), fn.sum('m', accum_name))
        g.apply_edges(fn.e_mul_v(out_name, accum_name, out_name))
        g.ndata.pop(accum_name)
        grad_score = g.edata.pop(grad_score_name) - g.edata.pop(out_name)
        return None, grad_score

In [0]:
class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, init=ct.curry(nn.init.xavier_normal_)(gain=1.414)):
        self.init = init
        super(Linear, self).__init__(in_features, out_features, bias)
        
    def reset_parameters(self):
        self.init(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
class MultiLinear(nn.Module):
    def __init__(self, in_features, out_features, n_linears=1, bias=True,
                init=ct.curry(nn.init.xavier_normal_)(gain=nn.init.calculate_gain('relu'))):
        super(MultiLinear, self).__init__()
        self.out_features = out_features
        self.n_linears = n_linears
        self.in_features = in_features
        self.init = init
        weights = torch.zeros(n_linears, in_features, out_features, dtype=torch.float32)
        init(weights)
        self.lin = nn.Parameter(weights)
        self.init(self.lin.data)
        if bias:
            b = torch.zeros((n_linears, 1, self.out_features), dtype=torch.float32)
            self.bias = nn.Parameter(b)
        else:
            self.bias = None
            
    def extra_repr(self):
        return f'{self.in_features}, {self.out_features}, {self.n_linears}'
    
    def forward(self, x):
        batch = x.shape[0]
        x = x.view(batch, self.n_linears, -1).permute(1, 0, 2)
        if self.bias is not None:
            y = torch.baddbmm(self.bias.expand(self.n_linears, batch, self.out_features),
                             x,
                             self.lin)
        else:
            y = torch.bmm(input=x,mat2=self.lin.data)
        return y.permute(1, 0, 2).contiguous()
            
class GraphLambda(nn.Module):
    def __init__(self, fn, node_key='emb', edge_key='emb'):
        super(GraphLambda, self).__init__()
        self.fn = fn
        self.edge_key = edge_key
        self.node_key = node_key
    
    def forward(self, g):
        if self.node_key:
            g.ndata[self.node_key] = self.fn(g.ndata[self.node_key])
        if self.edge_key:
            g.edata[self.edge_key] = self.fn(g.edata[self.edge_key])
        return g
    
ReduceMean = lambda: GraphLambda(lambda x: x.mean(dim=-2))
ReduceCat = lambda: GraphLambda(lambda x: x.view(x.shape[0], -1))

class Residual(nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    
    def forward(self, g):
        nemb = g.ndata['emb']
        eemb = g.edata['emb']
        g = self.module(g)
        g.ndata['emb'] += nemb
        g.edata['emb'] += eemb
        return g
    
class GatedResidual(nn.Module):
    def __init__(self, module, node_dim, edge_dim):
        super(GatedResidual, self).__init__()
        self.module = module
        sig_gain = nn.init.calculate_gain('sigmoid')
        self.prev_node_gate = Linear(
            node_dim, node_dim, bias=False,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        self.curr_node_gate = Linear(
            node_dim, node_dim, bias=True,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        self.prev_edge_gate = Linear(
            edge_dim, edge_dim, bias=False,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        self.curr_edge_gate = Linear(
            edge_dim, edge_dim, bias=True,
            init=ct.curry(nn.init.xavier_normal_)(gain=sig_gain))
        nn.init.zeros_(self.curr_node_gate.bias.data)
        nn.init.zeros_(self.curr_edge_gate.bias.data)
    
    def forward(self, g):
        prev_node = g.ndata['emb']
        prev_edge = g.edata['emb']
        
        g = self.module(g)
        
        node_z = torch.sigmoid(
            self.prev_node_gate(prev_node) + \
            self.curr_node_gate(g.ndata['emb']))
        edge_z = torch.sigmoid(
            self.prev_edge_gate(prev_edge) + \
            self.curr_edge_gate(g.edata['emb']))
        g.ndata['emb'] = node_z * g.ndata['emb'] + (1 - node_z) * prev_node
        g.edata['emb'] = edge_z * g.edata['emb'] + (1 - edge_z) * prev_edge
        
        return g
    
class TripletLinear(nn.Module):
    def __init__(self, in_node_dim, in_edge_dim, out_edge_dim, bias=False):
        super(TripletLinear, self).__init__()
        self.lin = Linear(in_node_dim * 2 + in_edge_dim, out_edge_dim, bias)
        
    def triplet_linear(self, edges):
        triplets = torch.cat([edges.src['emb'], edges.data['emb'], edges.dst['emb']], dim=-1)
        return { 'triplets' : triplets }
    
    def forward(self, g):
        g.apply_edges(self.triplet_linear)
        g.edata['emb'] = self.lin(g.edata.pop('triplets'))
        return g

      
class TripletMultiLinear(nn.Module):
    def __init__(self, in_node_dim, in_edge_dim, out_edge_dim, n_lins, bias=False):
        super(TripletMultiLinear, self).__init__()
        self.lin = MultiLinear(in_node_dim * 2 + in_edge_dim, out_edge_dim, n_lins, bias)
        
    def triplet_linear(self, edges):
        triplets = torch.cat([edges.src['emb'], edges.data['emb'], edges.dst['emb']], dim=-1)
        return { 'triplets' : triplets }
    
    def forward(self, g):
        g.apply_edges(self.triplet_linear)
        g.edata['emb'] = self.lin(g.edata.pop('triplets'))
        return GraphLambda(lambda x: x.view(x.shape[0], -1), node_key=None)(g)


class TripletCat(nn.Module):
    def __init__(self, out='emb'):
        super(TripletCat, self).__init__()
        self.out = out

    def triplet_linear(self, edges):
        triplets = torch.cat([edges.src['emb'], edges.data['emb'], edges.dst['emb']], dim=-1)
        return { self.out : triplets }
    
    def forward(self, g):
        g.apply_edges(self.triplet_linear)
        return g
    

class MagicAttn(nn.Module):
    def __init__(self, node_dim, edge_dim, n_heads, attn_key='emb', msg_key='emb', alpha=.2):
        super(MagicAttn, self).__init__()
        self.attn = MultiLinear(
            edge_dim, 1, n_heads, bias=False,
            init=ct.curry(nn.init.xavier_normal_)(gain=nn.init.calculate_gain('leaky_relu', alpha)))
        self.leaky_relu = nn.LeakyReLU(alpha)
        self.n_heads = n_heads
        self.softmax = EdgeSoftmax.apply
        self.attn_key = attn_key
        self.msg_key = msg_key
        
    def forward(self, g):
        alpha_prime = self.leaky_relu(self.attn(g.edata[self.attn_key]))
        g.edata['a'] = self.softmax(g, alpha_prime) * g.edata['emb'].view(g.edata['emb'].shape[0], self.n_heads, -1)
        attn_emb = g.ndata[self.msg_key]
        if attn_emb.ndimension() == 2:
            g.ndata[self.msg_key] = attn_emb.view(g.number_of_nodes(), self.n_heads, -1)
        g.update_all(fn.src_mul_edge(self.msg_key, 'a', 'm'), fn.sum('m', 'emb'))
        return GraphLambda(lambda x: x.view(x.shape[0], -1))(g)


class HeadExpand(nn.Module):
    def __init__(self, n_heads):
        super(HeadExpand, self).__init__()
        self.n_heads = n_heads
    
    def forward(self, g):
        g.edata['emb_orig'] = g.edata['emb']
        g.ndata['emb_orig'] = g.ndata['emb']
        g.edata['emb'] = g.edata['emb'].unsqueeze(1).expand((-1, 
                                                             self.n_heads, -1)).reshape(g.edata['emb'].shape[0], -1)
        g.ndata['emb'] = g.ndata['emb'].unsqueeze(1).expand((-1, 
                                                             self.n_heads, -1)).reshape(g.ndata['emb'].shape[0], -1)
        return g
   
  
class OrigEmb(nn.Module):
    def __init__(self):
        super(OrigEmb, self).__init__()
    
    def forward(self, g):
        g.edata['emb'] = g.edata['emb_orig']
        g.ndata['emb'] = g.ndata['emb_orig']
        return g

      
def NodeLinear(in_features, out_features, bias=False):
    return GraphLambda(Linear(in_features, out_features, bias), edge_key=None)


def EdgeLinear(in_features, out_features, bias=False):
    return GraphLambda(Linear(in_features, out_features, bias), node_key=None)
  
def load_checkpoint(path):
    checkpoint = torch.load(path)
    global start_epoch, i, prefix, batch_size
    net.load_state_dict(checkpoint['model'])
    optim.load_state_dict(checkpoint['optim'])
    prefix = checkpoint['prefix']
    i = checkpoint['iteration']
    start_epoch = checkpoint['epoch']
    if batch_size in checkpoint:
        assert batch_size == checkpoint['batch_size']

In [0]:
ctype_means = torch.tensor([
    94.97615286418801, 47.47988448446838, -0.2706244378832182,
    -10.28660516398165, 3.124753613418501,3.6884695895354453,
    4.771023359735822, 0.9907298624943462], dtype=torch.float32)
ctype_stds = torch.tensor([
    18.277236880290143, 10.922171556272271, 4.523610750196489,
    3.9796071637303525, 3.6734741723096023, 3.0709074866562185,
    3.7049844341285763, 1.3153933535337567], dtype=torch.float32)

if not SCALED:
    ctype_stds = torch.tensor([1., 1., 1., 1., 1., 1., 1., 1.], dtype=torch.float32)

ctype_means_c = ctype_means.cuda()
ctype_stds_c = ctype_stds.cuda()
ctype_stds_c_inv = 1/ctype_stds_c


def get_outputs(g, id_col='train_id', norm_truth=False, denorm_out=False):
    labeled = g.edata['coupling_type'] != -1
    out = g.edata['emb'][labeled].squeeze()
    truth = g.edata['coupling'][labeled]
    ctype = g.edata['coupling_type'][labeled]
    
    if norm_truth:
        truth = (truth - ctype_means_c[ctype])*ctype_stds_c_inv[ctype]
      
    if denorm_out:
        out = out * ctype_stds_c[ctype].unsqueeze(-1) + ctype_means_c[ctype].unsqueeze(-1)
      
    src, dst = g.all_edges('uv')
    src, dst = src[labeled].to(out.device), dst[labeled].to(out.device)
    tid = g.edata[id_col][labeled]

    return out, ctype, truth, src, dst, tid
  
  
def run_eval(net, tb_writer, epoch, iteration, eval_loader, label='Eval'):
    i = iteration
    net.eval()
    with torch.no_grad():
        outs_n = []
        outs_d = []
        for batch in tqdm(eval_loader, total=int(np.ceil(len(eval_data) / batch_size)),
                          desc=f'{label} epoch: {epoch}', leave=False):
            if use_cuda:
                batch.to(next(net.parameters()).device)
            g = net(batch)
            outs_n.append([t.detach().cpu() for t in get_outputs(g, norm_truth=True)])
            outs_d.append([t.detach().cpu() for t in get_outputs(g, denorm_out=True)])
        
        outs_n = zip(*outs_n)
        outs_n = [torch.cat(ts) for ts in outs_n]
        out, ctype, truth, src, dst, tid = outs_n
        loss = loss_fn(out, truth, ctype)
        tb_writer.add_scalar('loss', loss.item(), i)
        
        outs_d = zip(*outs_d)
        outs_d = [torch.cat(ts) for ts in outs_d]
        out, ctype, truth, src, dst, tid = outs_d
        scalars = scalar_metrics(out, ctype, truth, src, dst, tid)
        for mlabel, scalar in scalars.items():
            tb_writer.add_scalar(mlabel, scalar, i)
        tb_writer.file_writer.flush()

# Loss functions

In [0]:
def logMAE(preds, truth):
    return torch.log(F.l1_loss(preds, truth) + 1e-9)
  
    
@ct.curry
def logMAE_per_ctype(out, ctypes, truth, ctype_id):
    correct_type = ctypes == ctype_id
    if not torch.any(correct_type):
        return 0
    return logMAE(out[correct_type], truth[correct_type]).item()

@ct.curry
def logMAE_per_ctype_head(out, ctypes, truth, ctype_id):
    correct_type = ctypes == ctype_id
    if not torch.any(correct_type):
        return 0
    return logMAE(out[correct_type, ctype_id], truth[correct_type]).item()


def mae_head_per_ctype(out, ctypes, truth, ctype_id):
    correct_type = ctypes == ctype_id
    if not torch.any(correct_type):
        return torch.tensor(0.0, device=ctypes.device)
    return F.l1_loss(out[correct_type, ctype_id], truth[correct_type])

@ct.curry
def mae_head_per_ctypes(out, truth, ctypes, ctype_ids):
    return torch.cat([mae_head_per_ctype(out, ctypes, truth, ctype_id).unsqueeze(-1) for ctype_id in ctype_ids]).mean()


@ct.curry
def bidir_combine(out, ctype, truth, src, dst, tid):
    fwd = src > dst
    bwd = src < dst
    outs = (out[fwd] + out[bwd]) / 2.
    return outs, ctype[fwd], truth[fwd], src[fwd], dst[fwd], tid[fwd]
  
  
def scalar_metrics(out, ctype, truth, src, dst, tid):
    with torch.no_grad():
        combined = combine_fn(out, ctype, truth, src, dst, tid)
        cout, cctype, ctruth, csrc, cdst, ctid = combined
        
        metrics = {}
        ctype_names = ['1JHC', '1JHN', '2JHC', '2JHH', '2JHN', '3JHC', '3JHH', '3JHN']
        maes = []
        for ctype_id, label in enumerate(ctype_names):
            mae = logMAE_per_ctype_head(cout, cctype, ctruth, ctype_id)
            metrics[f'trueLogMAE/{label}'] = mae
            maes.append(mae)
        metrics['trueLogMAE/mean'] = torch.Tensor(maes).mean()
        if not net.training:
            print('Epoch', epoch, '-', 'trueLogMAE/mean:', metrics['trueLogMAE/mean'].item())
        return metrics

# Setting up the model

In [13]:
combine_fn = bidir_combine
loss_fn = mae_head_per_ctypes(ctype_ids=TYPES_TO_PROCESS)

emb = 48
heads = 24
bias = False

def AttnBlock(in_emb, out_emb):
    return nn.Sequential(
        EdgeLinear(in_emb, out_emb),
        NodeLinear(in_emb, out_emb),
        GraphLambda(lambda x: x.view(x.shape[0], heads, -1)),
        TripletCat(out='triplet'),
        MagicAttn(emb, 3 * emb, heads, attn_key='triplet'),
        TripletMultiLinear(emb, emb, emb, heads, bias=bias),
        GraphLambda(torch.nn.LayerNorm(heads * emb))
    )

net = nn.Sequential(
    Embed(emb, emb),
    AttnBlock(emb, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    GatedResidual(AttnBlock(emb * heads, emb * heads), emb * heads, emb * heads), GraphLambda(nn.PReLU()),
    EdgeLinear(emb * heads, 512, bias=True), GraphLambda(nn.PReLU(), node_key=None),
    EdgeLinear(512, 8, bias=True)
)


print('Parameter count: ', hf.format_number(sum([p.numel() for p in net.parameters()])))

use_cuda = True
if use_cuda:
    net = net.cuda()
    
def noreg(x):
    return type(x) == GraphLambda and type(x.fn) == torch.nn.modules.activation.PReLU

regs = [p for p in itertools.chain.from_iterable(x.parameters() for x in net if not noreg(x))]
noregs = [p for p in itertools.chain.from_iterable(x.parameters() for x in net if noreg(x))]
params = [
    {'params': regs, 'lr': 0.01, 'weight_decay': 5e-2}, #'betas':[0.99, 0.95]},
    {'params': noregs, 'lr': 0.01, 'weight_decay': 0}, # 'betas': [0.99, 0.95]}
]

optim = LambW(params)

Parameter count:  57,833,611


# Setting up the logging

In [14]:
i = -1
start_epoch = 0
batch_size = 80

if START:
  load_checkpoint(START)
prefix = settings['TRAIN']['MODEL_NAME']

if batch_size is None:
    batch_size = 80
precision = 'single'

train_loader = torch.utils.data.DataLoader(
    train_data, drop_last=True, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=make_graph_batch(precision=precision))
eval_loader = torch.utils.data.DataLoader(
    eval_data, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=make_graph_batch(precision=precision))


train_log = f'{RUNS}/{prefix}/train'
eval_log = f'{RUNS}/{prefix}/eval'
swa_log = f'{RUNS}/{prefix}/swa'
Path(train_log).mkdir(parents=True, exist_ok=True)
Path(eval_log).mkdir(parents=True, exist_ok=True)
Path(swa_log).mkdir(parents=True, exist_ok=True)

checkpoints = Path(CHKPS) / prefix /'train'
checkpoints.mkdir(parents=True, exist_ok=True)
swa_checkpoints = Path(CHKPS) / prefix /'swa'
swa_checkpoints.mkdir(parents=True, exist_ok=True)

print(train_log, start_epoch)

train_writer = tb.SummaryWriter(log_dir=train_log, filename_suffix='.train')
eval_writer = tb.SummaryWriter(log_dir=eval_log, filename_suffix='.eval')
swa_writer = tb.SummaryWriter(log_dir=swa_log, filename_suffix='.train')

gdrive/My Drive/CHAMPS/runs/scaled-truemlkn-43/train 0


# Training loop

In [15]:
epochs = 100
metric_freq =100
start_swa = 60
n_swa = 25

iper_epoch = len(train_loader)

scheduler = CyclicLR(optim, 
                     base_lr=0.001, max_lr=0.01, cycle_momentum=False,
                     step_size_up=15*iper_epoch,
                     step_size_down=15*iper_epoch,
                     last_epoch=i)

for epoch in tqdm(range(start_epoch + 1, start_epoch + epochs + 1)):
    train_writer.add_scalar('Optim/Learning rate', optim.param_groups[0]['lr'], i)
    train_writer.add_scalar('Optim/Weight Decay', optim.param_groups[0]['weight_decay'], i)
    net.train()
    for batch in tqdm(train_loader, total=iper_epoch, desc=f'Epoch: {epoch}', leave=False):
        i += 1
        if use_cuda:
            batch.to(next(net.parameters()).device)
        optim.zero_grad()
        g = net(batch)
        out, ctype, truth, src, dst, tid = get_outputs(g, norm_truth=True)
        loss = loss_fn(out, truth, ctype)
        loss.backward()
        optim.step()
        
        if epoch < 31:
            scheduler.step()
            
        if epoch == 35:
            optim.param_groups[0]['weight_decay'] = 1e-2
            optim.param_groups[0]['lr'] = 1e-3
            optim.param_groups[1]['lr'] = 1e-3    
    
        if (i % metric_freq == 0) or ((i+1) % iper_epoch == 0):            
            train_writer.add_scalar('loss', loss.item(), i)
            outs = get_outputs(g, denorm_out=True)
            outs = [out.detach().cpu() for out in outs]
            out, ctype, truth, src, dst, tid = outs
            scalars = scalar_metrics(out, ctype, truth, src, dst, tid)
            for label, scalar in scalars.items():
                train_writer.add_scalar(label, scalar, i)
            train_writer.file_writer.flush()
    run_eval(net, eval_writer, epoch, i, eval_loader)
    torch.save({
        'model': net.state_dict(),
        'optim': optim.state_dict(),
        'prefix': prefix,
        'iteration': i,
        'epoch': epoch,
        'batch_size': batch_size,
    }, checkpoints / f'{epoch}.torch')
    
    if epoch < start_swa:
        start_epoch = epoch
        continue    
    swa_chk = last_n_swa(n_swa, epoch, checkpoints, device='cpu')
    assert swa_chk['epoch'] == epoch
    net.load_state_dict(swa_chk['model'])
    if use_cuda: net.to('cuda')
    run_eval(net, swa_writer, epoch, i, eval_loader, label='SWA')
    torch.save(swa_chk, swa_checkpoints / f'{epoch}.last_{n_swa}.torch')
    net.load_state_dict(torch.load(checkpoints / f'{epoch}.torch')['model'])

    to_del = swa_checkpoints / f'{epoch-2}.last_{n_swa}.torch'
    if to_del.exists():
        to_del.unlink()
    
    to_del = checkpoints / f'{epoch-n_swa}.torch'
    if to_del.exists():
        to_del.unlink()
    
    start_epoch = epoch
    
train_writer.close()
eval_writer.close()
swa_writer.close()

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Epoch: 1', max=954, style=ProgressStyle(description_width='in…

Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


RuntimeError: ignored