In [None]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../..")
print(os.getcwd())
import torch
import numpy as np
import random

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
seed = 21
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
from modules.data_pipeline import DataPipeline
pipeline = DataPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')

In [None]:
import modules.datasplit_module as dsm
# --- Split graphs ---
random.shuffle(graph_list)
sampled_graph_list = graph_list
train, val, test = \
    dsm.system_disjoint_split(sampled_graph_list, random_state=seed, stratify_by_components=True)

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    dataset=train[:100],
    batch_size=1024,
    shuffle=True,
    follow_batch=['component_batch']
)

val_loader = DataLoader(
    dataset=val[:100],
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

test_loader = DataLoader(
    dataset=test[:100],
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

In [None]:
import modules.trainer_module as tm
import modules.dtmpnn as gm
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#=======Architecture setting=======
constraint_type = 'soft' # or soft
include_gd = True
gd_weight = 1

if constraint_type == 'hard':
    include_gd = False
else:
    include_gd = include_gd
#==================================

# Create model
model = gm.DTMPNN(
    node_dim=train[0].x.shape[1],
    edge_dim=train[0].edge_attr.shape[1],
    graph_hidden_dim=64,
    latent_dim=64,
    context_dim=64,
    graph_layers=3,
    constraint_type=constraint_type
).to(device)

# Initialize trainer
trainer = tm.DTMPNNTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    include_gd=include_gd,
    device=device,
    lr=0.002,
    weight_decay=1e-07,
    gd_weight=gd_weight
)

# Train the model
history = trainer.train(
    epochs=10,
    save_dir=f'notebooks/training_phase/__debug__/{constraint_type}_constraint/GD_backprop_{include_gd}',
    log_file_path=f'notebooks/training_phase/__debug__/{constraint_type}_constraint/GD_backprop_{include_gd}/log.txt',
    save_best=True,
    save_every=25
)

# Plot training curves
trainer.plot_history(save_path=f'notebooks/training_phase/__debug__/{constraint_type}_constraint/GD_backprop_{include_gd}/history.png')
torch.cuda.empty_cache()