In [None]:
"""
Contains various utility functions for PyTorch build GNN model.
"""

import numpy as np
from tqdm import tqdm


import rdkit
from rdkit import Chem

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter

import torch
import torch.nn as nn
import torch.nn.functional as F

# helper
# torch version of np unpackbits
#https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a

def tensor_dim_slice(tensor, dim, dim_slice):
	return tensor[(dim if dim >= 0 else dim + tensor.dim()) * (slice(None),) + (dim_slice,)]

# @torch.jit.script
def packshape(shape, dim: int = -1, mask: int = 0b00000001, dtype=torch.uint8, pack=True):
	dim = dim if dim >= 0 else dim + len(shape)
	bits, nibble = (
		8 if dtype is torch.uint8 else 16 if dtype is torch.int16 else 32 if dtype is torch.int32 else 64 if dtype is torch.int64 else 0), (
		1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else 0)
	# bits = torch.iinfo(dtype).bits # does not JIT compile
	assert nibble <= bits and bits % nibble == 0
	nibbles = bits // nibble
	shape = (shape[:dim] + (int(math.ceil(shape[dim] / nibbles)),) + shape[1 + dim:]) if pack else (
				shape[:dim] + (shape[dim] * nibbles,) + shape[1 + dim:])
	return shape, nibbles, nibble

# @torch.jit.script
def F_unpackbits(tensor, dim: int = -1, mask: int = 0b00000001, shape=None, out=None, dtype=torch.uint8):
	dim = dim if dim >= 0 else dim + tensor.dim()
	shape_, nibbles, nibble = packshape(tensor.shape, dim=dim, mask=mask, dtype=tensor.dtype, pack=False)
	shape = shape if shape is not None else shape_
	out = out if out is not None else torch.empty(shape, device=tensor.device, dtype=dtype)
	assert out.shape == shape

	if shape[dim] % nibbles == 0:
		shift = torch.arange((nibbles - 1) * nibble, -1, -nibble, dtype=torch.uint8, device=tensor.device)
		shift = shift.view(nibbles, *((1,) * (tensor.dim() - dim - 1)))
		return torch.bitwise_and((tensor.unsqueeze(1 + dim) >> shift).view_as(out), mask, out=out)

	else:
		for i in range(nibbles):
			shift = nibble * i
			sliced_output = tensor_dim_slice(out, dim, slice(i, None, nibbles))
			sliced_input = tensor.narrow(dim, 0, sliced_output.shape[dim])
			torch.bitwise_and(sliced_input >> shift, mask, out=sliced_output)
	return out

class dotdict(dict):
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__
	
	def __getattr__(self, name):
		try:
			return self[name]
		except KeyError:
			raise AttributeError(name)

            
print('helper ok!')

# Setup hyperparameters
PACK_NODE_DIM =9
PACK_EDGE_DIM =1
NODE_DIM =PACK_NODE_DIM*8
EDGE_DIM =PACK_EDGE_DIM*8

class MPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim
        self.mlp_msg = nn.Sequential(
            nn.Linear(2 * emb_dim + edge_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
            nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
        )
        self.mlp_upd = nn.Sequential(
            nn.Linear(2 * emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
            nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
        )

    def forward(self, h, edge_index, edge_attr):
        out = self.propagate(edge_index, h=h, edge_attr=edge_attr)
        return out

    def message(self, h_i, h_j, edge_attr):
        msg = torch.cat([h_i, h_j, edge_attr], dim=-1)
        return self.mlp_msg(msg)

    def aggregate(self, inputs, index):
        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)

    def update(self, aggr_out, h):
        upd_out = torch.cat([h, aggr_out], dim=-1)
        return self.mlp_upd(upd_out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')

class MPNNModel(nn.Module):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=9, edge_dim=4):
        super().__init__()
        self.lin_in = nn.Linear(in_dim, emb_dim)
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))
        self.pool = global_mean_pool

    def forward(self, batch):
        #print(f"Input batch.x shape: {batch.x.shape}")
        h = self.lin_in(F_unpackbits(batch.x,-1).float())  
        #print(f"Shape after lin_in: {h.shape}")   
        for conv in self.convs:
            h = h + conv(h, batch.edge_index.long(), F_unpackbits(batch.edge_attr,-1).float())  # (n, d) -> (n, d)
            #print(f"Shape after conv layer: {h.shape}")
        h_graph = self.pool(h, batch.batch)
        #print(f"Shape after pooling: {h_graph.shape}")  
        return h_graph

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.output_type = ['infer', 'loss']
        graph_dim = 96
        self.smile_encoder = MPNNModel(
            in_dim=NODE_DIM, edge_dim=EDGE_DIM, emb_dim=graph_dim, num_layers=4,
        )
        self.bind = nn.Sequential(
            nn.Linear(graph_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, 1),
        )

    def forward(self, batch):
        #print(f"Initial batch.y shape: {batch.y.shape}")
        x = self.smile_encoder(batch)
        bind = self.bind(x).squeeze(-1)
        output = {}
        #print(f"Shape of bind: {bind.shape}, Shape of target before any change: {batch.y.shape}")
        if 'loss' in self.output_type:
                       
            target = batch.y.view(-1).float()
            #print(f"Shape of target used for loss: {target.shape}")
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind, target.float())
        if 'infer' in self.output_type:

            probs = torch.sigmoid(bind)
            output['bind'] = probs
            output['preds'] = (probs >= 0.5).float()

        return output




