In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch_geometric as tg
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset
import sys
from tqdm import tqdm

wd = os.getcwd()

def is_interactive():
    import __main__ as main
    return not hasattr(main, '__file__')
    
if is_interactive():
    model, train_index, train_ratio, layers, parent = "MACE", "all-10", "1", "4", wd
    print("Interactive session")
else:
    model, train_index, train_ratio, layers, parent = sys.argv[1:]

train_ratio = float(train_ratio)
layers = int(layers)

torch.set_default_dtype(torch.float32)
torch.manual_seed(0)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print('torch device:' , device)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


Interactive session
torch device: cuda:0


In [2]:
import sys
sys.path.append(f'{parent}')
sys.path.append(f'{parent}/geometricgnn')
from geometricgnn.src import models as gnn_models
from geometricgnn.src import data as data_func
from geometricgnn.src.utils.train_utility import save_ckp, PearsonCC, train, validate

In [3]:
dataset = data_func.GlassDynDataset(root=f"{parent}")

In [4]:
if train_index == "all":
    shuffle_index = np.arange(len(dataset))
    shuffle_index = np.random.default_rng(seed=0).permutation(shuffle_index)
    index_train = shuffle_index[:round(0.8*len(shuffle_index)*train_ratio)]
    index_val = shuffle_index[round(0.8*len(shuffle_index)):round(0.9*len(shuffle_index))]
    index_test = shuffle_index[round(0.9*len(shuffle_index)):]
    
    train_dataset = Subset(dataset, index_train)
    val_dataset = Subset(dataset, index_val)
    test_dataset = Subset(dataset, index_test)
    
    in_dim = 4
    time_features = True
    
elif train_index == "all-10":
    num_config = len(dataset) / 8
    
    index_train = np.array([list(range(                    int(num_config)*i, int(num_config*0.8)+int(num_config)*i, 1)) for i in range(1,8,1)]).flatten()
    index_val   = np.array([list(range(int(num_config*0.8)+int(num_config)*i, int(num_config*0.9)+int(num_config)*i, 1)) for i in range(1,8,1)]).flatten()
    index_test  = np.array([list(range(int(num_config*0.9)+int(num_config)*i, int(num_config    )+int(num_config)*i, 1)) for i in range(1,8,1)]).flatten()
    
    train_dataset = Subset(dataset, index_train)
    val_dataset = Subset(dataset, index_val)
    test_dataset = Subset(dataset, index_test)
    
    in_dim = 4
    time_features = True
    
else:
    train_index = int(train_index)
    train_dataset = dataset[round((train_index) * len(dataset) / 8) : round(((train_index) * len(dataset) / 8) + train_ratio * 0.8 * len(dataset) / 8)]
    val_dataset = dataset[round(((train_index) * len(dataset) / 8) + 0.8* len(dataset) / 8) : round(((train_index) * len(dataset) / 8) + 0.9* len(dataset) / 8)]
    test_dataset = dataset[round(((train_index) * len(dataset) / 8) + 0.9* len(dataset) / 8) : round(((train_index) * len(dataset) / 8) + len(dataset) / 8)]
    in_dim = 3
    time_features = False
    
batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

for data in test_loader:
    Na_ind = np.where(np.array(data.symbol).flatten()=="Na")
    Si_ind = np.where(np.array(data.symbol).flatten()=="Si")
    O_ind = np.where(np.array(data.symbol).flatten()=="O")
    break
    
Na_ind_batch = np.hstack([Na_ind[0] + i*3000 for i in range(batch_size)])
Si_ind_batch = np.hstack([Si_ind[0] + i*3000 for i in range(batch_size)])
O_ind_batch = np.hstack([O_ind[0] + i*3000 for i in range(batch_size)])

out_dim = 1

In [5]:
if model == "TFN":
    model = gnn_models.TFNModel(
        in_dim=in_dim,
        emb_dim=16,
        out_dim=1,
        max_ell=2,
        num_layers=layers,
        r_max=5,
        time_features=time_features,
        avg_num_neighbors = 38
    )

elif model == "Schnet":
    model = gnn_models.SchNetModel(
        in_dim=in_dim,
        out_dim=1,
        num_layers=layers,
        hidden_channels=16,
        num_filters = 16, 
        num_gaussians = 16, 
        cutoff=5,
        time_features=time_features,
    )

elif model == "MACE":
    model = gnn_models.MACEModel(
        in_dim=in_dim,
        emb_dim=16,
        out_dim=1,
        max_ell=2,
        correlation=3,
        num_layers=layers,
        r_max=5,
        time_features=time_features,
        avg_num_neighbors = 38
    )

elif model == "EGNN":
    model = gnn_models.EGNNModel(
        num_layers=layers,
        emb_dim=16,
        in_dim=in_dim,
        out_dim=1,
        activation="softplus",
        norm="layer",
        aggr="sum",
        residual=True,
        time_features=time_features,
    )



In [6]:
loss_fn = torch.nn.MSELoss()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print('torch device:' , device)

opt = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.90, 0.999), weight_decay=10e-8, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.75, patience=10)

torch device: cuda:0


In [None]:
model = model.to(device)
lowest_loss = 10

for e in tqdm(range(1000)):
    train_loss = train(model, opt, loss_fn, train_loader, device)
    val_loss = validate(model, loss_fn, val_loader, device)
    scheduler.step(val_loss)
    
    state = {
    'epoch': e + 1,
    'state_dict': model.state_dict(),
    'optimizer': opt.state_dict(),
    'scheduler': scheduler.state_dict()
    }
    
    if val_loss < lowest_loss:
        save_ckp(state, True)
        lowest_loss = val_loss
        best_epoch = e
    else:
        save_ckp(state, False)
        
    with open('log.dat', mode='a') as file:
        file.write(f"{train_loss:.4f}, {val_loss:.4f}\n")

    if e > best_epoch + 100:
        break
        
    print(f"Epoch: {e}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

  0%|▏                                                                                                                                                                                                | 1/1000 [16:07<268:33:28, 967.78s/it]

Epoch: 0, Train Loss: 0.2874, Val Loss: 4.8935
