In [1]:
import torch
from torch_geometric.data import InMemoryDataset, download_url, Data, Batch
from torch import nn
from torch import functional as F
import os
import pandas as pd
import numpy as np
import pickle
import itertools
import jax
from jax import numpy as jnp
import networkx as nx
from scipy.spatial.distance import pdist, squareform
from sklearn.preprocessing import MinMaxScaler
import mendeleev

In [2]:
from mendeleev import element

In [3]:
C, H, N, O, F = element(["C", "H", "N", "O", "F"])

In [4]:
rows = []
for e in [C, H, N, O, F]:
    row = [e.atomic_radius,
           e.atomic_volume,
           e.atomic_weight,
           e.boiling_point,
           e.covalent_radius_bragg,
           e.dipole_polarizability,
           e.electron_affinity,
           e.en_ghosh,
           e.vdw_radius]
    rows.append(row)

In [5]:
properties=np.array(rows)

In [6]:
properties[properties == None] = 0

In [7]:
emms = MinMaxScaler([-5, 5])
props = properties.astype(np.float32)
props = emms.fit_transform(props)

In [8]:
def get_edge_features(coords, dipole_moment, eps = 0.000001):
    norm_dipole = np.linalg.norm(dipole_moment) # get the norm to normalize the vector and find angles
    distmat = squareform(pdist(coords)) # get dist_mat to find edges
    np.fill_diagonal(distmat, np.nan) # fill to avoid ranking problem
    rankings = distmat.argsort(axis=1).argsort(axis=1) # order distance matrix to get n-neighborhood of each edge
    G = nx.from_numpy_matrix(distmat, create_using=nx.DiGraph())
    G.remove_edges_from(nx.selfloop_edges(G)) # remove to avoid self edges
    edgelist = nx.to_edgelist(G)
    edgelist, edge_features = edge_list_to_numpy(edgelist) #get edges and distances
    G = nx.from_numpy_matrix(rankings+1, create_using=nx.DiGraph())
    G.remove_edges_from(nx.selfloop_edges(G))
    edgelist2 = nx.to_edgelist(G)
    edgelist2, edge_rankings = edge_list_to_numpy(edgelist2) # get rankings
    edge_rankings = edge_rankings - 1
    n_edges = edgelist.shape[0]
    coords_nodes = coords[edgelist] #select nodes in edges
    vectors_edges = coords_nodes.transpose(0, 2, 1)[:, :, 1] - coords_nodes.transpose(0, 2, 1)[:, :, 0] # get vector of each edge
    vectors_edges_normalized = vectors_edges/(np.linalg.norm(vectors_edges, axis=1) + eps)[:,None] 
    dipole_moment_normalized = dipole_moment/(norm_dipole + eps)
    angles_dipole_moment = np.dot(vectors_edges_normalized, dipole_moment) # do dot product to find angles
    ranks = np.zeros([n_edges, 6])
    edge_rankings[edge_rankings > 4] = 5 # replace to impose same dimensionality in all molecules
    ranks[np.arange(n_edges), edge_rankings] = 1
    edge_features = np.concatenate([edge_features[:,None], ranks, angles_dipole_moment[:,None]], axis=1) # concatenate all features
    return edgelist, edge_features

In [9]:
def edge_list_to_numpy(edgelist):
    tail = []
    head = []
    weight = []
    for edge in list(edgelist):
        tail.append(edge[0])
        head.append(edge[1])
        weight.append(edge[2]["weight"])
    tail = np.array(tail)[:,None]
    head = np.array(head)[:,None]
    weight = np.array(weight)
    return np.concatenate([tail, head], axis=1), weight


In [10]:
target = pd.read_csv("raw/train.csv")

In [11]:
edges = target["type"].unique()

Exception during reset or similar
Traceback (most recent call last):
  File "/home/pedro/anaconda3/envs/mpi/lib/python3.9/site-packages/sqlalchemy/pool/base.py", line 682, in _finalize_fairy
    fairy._reset(pool)
  File "/home/pedro/anaconda3/envs/mpi/lib/python3.9/site-packages/sqlalchemy/pool/base.py", line 887, in _reset
    pool._dialect.do_rollback(self)
  File "/home/pedro/anaconda3/envs/mpi/lib/python3.9/site-packages/sqlalchemy/engine/default.py", line 667, in do_rollback
    dbapi_connection.rollback()
sqlite3.ProgrammingError: SQLite objects created in a thread can only be used in that same thread. The object was created in thread id 139799173416768 and this is thread id 139799094974208.
IOStream.flush timed out
Exception closing connection <sqlite3.Connection object at 0x7f23d7b58210>
Traceback (most recent call last):
  File "/home/pedro/anaconda3/envs/mpi/lib/python3.9/site-packages/sqlalchemy/pool/base.py", line 682, in _finalize_fairy
    fairy._reset(pool)
  File "/hom

In [12]:
edge_to_int = {edges[i]:i for i in range(len(edges))}

IOStream.flush timed out
IOStream.flush timed out


In [13]:
def get_dihedral(coords, indices):
    a = coords[indices[0],:] - coords[indices[1],:]
    b = coords[indices[2],:] - coords[indices[1],:]
    return a@b/(np.linalg.norm(b) * np.linalg.norm(a))

In [14]:
class SCDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, training= True):
        super().__init__(root, transform, pre_transform)
        self.filenames = pd.read_csv("raw/processed_names.csv")
        self.charges = pd.read_csv("raw/mulliken_charges.csv")
        self.magnetic_shieldings = pd.read_csv("raw/magnetic_shielding_tensors.csv")
        self.dipole_moments = pd.read_csv("raw/dipole_moments.csv")
        self.potential_energy = pd.read_csv("raw/potential_energy.csv")
        self.target = pd.read_csv("raw/train.csv")
        self.structures = pd.read_csv("raw/structures.csv")
        self.molecule_names = molecule_names = np.unique(self.potential_energy["molecule_name"])
        if training:
            self.training_mask = np.loadtxt("./raw/training_mask2.csv").astype(bool)
            self.molecule_names = self.molecule_names[self.training_mask]
            
    def len(self) -> int:
        return len(self.molecule_names)
    def standarize(self):
        mms = MinMaxScaler([-4, 4])
        self.charges["mulliken_charge"] = mms.fit_transform(self.charges[["mulliken_charge"]]).squeeze()
        self.magnetic_shieldings.iloc[:, 2:] = mms.fit_transform(self.magnetic_shieldings.iloc[:, 2:])
        self.dipole_moments.iloc[:, 1:] = mms.fit_transform(self.dipole_moments.iloc[:, 1:])
        self.potential_energy["potential_energy"] = mms.fit_transform(self.potential_energy[["potential_energy"]]).squeeze()

    def preprocess(self, k = None):
        charges = self.charges
        magnetic_shieldings = self.magnetic_shieldings
        dipole_moments = self.dipole_moments
        potential_energy = self.potential_energy
        molecule_names = self.molecule_names
        target = self.target
        structures = self.structures
        dfs = [charges, magnetic_shieldings, dipole_moments, potential_energy, target, structures]
        for i in range(len(dfs)):
            dfs[i] = dfs[i].set_index("molecule_name", drop=True)
        charges, magnetic_shieldings, dipole_moments, potential_energy, target, structures = dfs
        atoms = structures["atom"].unique()
        atoms_id = {atoms[i]:i for i in range(len(atoms))}
        training_mask = []
        for x, name in enumerate(list(molecule_names)):
            any_training_edges = True
            coords = structures.loc[name][["x", "y", "z"]].to_numpy()
            n_nodes = coords.shape[0]
            print("{}/{}".format(x + 1, len(molecule_names)), end = "\r")
            # adj_mat
            atom_types = structures.loc[name]["atom"].replace(atoms_id).to_numpy()
            atom_onehot = np.zeros([n_nodes, len(atoms)])
            atom_onehot[np.arange(0, n_nodes), atom_types] = 1
            charge = charges.loc[name]["mulliken_charge"].to_numpy()
            shieldings = magnetic_shieldings.loc[name].iloc[:, 2:].to_numpy()
            node_features = np.concatenate([charge[:, None], shieldings, atom_onehot, coords], axis=1)
            with open("./processed2/{}_node_attr.csv".format(name), "wb") as f:
                np.savetxt(f, node_features)
            try:
                edges_target = target.loc[[name]]
                training_mask.append(True)
            except KeyError:
                training_mask.append(False)
                any_training_edges = False
                
            if any_training_edges:
                edges_target["type"] = edges_target["type"].replace(edge_to_int).astype(np.int64)
                scalar_coupling = edges_target.loc[:, ["atom_index_0", "atom_index_1","type","scalar_coupling_constant"]].to_numpy()
            else:
                scalar_coupling = np.array([-1, -1, -1, 0])
            with open("./processed2/{}_target.csv".format(name), "wb") as f:
                np.savetxt(f, scalar_coupling)
            # Graph features
            dipole_moment = dipole_moments.loc[name]
            norm_dipole = np.array([np.linalg.norm(dipole_moment)])
            potential = potential_energy.loc[name]
            graph_features = (np.concatenate([dipole_moment, norm_dipole, potential, np.array([n_nodes]), atom_onehot.sum(axis=0)]))
            with open("./processed2/{}_graph_features.csv".format(name), "wb") as f:
                np.savetxt(f, graph_features)
            # edge_features
            edgelist, edge_attr = get_edge_features(coords, dipole_moment)
            # metaedge
            second_nhood = edgelist[edge_attr[:, 2] == 1]
            second_nhood = np.unique(np.sort(second_nhood, axis=1), axis=0)
            first_nhood = edgelist[edge_attr[:, 1] == 1]
            first_nhood = np.unique(np.sort(first_nhood, axis=1), axis=0)
            is_in_first = [x in first_nhood.tolist() for x in second_nhood.tolist()]
            second_nhood = second_nhood[~np.array(is_in_first)]
            meta_edges = []
            normal_edges = []
            for i in first_nhood:
                for j in first_nhood:
                    meta_edge = np.unique(np.concatenate([i, j]))
                    if len(meta_edge) == 3:
                        for k in second_nhood:
                            is_in_meta = np.in1d(k, meta_edge)
                            if is_in_meta.all():
                                center_node = (np.setdiff1d(meta_edge, k))
                                meta_edges.append(np.array([k[0], center_node[0], k[1]]))
                                normal_edges.append(k)
            if len(meta_edges) > 0:         
                unique_dihedral_edges, is_unique = np.unique(np.stack(normal_edges), axis=0, return_index=True)
                meta_edges_un = np.stack(meta_edges)[is_unique]
                angles = []
                for me in meta_edges_un.astype(np.int64):
                    angles.append(get_dihedral(coords, me))
                dihedral_angles = np.zeros(edgelist.shape[0])
                for i in range(len(angles)):
                    is_edge = np.where(edgelist == unique_dihedral_edges[i], 1, 0).all(axis=1)
                    dihedral_angles[is_edge] = angles[i]
                    is_edge = np.where(edgelist == unique_dihedral_edges[i][::-1], 1, 0).all(axis=1)
                    dihedral_angles[is_edge] = angles[i]
            else:
                dihedral_angles = np.zeros(edgelist.shape[0])
            with open("./processed2/{}_edge_list.csv".format(name), "wb") as f:
                np.savetxt(f, edgelist)
            with open("./processed2/{}_edgeattr.csv".format(name), "wb") as f:
                np.savetxt(f, np.concatenate([edge_attr, dihedral_angles[:,None]], axis=1))
                
    def mem_load(self):
        self.mem = {}
        for i, molecule in enumerate(self.molecule_names):
            graph_features = pd.read_csv("./processed2/{}_graph_features.csv".format(molecule), sep=" ", header=None).to_numpy()
            node_features = pd.read_csv("./processed2/{}_node_attr.csv".format(molecule), sep=" ", header=None).to_numpy()
            atomtypes = node_features[:,-5:].argmax(axis=1)
            prop_atoms = props[atomtypes,:]
            n_nodes = node_features.shape[0]
            graph_features = np.tile(graph_features, [1, n_nodes]).T
            node_features = np.concatenate([prop_atoms, node_features, graph_features], axis=1)
            target =  pd.read_csv("./processed2/{}_target.csv".format(molecule), sep=" ", header=None).to_numpy()
            edge_type = target[:,2]
            edge_type = np.concatenate([edge_type, edge_type], axis=0)
            edges_target = target[:,0:2]
            target = target[:,3]
            target = np.concatenate([target, target])
            edges_target = np.concatenate([edges_target, edges_target[:,::-1]], axis=0)
            edge_list = pd.read_csv("./processed2/{}_edge_list.csv".format(molecule), sep=" ", header=None).to_numpy()
            edge_attr = pd.read_csv("./processed2/{}_edgeattr.csv".format(molecule), sep=" ", header=None).to_numpy()
            data = Data(x=torch.Tensor(node_features), edge_index = torch.Tensor(edge_list).T, y=torch.Tensor(target), edge_attr = torch.Tensor(edge_attr))
            data.nodes_target = torch.Tensor(edges_target)
            data.nodes = n_nodes
            data.edges = edge_list.shape[0]
            data.types = torch.Tensor(edge_type)
            # data.edge_cross = edgelist
            # data.nodes = node_features.shape[0]
            self.mem[molecule] = data
            print("{}/{}".format(i, len(self.molecule_names)), end = "\r")
            
    def __getitem__(self, idx):
        molecule = self.molecule_names[idx]
        return self.mem[molecule]
        

def get_distance_matrix(X, k=None):
    dist = squareform(pdist(X))
    if k is not None:
        non_k = dist.argsort(axis=1)[:, k+1:]
        dist[np.arange(0, dist.shape[0])[:,None], non_k] = 0
    return dist

def to_batch(list_graphs):
    n_nodes = 0
    for graph in list_graphs:
        graph["nodes_target"] += n_nodes
        n_nodes += graph.nodes
    return Batch.from_data_list(list_graphs) 

In [15]:
dataset = SCDataset('/media/pedro/data/projects/scalar_coupling', training=False)
dataset.standarize()
dataset.preprocess()

In [16]:
dataset = SCDataset('/media/pedro/data/projects/scalar_coupling', training=True)

In [17]:
dataset.mem_load()

In [19]:
train_length = int(len(dataset) * 0.9)
test_length = len(dataset) - train_length
train_set, val_set = torch.utils.data.random_split(dataset, [train_length, test_length])
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=128, collate_fn=to_batch, shuffle=True, num_workers=20)
test_dataloader = torch.utils.data.DataLoader(val_set, batch_size=128, collate_fn=to_batch, shuffle=True, num_workers=20)

In [20]:
from torch_geometric.nn import GCNConv, GATv2Conv, GATConv, SAGEConv, PDNConv
import torch.nn as nn
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)


In [21]:
class GATv2EncoderGated(nn.Module):
    def __init__(self, num_node_features, hidden_features, heads, n_layers, p_dropout):
        super().__init__()
        self.p_dropout = p_dropout
        assert n_layers > 1
        self.init_conv = GATv2Conv(num_node_features, hidden_features, heads=n_heads, dropout=p_dropout,  edge_dim=9, concat=False)
        self.layers = nn.ModuleList([GATv2Conv(hidden_features, hidden_features, heads=n_heads, dropout=p_dropout,  edge_dim=9, concat=False) for i in range(n_layers-1)])
        self.gates = nn.Parameter(torch.Tensor(n_layers))
        self.init_conv.apply(init_weights)
        for conv in self.layers:
            conv.apply(init_weights)
    def forward(self, x, edge_index, edge_attr):
        range_gates = torch.sigmoid(self.gates)
        x = self.init_conv(x, edge_index, edge_attr)
        for i, layer in enumerate(self.layers):
            x = F.leaky_relu(x)
            x = (range_gates[i])*layer(x, edge_index, edge_attr) + (1-range_gates[i])*x
        return x
    
class ResNetGated(nn.Module):
    def __init__(self, init_dim, hidden_dim, layers, p_dropout):
        super().__init__()
        self.p_dropout = p_dropout
        assert n_layers > 1
        self.layers = nn.ModuleList([nn.Sequential(nn.Linear(init_dim, hidden_dim),
                             nn.ReLU(),
                             nn.Dropout(p=p_dropout),
                             nn.Linear( hidden_dim, init_dim)) for i in range(layers)])
        self.gates = nn.Parameter(torch.Tensor(n_layers))
        self.layers.apply(init_weights)
    def forward(self, x):
        range_gates = torch.sigmoid(self.gates)
        for i, layer in enumerate(self.layers):
            x = F.relu(x)
            x = (range_gates[i])*layer(x) + (1-range_gates[i])*x
        return x

    
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, out_features, n_heads, n_layers, n_res, p_dropout):
        super().__init__()
        self.conv = GATv2EncoderGated(num_node_features, out_features, heads=n_heads, p_dropout=p_dropout,  n_layers=n_layers)
        self.fcs = nn.ModuleList([nn.Sequential(ResNetGated(out_features*2, out_features*64, n_res, p_dropout),
                               nn.Linear(2 * out_features, 1)) for i in range(8)])
        for fc in self.fcs:
            fc.apply(init_weights)
    def forward(self, x, edge_index, edge_attr, edge_cross, types):
        x = self.conv(x, edge_index, edge_attr)
        x = x[edge_cross]
        shp = x.shape
        x = x.transpose(1, 2).reshape([shp[0], shp[2]*2])
        xs = []
        for i in range(8):
            xs.append(self.fcs[i](x[types == i]))
        x = torch.concat(xs, axis=0)
        return x

In [22]:
### Define the loss function
loss_fn = nn.MSELoss

lr= 0.00078
weight_decay = 1.76755e-08
p_dropout = 0.0343089
conv_features = 128
n_heads = 3
n_layers = 3
n_res = 3
### Set the random seed for reproducible results
torch.manual_seed(0)

gcn = GCN(37, conv_features, n_heads, n_layers, n_res, p_dropout=p_dropout)
params_to_optimize = [
    {'params': gcn.parameters()}
]

optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=weight_decay)
# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

# Move both the encoder and the decoder to the selected device
gcn.to(device)

Selected device: cuda


GCN(
  (conv): GATv2EncoderGated(
    (init_conv): GATv2Conv(37, 128, heads=3)
    (layers): ModuleList(
      (0): GATv2Conv(128, 128, heads=3)
      (1): GATv2Conv(128, 128, heads=3)
    )
  )
  (fcs): ModuleList(
    (0): Sequential(
      (0): ResNetGated(
        (layers): ModuleList(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=8192, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.0343089, inplace=False)
            (3): Linear(in_features=8192, out_features=256, bias=True)
          )
          (1): Sequential(
            (0): Linear(in_features=256, out_features=8192, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.0343089, inplace=False)
            (3): Linear(in_features=8192, out_features=256, bias=True)
          )
          (2): Sequential(
            (0): Linear(in_features=256, out_features=8192, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.0343089, inplace=False)
            (3): L

In [23]:
import datetime
date = datetime.datetime.today().strftime('%Y-%m-%d-%H:%M:%S')

RUN = date + "_lr={}_wd={}_p={}_conv_features={}_n_layers={}_n_res={}".format(lr,
                                                            weight_decay,
                                                            p_dropout,
                                                            conv_features,
                                                             n_layers,
                                                            n_res
                                                            )
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("./runs/{}".format(RUN))

In [24]:
def train(model, device, data ,loss_fn, optimizer, batch_acc=2):
    model.train()
    train_losses = []
    optimizer.zero_grad()
    for i, batch in enumerate(data):
        x, edge_index, edge_attr, target, edge_cross, types = (batch["x"],
                                                               batch["edge_index"],
                                                               batch["edge_attr"],
                                                               batch["y"],
                                                               batch["nodes_target"],
                                                               batch["types"])
        types_cpu = types.numpy()
        sort_index = torch.Tensor(types.numpy().argsort(kind="stable")).long()
        target = target[sort_index]
        x, edge_index, edge_attr, target, edge_cross, types = x.to(device), \
                                                            edge_index.long().to(device), \
                                                            edge_attr.to(device), \
                                                            target.to(device),\
                                                            edge_cross.long().to(device), \
                                                            types.long().to(device)
        logits = model(x, edge_index, edge_attr, edge_cross, types)
        loss = loss_fn()
        output=loss(logits.squeeze(), target.squeeze())
        output.backward()
        if ((i+1)%batch_acc) == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
            optimizer.step()
        train_loss = output.data.cpu().numpy()
        train_losses.append(train_loss)
        
    return np.mean(train_losses)

### Testing function
def test(model, device, data, loss_fn):
    # Set evaluation mode for encoder and decoder
    model.eval()
    test_losses = []
    with torch.no_grad(): # No need to track the gradients
        for i, batch in enumerate(data):
            x, edge_index, edge_attr, target, edge_cross, types = (batch["x"],
                                                               batch["edge_index"],
                                                               batch["edge_attr"],
                                                               batch["y"],
                                                               batch["nodes_target"],
                                                               batch["types"])
            types_cpu = types.numpy()
            sort_index = torch.Tensor(types.numpy().argsort(kind="stable")).long()
            target = target[sort_index]
            x, edge_index, edge_attr, target, edge_cross, types = x.to(device), \
                                                                edge_index.long().to(device), \
                                                                edge_attr.to(device), \
                                                                target.to(device),\
                                                                edge_cross.long().to(device), \
                                                                types.long().to(device)
            logits = model(x, edge_index, edge_attr, edge_cross, types)
            loss = loss_fn()
            output=loss(logits.squeeze(), target.squeeze())
            test_loss = output.data.cpu().numpy()
            test_losses.append(test_loss)
    return np.mean(test_losses)

In [25]:
best_loss = 1000
num_epochs = 100
diz_loss = {'train_loss':[],'val_loss':[]}
decay = 0.95
for epoch in range(num_epochs):
#     if epoch == 0:
#         for param in optim.param_groups:
#             param["lr"] = param["lr"]/(10e3)
#     elif epoch <= 3:
#         for param in optim.param_groups:
#             param["lr"] = param["lr"]*10
    train_loss = train(gcn, device, train_dataloader, loss_fn, optim)
    test_loss = test(gcn, device, test_dataloader, loss_fn)
    print('\n EPOCH {}/{} \t train loss {} \t \t val loss {}'.format(epoch + 1, num_epochs, train_loss, test_loss))
    diz_loss['train_loss'].append(train_loss)
    diz_loss['val_loss'].append(test_loss)
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/test', test_loss, epoch)
    if test_loss < best_loss:
        best_loss = test_loss
        torch.save(gcn.state_dict(), "./saved_models/with_resnet_{}.pth".format(RUN))


 EPOCH 1/100 	 train loss 354.1801452636719 	 	 val loss 35.28205490112305

 EPOCH 2/100 	 train loss 29.1060791015625 	 	 val loss 27.516782760620117

 EPOCH 3/100 	 train loss 21.970190048217773 	 	 val loss 17.692537307739258

 EPOCH 4/100 	 train loss 18.177515029907227 	 	 val loss 16.3250675201416

 EPOCH 5/100 	 train loss 13.73880386352539 	 	 val loss 18.16177749633789

 EPOCH 6/100 	 train loss 11.553083419799805 	 	 val loss 9.14149284362793

 EPOCH 7/100 	 train loss 9.843188285827637 	 	 val loss 8.704882621765137

 EPOCH 8/100 	 train loss 7.67933464050293 	 	 val loss 6.186276912689209

 EPOCH 9/100 	 train loss 6.731311798095703 	 	 val loss 5.988966941833496

 EPOCH 10/100 	 train loss 5.772371768951416 	 	 val loss 5.999088764190674

 EPOCH 11/100 	 train loss 5.645722389221191 	 	 val loss 4.658374309539795

 EPOCH 12/100 	 train loss 4.833165168762207 	 	 val loss 3.7202491760253906

 EPOCH 13/100 	 train loss 4.319706439971924 	 	 val loss 3.936246871948242

 EPOC

In [26]:
def train(model, device, data ,loss_fn, optimizer, batch_acc=1):
    model.train()
    train_losses = []
    optimizer.zero_grad()
    for i, batch in enumerate(data):
        x, edge_index, edge_attr, target, edge_cross, types = (batch["x"],
                                                               batch["edge_index"],
                                                               batch["edge_attr"],
                                                               batch["y"],
                                                               batch["nodes_target"],
                                                               batch["types"])
        types_cpu = types.numpy()
        sort_index = torch.Tensor(types.numpy().argsort(kind="stable")).long()
        target = target[sort_index]
        x, edge_index, edge_attr, target, edge_cross, types = x.to(device), \
                                                            edge_index.long().to(device), \
                                                            edge_attr.to(device), \
                                                            target.to(device),\
                                                            edge_cross.long().to(device), \
                                                            types.long().to(device)
        logits = model(x, edge_index, edge_attr, edge_cross, types)
        loss = loss_fn()
        output=loss(logits.squeeze(), target.squeeze())
        output.backward()
        if batch_acc == 1:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
            optimizer.step()
        elif (i%batch_acc == 0) and (i != 0):
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
            optimizer.step()
        train_loss = output.data.cpu().numpy()
        train_losses.append(train_loss)
        
    return np.mean(train_losses)

In [27]:
batch_acc = 1
for epoch in range(num_epochs):
    if epoch % 20 == 0:
        batch_acc *= 2
        for param in optim.param_groups:
            param["weight_decay"] = param["weight_decay"]/2
    train_loss = train(gcn, device, train_dataloader, loss_fn, optim, batch_acc)
    test_loss = test(gcn, device, test_dataloader, loss_fn)
    print('\n EPOCH {}/{} \t train loss {} \t \t val loss {}'.format(epoch + 1, num_epochs, train_loss, test_loss))
    diz_loss['train_loss'].append(train_loss)
    diz_loss['val_loss'].append(test_loss)
    if test_loss < best_loss:
        best_loss = test_loss
        torch.save(gcn.state_dict(), "./saved_models/with_resnet_{}.pth".format(RUN))


 EPOCH 1/100 	 train loss 0.530784547328949 	 	 val loss 0.6098902225494385

 EPOCH 2/100 	 train loss 0.4895051121711731 	 	 val loss 0.5317975282669067

 EPOCH 3/100 	 train loss 0.44453608989715576 	 	 val loss 0.5469436645507812

 EPOCH 4/100 	 train loss 0.47759273648262024 	 	 val loss 0.6072366237640381

 EPOCH 5/100 	 train loss 0.4598345160484314 	 	 val loss 0.5116789937019348

 EPOCH 6/100 	 train loss 0.5224311351776123 	 	 val loss 0.6302801966667175

 EPOCH 7/100 	 train loss 0.45595717430114746 	 	 val loss 0.545138955116272

 EPOCH 8/100 	 train loss 0.4223570227622986 	 	 val loss 0.4845689833164215

 EPOCH 9/100 	 train loss 0.45652061700820923 	 	 val loss 0.6576845645904541

 EPOCH 10/100 	 train loss 0.4898805022239685 	 	 val loss 0.5497469305992126

 EPOCH 11/100 	 train loss 0.4196653962135315 	 	 val loss 0.4888512194156647

 EPOCH 12/100 	 train loss 0.41594198346138 	 	 val loss 0.4935818612575531

 EPOCH 13/100 	 train loss 0.5658427476882935 	 	 val loss 0

In [28]:
print(RUN)

2022-02-26-23:59:40_lr=0.00078_wd=1.76755e-08_p=0.0343089_conv_features=128_n_layers=3_n_res=3


In [29]:
# with open("./train_dataloader.pkl", "wb") as f:
#     pickle.dump(train_dataloader, f)
# with open("./test_dataloader.pkl", "wb") as f:
#     pickle.dump(test_dataloader, f)