In [1]:
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils import data

from tqdm.auto import tqdm

from sklearn import model_selection as ms

In [2]:
import os
import utils as u

In [3]:
DATA_DIR = "../data/graph/"

In [4]:
os.listdir(DATA_DIR)

['out_edges_train.npy',
 'nodes_train.npy',
 'in_edges_train.npy',
 'nodes_test.npy',
 'in_edges_test.npy']

In [5]:
train_nodes = np.load(DATA_DIR + "nodes_train.npy")
train_edges = np.load(DATA_DIR + "in_edges_train.npy")

y = np.load(DATA_DIR + "out_edges_train.npy")

test_nodes = np.load(DATA_DIR + "nodes_test.npy")
test_edges = np.load(DATA_DIR + "in_edges_test.npy")

In [6]:
sub = pd.read_csv("../data/raw/sample_submission.csv")

In [7]:
def print_shapes(**kwargs):
    for key, item in dict(**kwargs).items():
        print(f"{key:^11} has shape of {str(item.shape):^10}")

In [8]:
print_shapes(train_nodes=train_nodes, train_edges=train_edges, y=y, test_nodes=test_nodes, test_edges=test_edges)

train_nodes has shape of (85003, 29, 5)
train_edges has shape of (85003, 29, 29, 6)
     y      has shape of (85003, 29, 29, 1)
test_nodes  has shape of (45772, 29, 5)
test_edges  has shape of (45772, 29, 29, 6)


In [9]:
class SimpleNet(nn.Module):
    
    def __init__(self, max_atoms, max_edges, dim_edges,
                 atom_embedding_dim=32, edge_hidden_dim=64, edge_conv_dim=32, projection_dim=32):
        super().__init__()
        
        self.max_edges = max_edges
        self.projection_dim = projection_dim
        self.edge_conv_dim = edge_conv_dim
        
        self.atom_embedding = nn.Linear(in_features=max_atoms, out_features=atom_embedding_dim)
        self.edge_hidden = nn.Linear(in_features=dim_edges, out_features=edge_hidden_dim)
        
        self.edge_convolution = nn.Conv2d(in_channels=edge_hidden_dim, out_channels=edge_conv_dim, kernel_size=5, padding=2, stride=1)
        self.edge_batchnorm = nn.BatchNorm2d(edge_conv_dim)
        
        self.out_projection = nn.Linear(
            in_features=(edge_conv_dim*max_edges)+atom_embedding_dim,
            out_features=projection_dim*max_edges
        )
        self.out_batchnorm = nn.BatchNorm2d(projection_dim)
        
        self.final_layer = nn.Linear(in_features=projection_dim, out_features=1)
        
    def forward(self, atoms, edges):
        e_atoms = F.relu(self.atom_embedding(atoms))
        
        e_edges = F.relu(self.edge_hidden(edges))
        e_edges = F.relu(self.edge_convolution(e_edges.permute(0, 3, 1, 2)))
        e_edges = self.edge_batchnorm(e_edges).permute(0, 2, 3, 1)
        
        e_edges = e_edges.reshape(-1, self.max_edges, self.max_edges*self.edge_conv_dim)
        concat = torch.cat([e_edges, e_atoms], dim=-1)
        #print(concat.size())
        concat = F.relu(self.out_projection(concat))
        #print(concat.size())
        concat = concat.reshape(-1, self.max_edges, self.max_edges, self.projection_dim)
        concat = self.out_batchnorm(concat.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        
        return self.final_layer(concat)
        

In [10]:
valid_losses = []
train_losses = []
n_folds = 5
n_epochs = 100
random_state = 42
device = torch.device("cuda:0")
batch_size = 1024
lr = 0.001
cv = ms.KFold(n_splits=n_folds, random_state=random_state, shuffle=True)

netconfig = {
    'atom_embedding_dim': 64,
    'edge_hidden_dim': 128,
    'edge_conv_dim': 64,
    'projection_dim': 16,
    'max_atoms': 5,
    'max_edges': 29,
    'dim_edges': 6
}

In [11]:
def custom_collate_fn(samples):
    return samples[0], samples[1], samples[2]

In [12]:
split_idx = 0

for train_idx, valid_idx in cv.split(train_nodes):
    print("-"*50)
    print(f"Training for split {split_idx}")
    split_idx += 1
    train_part = data.TensorDataset(
        torch.from_numpy(train_nodes[train_idx]),
        torch.from_numpy(train_edges[train_idx]),
        torch.from_numpy(y[train_idx])
    )
    
    valid_part = data.TensorDataset(
        torch.from_numpy(train_nodes[valid_idx]),
        torch.from_numpy(train_edges[valid_idx]),
        torch.from_numpy(y[valid_idx])
    )
    
    train_loader = data.DataLoader(train_part, batch_size=batch_size, shuffle=False)
    valid_loader = data.DataLoader(valid_part, batch_size=batch_size, shuffle=False)
    
    #s, a, c = train_dataset[train_idx[:128]]
    #print(s.shape, a.shape, c.shape)
    
    model = SimpleNet(**netconfig)
    model = model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    lr_sched = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=7, threshold=0.001)
    
    criterion = nn.L1Loss()
    step_counter = 0
    split_train_loss = []
    split_valid_loss = []
    
    for i in range(n_epochs):
        
        train_loss = []
        valid_loss = []
        
        print(f"    Epoch {i}")
        model.train()
        for atoms, edges, y_true in tqdm(train_loader, desc='train'):
            atoms = atoms.to(device)
            edges = edges.to(device)
            y_true = y_true.to(device)
            
            predict = model(atoms, edges)
            loss = torch.log(criterion(predict, y_true))
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            
            train_loss.append(loss.item())
            
        model.eval()
        with torch.no_grad():
            for atoms, edges, y_true in tqdm(valid_loader, desc='valid'):
                atoms = atoms.to(device)
                edges = edges.to(device)
                y_true = y_true.to(device)
            
                predict = model(atoms, edges)
                loss = torch.log(criterion(predict, y_true))
                
                valid_loss.append(loss.item())
                
        lr_sched.step(np.mean(valid_loss))
                
        print(f"    Mean train mae: {np.mean(train_loss)}")
        print(f"    Mean valid mae: {np.mean(valid_loss)}")
        
        split_train_loss += train_loss
        split_valid_loss += valid_loss
        
    print(f"Mean split train loss: {np.mean(split_train_loss)}")
    print(f"Mean split valid loss: {np.mean(split_valid_loss)}")
    
    train_losses.append(np.mean(split_train_loss))
    valid_losses.append(np.mean(split_valid_loss))

--------------------------------------------------
Training for split 0
    Epoch 0


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8687830582928302
    Mean valid mae: 0.840569291044684
    Epoch 1


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8224171936067183
    Mean valid mae: 0.8115405370207394
    Epoch 2


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.7787399605583789
    Mean valid mae: 0.7293095658807194
    Epoch 3


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.6997012697271446
    Mean valid mae: 0.6363918465726516
    Epoch 4


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.5932632429184698
    Mean valid mae: 0.42828862193752737
    Epoch 5


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.4814817054725405
    Mean valid mae: 0.3041509679135154
    Epoch 6


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.3287595498250492
    Mean valid mae: 0.324185584397877
    Epoch 7


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.1156814741804751
    Mean valid mae: 0.2757790870964527
    Epoch 8


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: -0.17599685818874347
    Mean valid mae: -0.246809939689496
    Epoch 9


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: -0.4396240502151091
    Mean valid mae: 0.08847612142562866
Mean split train loss: 0.4073206546177754
Mean split valid loss: 0.4191881683600299
--------------------------------------------------
Training for split 1
    Epoch 0


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8638627193312147
    Mean valid mae: 0.8428405067499947
    Epoch 1


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8236885302102388
    Mean valid mae: 0.8068692351088804
    Epoch 2


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.7741315298116029
    Mean valid mae: 0.715521118220161
    Epoch 3


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.6962289478240619
    Mean valid mae: 0.7119101566426894
    Epoch 4


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.5974966366427825
    Mean valid mae: 0.5161946860744673
    Epoch 5


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.46364048412486686
    Mean valid mae: 0.3384794321130304
    Epoch 6


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.3017002750125561
    Mean valid mae: 0.19476356784648755
    Epoch 7


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.05908847811506755
    Mean valid mae: -0.05344729619237649
    Epoch 8


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: -0.2584360842593014
    Mean valid mae: -0.20518724795650034
    Epoch 9


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: -0.5074919895077847
    Mean valid mae: -0.19800367513123682
Mean split train loss: 0.38139095273053053
Mean split valid loss: 0.3669940483475597
--------------------------------------------------
Training for split 2
    Epoch 0


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.865646870056195
    Mean valid mae: 0.8446395327063168
    Epoch 1


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8458778751430227
    Mean valid mae: 0.8375322380486656
    Epoch 2


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8281516899813467
    Mean valid mae: 0.8101274563985712
    Epoch 3


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.781049844044358
    Mean valid mae: 0.7377345649635091
    Epoch 4


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.6988966697409972
    Mean valid mae: 0.5977348479277947
    Epoch 5


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.6018662679551253
    Mean valid mae: 0.5253872249467189
    Epoch 6


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.5056058996871337
    Mean valid mae: 0.4028078554307713
    Epoch 7


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.36597989904068745
    Mean valid mae: 0.35044512196498756
    Epoch 8


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.22155206499676874
    Mean valid mae: 0.30974602480145064
    Epoch 9


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: -0.0003902820788615786
    Mean valid mae: 0.4776791036128998
Mean split train loss: 0.5714236798566773
Mean split valid loss: 0.5893833970801685
--------------------------------------------------
Training for split 3
    Epoch 0


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.904532804417966
    Mean valid mae: 0.8591085076332092
    Epoch 1


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8531878615048394
    Mean valid mae: 0.844373007031048
    Epoch 2


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8454048913806232
    Mean valid mae: 0.8387231116785723
    Epoch 3


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.821402297535939
    Mean valid mae: 0.8065133611945545
    Epoch 4


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.7499988001229158
    Mean valid mae: 0.7087167825768975
    Epoch 5


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.6586196998606867
    Mean valid mae: 0.5763116240063134
    Epoch 6


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.5685284043759552
    Mean valid mae: 0.4723591098890585
    Epoch 7


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.47359734306584544
    Mean valid mae: 0.3576294429161969
    Epoch 8


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.37452563826939955
    Mean valid mae: 0.3170818094821537
    Epoch 9


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.22889136840631044
    Mean valid mae: 0.13184782336739934
Mean split train loss: 0.6478689108940481
Mean split valid loss: 0.5912664579775404
--------------------------------------------------
Training for split 4
    Epoch 0


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8776167591116322
    Mean valid mae: 0.8447428775184295
    Epoch 1


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8502527800068926
    Mean valid mae: 0.8428459667107638
    Epoch 2


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.846524287070801
    Mean valid mae: 0.8437264517826193
    Epoch 3


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.8297124398288442
    Mean valid mae: 0.8240408353945788
    Epoch 4


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.7749961425564182
    Mean valid mae: 0.7536335079108968
    Epoch 5


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.6297314199271487
    Mean valid mae: 0.5515051233198713
    Epoch 6


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.4869718029539087
    Mean valid mae: 0.569588848773171
    Epoch 7


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.3321738399898828
    Mean valid mae: 0.23241635937901103
    Epoch 8


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: 0.13678424043664292
    Mean valid mae: 1.055854832424837
    Epoch 9


HBox(children=(IntProgress(value=0, description='train', max=67, style=ProgressStyle(description_width='initia…




HBox(children=(IntProgress(value=0, description='valid', max=17, style=ProgressStyle(description_width='initia…


    Mean train mae: -0.11168387339714525
    Mean valid mae: -0.030742491143053698
Mean split train loss: 0.5653079838485026
Mean split valid loss: 0.6487612312071125


In [14]:
CV_train, train_std = np.mean(train_losses), np.std(train_losses)
CV_valid, valid_std = np.mean(valid_losses), np.std(valid_losses)

print(f"CV train logmae score: {CV_train:.4f}+-{train_std:.6f}")
print(f"CV valid logmae score: {CV_valid:.4f}+-{valid_std:.6f}")

CV train logmae score: 0.5147+-0.102776
CV valid logmae score: 0.5231+-0.109542
