In [1]:
%%HTML
<style>
   div#notebook-container    { width: 95%; }
   div#menubar-container     { width: 65%; }
   div#maintoolbar-container { width: 99%; }
</style>

In [1]:
import os
from datetime import datetime
PROJECT_PATH = os.path.dirname(os.path.realpath('__file__').replace('/lib',''))
IDENTIFIER   = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

import pickle 

#numerical libs
import math
import numpy as np
import random
import PIL
import cv2
import matplotlib
# matplotlib.use('TkAgg')
#matplotlib.use('WXAgg')
#matplotlib.use('Qt4Agg')
#matplotlib.use('Qt5Agg') #Qt4Agg
# print('matplotlib.get_backend : ', matplotlib.get_backend())
#print(matplotlib.__version__)


# torch libs
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parallel.data_parallel import data_parallel

from torch.nn.utils.rnn import *


# std libs
import collections
import copy
import numbers
import inspect
import shutil
from timeit import default_timer as timer
import itertools
from collections import OrderedDict
from multiprocessing import Pool
import multiprocessing as mp

#from pprintpp import pprint, pformat
import json
import zipfile



import csv
import pandas as pd
import pickle
import glob
import sys
from distutils.dir_util import copy_tree
import time


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from dscribe.descriptors import ACSF
from dscribe.core.system import System

import torch_geometric.nn as gnn

import networkx as nx
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig

import rdkit.Chem.Draw
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions
DrawingOptions.bondLineWidth=1.8

from rdkit.Chem.rdmolops import SanitizeFlags

# constant #
PI  = np.pi
INF = np.inf
EPS = 1e-12

In [2]:
file_folder = '../../data/input'
os.listdir(file_folder)

['sample_submission.csv',
 'magnetic_shielding_tensors.csv',
 'potential_energy.csv',
 'scalar_coupling_contributions.csv',
 'dipole_moments.csv',
 'mulliken_charges.csv',
 'train.csv',
 'test.csv',
 'structures.csv',
 'structures']

In [3]:
train = pd.read_csv(f'{file_folder}/train.csv')
train_molecule_names = train.molecule_name.unique().tolist()

In [4]:
test = pd.read_csv(f'{file_folder}/test.csv')
test_molecule_names = test.molecule_name.unique().tolist()

In [5]:
graph_dir='../../data/temp/pytorch_geometric2'

In [6]:
SYMBOL=['H', 'C', 'N', 'O', 'F']

ACSF_GENERATOR = ACSF(
    species=SYMBOL,
    rcut=6.0,
    g2_params=[[1, 1], [1, 2], [1, 3]],
    g4_params=[[1, 1, 1], [1, 2, 1], [1, 1, -1], [1, 2, -1]],
)

COUPLING_TYPE_STATS=[
    #type   #mean, std, min, max
    '1JHC',  94.9761528641869,   18.27722399839607,   66.6008,   204.8800,
    '2JHC',  -0.2706244378832,    4.52360876732858,  -36.2186,    42.8192,
    '3JHC',   3.6884695895355,    3.07090647005439,  -18.5821,    76.0437,
    '1JHN',  47.4798844844683,   10.92204561670947,   24.3222,    80.4187,
    '2JHN',   3.1247536134185,    3.67345877025737,   -2.6209,    17.7436,
    '3JHN',   0.9907298624944,    1.31538940138001,   -3.1724,    10.9712,
    '2JHH', -10.2866051639817,    3.97960190019757,  -35.1761,    11.8542,
    '3JHH',   4.7710233597359,    3.70498129755812,   -3.0205,    17.4841,
]
NUM_COUPLING_TYPE = len(COUPLING_TYPE_STATS)//5

COUPLING_TYPE_MEAN = [ COUPLING_TYPE_STATS[i*5+1] for i in range(NUM_COUPLING_TYPE)]
COUPLING_TYPE_STD  = [ COUPLING_TYPE_STATS[i*5+2] for i in range(NUM_COUPLING_TYPE)]
COUPLING_TYPE      = [ COUPLING_TYPE_STATS[i*5  ] for i in range(NUM_COUPLING_TYPE)]


#---

SYMBOL=['H', 'C', 'N', 'O', 'F']

BOND_TYPE = [
    Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
    Chem.rdchem.BondType.AROMATIC,
]
HYBRIDIZATION=[
    #Chem.rdchem.HybridizationType.S,
    Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3,
    #Chem.rdchem.HybridizationType.SP3D,
    #Chem.rdchem.HybridizationType.SP3D2,
]

In [7]:
def time_to_str(t, mode='min'):
    if mode=='min':
        t  = int(t)/60
        hr = t//60
        min = t%60
        return '%2d hr %02d min'%(hr,min)

    elif mode=='sec':
        t   = int(t)
        min = t//60
        sec = t%60
        return '%2d min %02d sec'%(min,sec)


    else:
        raise NotImplementedError

def read_pickle_from_file(pickle_file):
    with open(pickle_file,'rb') as f:
        x = pickle.load(f)
    return x

def write_pickle_to_file(pickle_file, x):
    with open(pickle_file, 'wb') as f:
        pickle.dump(x, f, pickle.HIGHEST_PROTOCOL)

# http://stackoverflow.com/questions/34950201/pycharm-print-end-r-statement-not-working
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout  #stdout
        self.file = None

    def open(self, file, mode=None):
        if mode is None: mode ='w'
        self.file = open(file, mode)

    def write(self, message, is_terminal=1, is_file=1 ):
        if '\r' in message: is_file=0

        if is_terminal == 1:
            self.terminal.write(message)
            self.terminal.flush()
            #time.sleep(1)

        if is_file == 1:
            self.file.write(message)
            self.file.flush()

    def flush(self):
        # this flush method is needed for python 3 compatibility.
        # this handles the flush command by doing nothing.
        # you might want to specify some extra behavior here.
        pass

In [8]:
class ChampsDataset(Dataset):
    def __init__(self, molecule_names, graph_file_path):

        self.id   = molecule_names
        self.graph_file_path = graph_file_path
        return

        #zz=0
        #self.dummy_graph = read_pickle_from_file(DATA_DIR + '/structure/graph/dsgdb9nsd_000001.pickle')

    def __str__(self):
            return 'str'

    def __len__(self):
        return len(self.id)


    def __getitem__(self, index):

        molecule_name = self.id[index]
        graph_file = self.graph_file_path + '/%s.pickle'%molecule_name
        graph = read_pickle_from_file(graph_file)
        assert(graph.molecule_name==molecule_name)

        ##filter only J link
        if 0:
            # 1JHC,     2JHC,     3JHC,     1JHN,     2JHN,     3JHN,     2JHH,     3JHH
            mask = np.zeros(len(graph.coupling.type),np.bool)
            for t in ['1JHC',     '2JHH']:
                mask += (graph.coupling.type == COUPLING_TYPE.index(t))

            graph.coupling.id = graph.coupling.id [mask]
            graph.coupling.contribution = graph.coupling.contribution [mask]
            graph.coupling.index = graph.coupling.index [mask]
            graph.coupling.type = graph.coupling.type [mask]
            graph.coupling.value = graph.coupling.value [mask]

        if 1:
            atom = System(symbols =graph.axyz[0], positions=graph.axyz[1])
            acsf = ACSF_GENERATOR.create(atom)
            graph.node += [acsf, graph.axyz[1]]


        # if 1:
        #     graph.edge = graph.edge[:-1]

        graph.node = np.concatenate(graph.node,-1)
        graph.edge = np.concatenate(graph.edge,-1)
        return graph

In [9]:
# net ------------------------------------
# https://github.com/pytorch/examples/blob/master/imagenet/main.py ###############
def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_learning_rate(optimizer):
    lr=[]
    for param_group in optimizer.param_groups:
       lr +=[ param_group['lr'] ]

    assert(len(lr)==1) #we support only one param_group
    lr = lr[0]

    return lr

def null_collate(batch):

    batch_size = len(batch)

    node = []
    edge = []
    edge_index  = []
    node_batch_index = []

    coupling_index = []
    coupling_type  = []
    coupling_value = []
    coupling_batch_index = []
    infor = []

    offset = 0
    for b in range(batch_size):
        graph = batch[b]
        #print(graph.molecule_name)

        num_node = len(graph.node)
        node.append(graph.node)
        edge.append(graph.edge)
        edge_index.append(graph.edge_index+offset)
        node_batch_index.append([b]*num_node)

        num_coupling = len(graph.coupling.value)
        coupling_index.append(graph.coupling.index+offset)
        coupling_type.append (graph.coupling.type)
        coupling_value.append(graph.coupling.value)
        coupling_batch_index.append([b]*num_coupling)

        infor.append((graph.molecule_name, graph.smiles, graph.coupling.id))
        offset += num_node


    node = torch.from_numpy(np.concatenate(node)).float()
    edge = torch.from_numpy(np.concatenate(edge)).float()
    edge_index  = torch.from_numpy(np.concatenate(edge_index).astype(np.int32)).long()
    node_batch_index = torch.from_numpy(np.concatenate(node_batch_index)).long()


    coupling_index = torch.from_numpy(np.concatenate(coupling_index)).long()
    coupling_type  = torch.from_numpy(np.concatenate(coupling_type )).long()
    coupling_value = torch.from_numpy(np.concatenate(coupling_value)).float()
    coupling_batch_index = torch.from_numpy(np.concatenate(coupling_batch_index)).long()
    return node,edge,edge_index, node_batch_index, \
           coupling_index,coupling_type,coupling_value,coupling_batch_index, infor

In [10]:
class NullScheduler():
    def __init__(self, lr=0.01 ):
        super(NullScheduler, self).__init__()
        self.lr    = lr
        self.cycle = 0

    def __call__(self, time):
        return self.lr

    def __str__(self):
        string = 'NullScheduler\n' \
                + 'lr=%0.5f '%(self.lr)
        return string

class LinearBn(nn.Module):
    def __init__(self, in_channel, out_channel, act=None):
        super(LinearBn, self).__init__()
        self.linear = nn.Linear(in_channel, out_channel, bias=False)
        self.bn   = nn.BatchNorm1d(out_channel,eps=1e-05, momentum=0.1)
        self.act  = act

    def forward(self, x):
        x = self.linear(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.act is not None:
            x = self.act(x)
        return x


#message passing
class Net(torch.nn.Module):
    def __init__(self, node_dim=96, edge_dim=16, num_target=8):
        super(Net, self).__init__()

        self.num_message_passing = 6
        node_hidden_dim=128
        edge_hidden_dim=128

        self.preprocess = nn.Sequential(
            LinearBn(node_dim, 256),
            nn.ReLU(),
            LinearBn(256, node_hidden_dim),
        )
        edge_net = nn.Sequential(
            LinearBn(edge_dim, 32),
            nn.ReLU(),
            LinearBn(32, 64),
            nn.ReLU(),
            LinearBn(64, edge_hidden_dim),
            nn.ReLU(),
            LinearBn(edge_hidden_dim, node_hidden_dim * node_hidden_dim) # edge_hidden_dim,  node_hidden_dim *node_hidden_dim
        )

        self.conv = gnn.NNConv(node_hidden_dim, node_hidden_dim, edge_net, aggr='mean', root_weight=True) #node_hidden_dim, node_hidden_dim
        self.gru  = nn.GRU(node_hidden_dim, node_hidden_dim)
        self.set2set = gnn.Set2Set(node_hidden_dim, processing_steps=6) # node_hidden_dim

        #predict coupling constant
        self.predict = nn.Sequential(
            LinearBn(4*node_hidden_dim, 512),  #node_hidden_dim
            nn.ReLU(),
            nn.Linear(512, num_target),
        )

    def forward(self, node, edge, edge_index, node_batch_index, coupling_index, coupling_type, coupling_batch_index):

        #----
        edge_index = edge_index.t().contiguous()

        x = F.relu(self.preprocess(node))
        h = x.unsqueeze(0)

        for i in range(self.num_message_passing):
            m    = F.relu(self.conv(x, edge_index, edge))
            x, h = self.gru(m.unsqueeze(0), h)
            x = x.squeeze(0)
        #x =  num_node, node_hidden_dim

        pool = self.set2set(x, node_batch_index) # global pool
        pool = torch.index_select(
            pool,
            dim=0,
            index=coupling_batch_index
        )
        x = torch.index_select(
            x,
            dim=0,
            index=coupling_index.view(-1)
        ).reshape(len(coupling_index),-1)

        x = torch.cat([pool,x],-1)
        predict = self.predict(x)

        predict = torch.gather(predict,1,coupling_type.view(-1,1)).view(-1)
        return predict


# def criterion(predict, coupling_value):
#     predict = predict.view(-1)
#     coupling_value = coupling_value.view(-1)
#     assert(predict.shape==coupling_value.shape)
#
#     loss = F.mse_loss(predict, coupling_value)
#     return loss


def criterion(predict, coupling_value):
    predict = predict.view(-1)
    coupling_value = coupling_value.view(-1)
    assert(predict.shape==coupling_value.shape)

    loss = torch.abs(predict-coupling_value)
    loss = loss.mean()
    loss = torch.log(loss)
    return loss

def compute_kaggle_metric( predict, coupling_value, coupling_type):

    mae     = [None]*NUM_COUPLING_TYPE
    log_mae = [None]*NUM_COUPLING_TYPE
    diff = np.fabs(predict-coupling_value)
    for t in range(NUM_COUPLING_TYPE):
        index = np.where(coupling_type==t)[0]
        if len(index)>0:
            m = diff[index].mean()
            log_m = np.log(m+1e-8)

            mae[t] = m
            log_mae[t] = log_m
        else:
            pass

    return mae, log_mae

In [11]:
# train_molecule_names_valid = random.sample(train_molecule_names, 5000)

In [12]:
# train_molecule_names_train = [ m for m in train_molecule_names if m not in  train_molecule_names_valid]

In [13]:
# with open('train_molecule_names_valid.pickle', 'wb') as f:
#     pickle.dump(train_molecule_names_valid, f)
    
# with open('train_molecule_names_train.pickle', 'wb') as f:
#     pickle.dump(train_molecule_names_valid, f)

In [14]:
with open('train_molecule_names_valid.pickle', 'rb') as f:
    train_molecule_names_valid = pickle.load(f)

with open('train_molecule_names_train.pickle', 'rb') as f:
    train_molecule_names_train = pickle.load(f)

In [15]:
def do_valid(net, valid_loader):

    valid_num = 0
    valid_predict = []
    valid_coupling_type  = []
    valid_coupling_value = []

    valid_loss = 0
    for b, (node,edge,edge_index,node_batch_index,
            coupling_index,coupling_type,coupling_value,coupling_batch_index,
            infor) in enumerate(valid_loader):

        #if b==5: break
        net.eval()
        node = node.cuda()
        edge = edge.cuda()
        edge_index  = edge_index.cuda()
        node_batch_index = node_batch_index.cuda()
        coupling_index = coupling_index.cuda()
        coupling_type  = coupling_type.cuda()
        coupling_value = coupling_value.cuda()
        coupling_batch_index = coupling_batch_index.cuda()

        with torch.no_grad():
            predict = net(node,edge,edge_index,node_batch_index, coupling_index,coupling_type,coupling_batch_index)
            loss = criterion(predict, coupling_value)

        #---
        batch_size = len(infor)
        valid_predict.append(predict.data.cpu().numpy())
        valid_coupling_type.append(coupling_type.data.cpu().numpy())
        valid_coupling_value.append(coupling_value.data.cpu().numpy())

        valid_loss += batch_size*loss.item()
        valid_num  += batch_size

#         print('\r',end='',flush=True)
        print('\r %8d /%8d'%(valid_num, len(valid_loader.dataset)),end='',flush=True)

        pass  #-- end of one data loader --
    assert(valid_num == len(valid_loader.dataset))
    #print('')
    valid_loss = valid_loss/valid_num

    #compute
    predict = np.concatenate(valid_predict)
    coupling_value = np.concatenate(valid_coupling_value)
    coupling_type  = np.concatenate(valid_coupling_type).astype(np.int32)
    mae, log_mae   = compute_kaggle_metric( predict, coupling_value, coupling_type,)

    num_target = NUM_COUPLING_TYPE
    for t in range(NUM_COUPLING_TYPE):
        if mae[t] is None:
            mae[t] = 0
            log_mae[t]  = 0
            num_target -= 1

    mae_mean, log_mae_mean = sum(mae)/num_target, sum(log_mae)/num_target
    #list(np.stack([mae, log_mae]).T.reshape(-1))

    valid_loss = log_mae + [valid_loss,mae_mean, log_mae_mean, ]
    return valid_loss

In [16]:
save_folder = '.'
net_file = 'net'
opt_file = 'optimizer'
batch_size= 16
EDGE_DIM   =  6
NODE_DIM   = 96 ##  93  13
NUM_TARGET =  8

train_dataset = ChampsDataset(train_molecule_names_train, graph_dir)

train_loader  = DataLoader(
                train_dataset,
                sampler     = RandomSampler(train_dataset),
                batch_size  = batch_size,
                drop_last   = True,
                num_workers = 16,
                pin_memory  = True,
                collate_fn  = null_collate
    )

valid_dataset = ChampsDataset(train_molecule_names_valid, graph_dir)

valid_loader = DataLoader(
            valid_dataset,
            #sampler     = SequentialSampler(valid_dataset),
            sampler     = RandomSampler(valid_dataset),
            batch_size  = batch_size,
            drop_last   = False,
            num_workers = 0,
            pin_memory  = True,
            collate_fn  = null_collate
)

net = Net(node_dim=NODE_DIM,edge_dim=EDGE_DIM, num_target=NUM_TARGET).cuda()
if type(net_file)!=type(None) and os.path.exists(f'{save_folder}/{net_file}'):
    net.load_state_dict(torch.load(f'{save_folder}/{net_file}', map_location=lambda storage, loc: storage))
    print('net loaded')
schduler = NullScheduler(lr=0.001)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr=schduler(0))
if type(opt_file)!=type(None) and os.path.exists(f'{save_folder}/{opt_file}'):
    checkpoint  = torch.load(f'{save_folder}/{opt_file}')
    optimizer.load_state_dict(checkpoint['optimizer'])
    print('optimizer loaded')

# schduler = NullScheduler(lr=0.0001)
# net = Net(node_dim=NODE_DIM,edge_dim=EDGE_DIM, num_target=NUM_TARGET).cuda()
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr=schduler(0))


iter_accum  = 1
num_iters   = 3000  *1000
iter_smooth = 50
iter_log    = 500
iter_valid  = 500
iter_save   = [0, num_iters-1]+ list(range(0, num_iters, 2500))#1*1000

start_iter = 0
start_epoch= 0
rate       = 0


train_loss   = np.zeros(20,np.float32)
valid_loss   = np.zeros(20,np.float32)
batch_loss   = np.zeros(20,np.float32)
iter_ = 0
i    = 0

start = timer()
while  iter_<num_iters:
    sum_train_loss = np.zeros(20,np.float32)
    sum_ = 0

    optimizer.zero_grad()
    for node,edge,edge_index,node_batch_index, coupling_index,coupling_type,coupling_value,coupling_batch_index, infor in train_loader:

        #while 1:
        batch_size = len(infor)
        iter_  = i + start_iter
        epoch = (iter_-start_iter)*batch_size/len(train_dataset) + start_epoch

        #if 0:
        if (iter_ % iter_valid==0):
            valid_loss = do_valid(net, valid_loader) #
            if type(save_folder) != type(None):
                torch.save(net.state_dict(),f'{save_folder}/{net_file}')
                torch.save({'optimizer': optimizer.state_dict()}, f'{save_folder}/{opt_file}')

        # learning rate schduler -------------
        lr = schduler(iter)
        if lr<0 : break
        adjust_learning_rate(optimizer, lr)
        rate = get_learning_rate(optimizer)

        # one iteration update  -------------
        #net.set_mode('train',is_freeze_bn=True)

        net.train()
        node = node.cuda()
        edge = edge.cuda()
        edge_index = edge_index.cuda()
        node_batch_index = node_batch_index.cuda()
        coupling_index = coupling_index.cuda()
        coupling_type  = coupling_type.cuda()
        coupling_value = coupling_value.cuda()
        coupling_batch_index = coupling_batch_index.cuda()

        predict = net(node,edge,edge_index,node_batch_index, coupling_index,coupling_type,coupling_batch_index)
        loss = criterion(predict, coupling_value)

        (loss/iter_accum).backward()
        if (iter_ % iter_accum)==0:
            optimizer.step()
            optimizer.zero_grad()

        # print statistics  ------------
        batch_loss[:1] = [loss.item()]
        sum_train_loss += batch_loss
        sum_ += 1
        if iter_%iter_smooth == 0:
            train_loss = sum_train_loss/sum_
            sum_train_loss = np.zeros(20,np.float32)
            sum_ = 0

        print('\r',end='',flush=True)
        print('| %+5.3f  %0.2f %+0.2f | %+5.3f | %s' % (*valid_loss[8:11], batch_loss[0], time_to_str((timer() - start),'min')) , end='',flush=True)
        i=i+1


    pass  #-- end of one data loader --
pass #-- end of all iterations --

net loaded
optimizer loaded
| -1.093  0.32 -1.23 | -1.099 |  5 hr 50 min

Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/tensorflow_gpu_p36/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/home/ubuntu/anaconda3/envs/tensorflow_gpu_p36/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/ubuntu/anaconda3/envs/tensorflow_gpu_p36/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/ubuntu/anaconda3/envs/tensorflow_gpu_p36/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/tensorflow_gpu_p36/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/home/ubuntu/anaconda3/envs/tensorflow_gpu_p36/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send

KeyboardInterrupt: 

In [None]:
def predict(net, test_loader):
    df_pred = pd.DataFrame()
    N_ = len(test_loader)
    for b, (node,edge,edge_index,node_batch_index,
            coupling_index,coupling_type,coupling_value,coupling_batch_index,
            infor) in enumerate(test_loader):

        #if b==5: break
        net.eval()
        node = node.cuda()
        edge = edge.cuda()
        edge_index  = edge_index.cuda()
        node_batch_index = node_batch_index.cuda()
        coupling_index = coupling_index.cuda()
        coupling_type  = coupling_type.cuda()
        coupling_value = coupling_value.cuda()
        coupling_batch_index = coupling_batch_index.cuda()

        with torch.no_grad():
            predict = net(node,edge,edge_index,node_batch_index, coupling_index,coupling_type,coupling_batch_index)
        
#         print(predict.cpu().detach().numpy().shape, infor[0][2].shape)
        print(f'{b}/{N_}', end='',flush=True)
        print('\r',end='',flush=True)
        df_pred_i = pd.DataFrame({'id':infor[0][2], 'scalar_coupling_constant':predict.cpu().detach().numpy() })
        df_pred = pd.concat([df_pred, df_pred_i], axis=0)
    return df_pred

In [None]:
test = pd.read_csv(f'{file_folder}/test.csv')
test_molecule_names = test.molecule_name.unique().tolist()

In [None]:
test_dataset = ChampsDataset(test_molecule_names, graph_dir)

test_loader  = DataLoader(
                test_dataset,
                sampler     = RandomSampler(test_dataset),
                batch_size  = 1,
                drop_last   = True,
                num_workers = 16,
                pin_memory  = True,
                collate_fn  = null_collate
    )

In [None]:
df_pred = predict(net, test_loader)

In [None]:
print(df_pred.shape)
df_pred.head()

In [None]:
idx=-0.96
# df_test_pred = df_trial.loc[idx]['df_test_pred']
df_submit = pd.DataFrame()
df_submit['scalar_coupling_constant'] = df_pred['scalar_coupling_constant']#np.mean(df_pred.drop(columns=['id']).values, axis=1)
df_submit['id'] = df_pred['id']
df_submit.to_csv('../../data/submission/submission_gnn_{}.csv'.format(idx), index=False)