In [None]:
import torch
from torch_geometric.loader import DataLoader
from sklearn.model_selection  import train_test_split

from tools.utils import *
from tools.het_networks import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [22]:
create_metrics_table

NameError: name 'create_metrics_table' is not defined

In [None]:
# Random
seed = 42

# Model
hidden_channels = 128
num_layers = 2
intra_aggr='sum'
inter_aggr='mean'
dropout = 0.5

# Training
batch_size = 8
epochs = 15
lr = 1e-4
maxlr = 1e-3

In [None]:
# Load dataset
het_dataset = torch.load('data/PSCDB/het_pscdb_graphs.pt', weights_only=False)
len(het_dataset)

856

In [None]:
def normalize_and_recompute_displacement(hetero_data):
    hetero_data = hetero_data.clone()
    for node_type in hetero_data.node_types:
        node_data = hetero_data[node_type]
        if hasattr(node_data, 'x') and node_data.x is not None:
            x = node_data.x
            if x.size(1) >= 6:
                # Normalize free and bound coordinates
                coords = x[:, :6]
                mean = coords.mean(dim=0)
                std = coords.std(dim=0, unbiased=False)
                std[std == 0] = 1.0
                normalized_coords = (coords - mean) / std
                x[:, :6] = normalized_coords
                # Recompute displacement as (bound_normalized - free_normalized)
                free_normalized = normalized_coords[:, :3]
                bound_normalized = normalized_coords[:, 3:6]
                displacement_normalized = bound_normalized - free_normalized
                x[:, 6:9] = displacement_normalized

                hetero_data[node_type].x = x
    return hetero_data

# Normalize features
normalized_het_dataset = [normalize_and_recompute_displacement(het_data) for het_data in het_dataset]

  std = coords.std(dim=0, unbiased=False)


In [None]:
# Create splits
labels = [data.y.item() for data in normalized_het_dataset]

train_set, temp_set = train_test_split(
    normalized_het_dataset, 
    test_size=0.3, 
    stratify=labels,
    random_state=seed
)

temp_labels = [data.y.item() for data in temp_set]
valid_set, test_set = train_test_split(
    temp_set, 
    test_size=0.5, 
    stratify=temp_labels,
    random_state=seed
)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [None]:
edge_types = het_dataset[0].edge_types
node_types = het_dataset[0].node_types
metadata = (node_types, edge_types)

(['A',
  'C',
  'D',
  'E',
  'F',
  'G',
  'H',
  'I',
  'K',
  'L',
  'M',
  'N',
  'P',
  'Q',
  'R',
  'S',
  'T',
  'V',
  'W',
  'Y'],
 [('A', 'edge_index_free', 'A'),
  ('A', 'edge_index_bound', 'A'),
  ('A', 'edge_index_free', 'C'),
  ('A', 'edge_index_bound', 'C'),
  ('A', 'edge_index_free', 'D'),
  ('A', 'edge_index_bound', 'D'),
  ('A', 'edge_index_free', 'E'),
  ('A', 'edge_index_bound', 'E'),
  ('A', 'edge_index_free', 'F'),
  ('A', 'edge_index_bound', 'F'),
  ('A', 'edge_index_free', 'G'),
  ('A', 'edge_index_bound', 'G'),
  ('A', 'edge_index_free', 'H'),
  ('A', 'edge_index_bound', 'H'),
  ('A', 'edge_index_free', 'I'),
  ('A', 'edge_index_bound', 'I'),
  ('A', 'edge_index_free', 'K'),
  ('A', 'edge_index_bound', 'K'),
  ('A', 'edge_index_free', 'L'),
  ('A', 'edge_index_bound', 'L'),
  ('A', 'edge_index_free', 'M'),
  ('A', 'edge_index_bound', 'M'),
  ('A', 'edge_index_free', 'N'),
  ('A', 'edge_index_bound', 'N'),
  ('A', 'edge_index_free', 'P'),
  ('A', 'edge_index_bo

# Exp 1

In [None]:
experiment_name = f"HeteroGNN_GraphConv-{hidden_channels} hidden channels-{num_layers} mlp-{num_layers} conv-{intra_aggr} intra_aggr-{inter_aggr} inter_aggr-{dropout} dropout-{lr} lr-{maxlr} maxlr-OneCylceLR-Adam-CE Loss"

model = HeteroGNN_GraphConv(metadata, hidden_channels, mlp_layers=num_layers, conv_layers=num_layers, intra_aggr=intra_aggr, inter_aggr=inter_aggr, dropout=dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
model

HeteroGNN(
  (conv_blocks): ModuleList(
    (0-1): 2 x ModuleDict(
      (conv): HeteroConv(num_relations=420)
      (post_lin): ModuleDict(
        (A): MLP(
          (layers): ModuleList(
            (0): Linear(-1, 32, bias=True)
            (1): Dropout(p=0.5, inplace=False)
            (2): Linear(32, 32, bias=True)
          )
        )
        (C): MLP(
          (layers): ModuleList(
            (0): Linear(-1, 32, bias=True)
            (1): Dropout(p=0.5, inplace=False)
            (2): Linear(32, 32, bias=True)
          )
        )
        (D): MLP(
          (layers): ModuleList(
            (0): Linear(-1, 32, bias=True)
            (1): Dropout(p=0.5, inplace=False)
            (2): Linear(32, 32, bias=True)
          )
        )
        (E): MLP(
          (layers): ModuleList(
            (0): Linear(-1, 32, bias=True)
            (1): Dropout(p=0.5, inplace=False)
            (2): Linear(32, 32, bias=True)
          )
        )
        (F): MLP(
          (layers): M

In [None]:
batches_per_epoch = len(train_loader)
total_steps = epochs * batches_per_epoch

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=maxlr, total_steps=total_steps, epochs=epochs, cycle_momentum=False)

In [15]:
metrics = {
    'train_loss': [],
    'valid_loss': [],
    'train_acc': [],
    'valid_acc': [],
    'train_f1': [],
    'valid_f1': []
}

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_f1 = train(model, train_loader, optimizer, criterion, het_predict, scheduler=scheduler, device=device)
    valid_loss, valid_acc, valid_f1 = test(model, test_loader, criterion, het_predict, device=device)
    
    # Update metrics
    metrics['train_loss'].append(train_loss)
    metrics['valid_loss'].append(valid_loss)
    metrics['train_acc'].append(train_acc)
    metrics['valid_acc'].append(valid_acc)
    metrics['train_f1'].append(train_f1)
    metrics['valid_f1'].append(valid_f1)
    
    print(f"Epoch {epoch}/{epochs}")
    print(f"Train Loss: {train_loss:.4f} | Validation Loss: {valid_loss:.4f}")
    print(f"Train Acc: {train_acc:.4f} | Validation Acc: {valid_acc:.4f}")
    print(f"Train F1: {train_f1:.4f} | Validation F1: {valid_f1:.4f}\n")

Epoch 1/15
Train Loss: 3.2948 | Validation Loss: 3.0728
Train Acc: 0.1987 | Validation Acc: 0.3488
Train F1: 0.0626 | Validation F1: 0.0739

Epoch 2/15
Train Loss: 2.1901 | Validation Loss: 1.8609
Train Acc: 0.3356 | Validation Acc: 0.3488
Train F1: 0.0730 | Validation F1: 0.0739

Epoch 3/15
Train Loss: 1.8797 | Validation Loss: 1.9210
Train Acc: 0.3406 | Validation Acc: 0.3101
Train F1: 0.0764 | Validation F1: 0.1036

Epoch 4/15
Train Loss: 1.8533 | Validation Loss: 1.8717
Train Acc: 0.3372 | Validation Acc: 0.3256
Train F1: 0.0987 | Validation F1: 0.0954

Epoch 5/15
Train Loss: 1.8119 | Validation Loss: 1.8570
Train Acc: 0.3372 | Validation Acc: 0.3411
Train F1: 0.0887 | Validation F1: 0.0731

Epoch 6/15
Train Loss: 1.8016 | Validation Loss: 1.8325
Train Acc: 0.3406 | Validation Acc: 0.3488
Train F1: 0.0831 | Validation F1: 0.0739

Epoch 7/15
Train Loss: 1.7642 | Validation Loss: 1.8297
Train Acc: 0.3539 | Validation Acc: 0.3411
Train F1: 0.1154 | Validation F1: 0.0739

Epoch 8/15
Tr

In [None]:
plot_metrics(metrics, experiment_name)
create_metrics_table(metrics, experiment_name)

Unnamed: 0,Epoch,Train Loss,Valid Loss,Train Acc,Valid Acc,Train F1,Valid F1
0,1,3.2948,3.0728,0.1987,0.3488,0.0626,0.0739
1,2,2.1901,1.8609,0.3356,0.3488,0.073,0.0739
2,3,1.8797,1.921,0.3406,0.3101,0.0764,0.1036
3,4,1.8533,1.8717,0.3372,0.3256,0.0987,0.0954
4,5,1.8119,1.857,0.3372,0.3411,0.0887,0.0731
5,6,1.8016,1.8325,0.3406,0.3488,0.0831,0.0739
6,7,1.7642,1.8297,0.3539,0.3411,0.1154,0.0739
7,8,1.7667,1.8237,0.3472,0.3566,0.1138,0.0926
8,9,1.7497,1.8394,0.3523,0.3566,0.1162,0.0922
9,10,1.7206,1.8158,0.3606,0.3333,0.1291,0.0731


# Exp 2

In [None]:
experiment_name = f"HeteroGNN_SAGEConv-{hidden_channels} hidden channels-{num_layers} mlp-{num_layers} conv-{intra_aggr} intra_aggr-{inter_aggr} inter_aggr-{dropout} dropout-{lr} lr-{maxlr} maxlr-OneCylceLR-Adam-CE Loss"

model = HeteroGNN_SAGEConv(metadata, hidden_channels, mlp_layers=num_layers, conv_layers=num_layers, intra_aggr=intra_aggr, inter_aggr=inter_aggr, dropout=dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
model

In [None]:
batches_per_epoch = len(train_loader)
total_steps = epochs * batches_per_epoch

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=maxlr, total_steps=total_steps, epochs=epochs, cycle_momentum=False)

In [None]:
metrics = {
    'train_loss': [],
    'valid_loss': [],
    'train_acc': [],
    'valid_acc': [],
    'train_f1': [],
    'valid_f1': []
}

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_f1 = train(model, train_loader, optimizer, criterion, het_predict, scheduler=scheduler, device=device)
    valid_loss, valid_acc, valid_f1 = test(model, test_loader, criterion, het_predict, device=device)
    
    # Update metrics
    metrics['train_loss'].append(train_loss)
    metrics['valid_loss'].append(valid_loss)
    metrics['train_acc'].append(train_acc)
    metrics['valid_acc'].append(valid_acc)
    metrics['train_f1'].append(train_f1)
    metrics['valid_f1'].append(valid_f1)
    
    print(f"Epoch {epoch}/{epochs}")
    print(f"Train Loss: {train_loss:.4f} | Validation Loss: {valid_loss:.4f}")
    print(f"Train Acc: {train_acc:.4f} | Validation Acc: {valid_acc:.4f}")
    print(f"Train F1: {train_f1:.4f} | Validation F1: {valid_f1:.4f}\n")

# Exp 3

In [None]:
experiment_name = f"HeteroGNN_GATConv-{hidden_channels} hidden channels-{num_layers} mlp-{num_layers} conv-{intra_aggr} intra_aggr-{inter_aggr} inter_aggr-{dropout} dropout-{lr} lr-{maxlr} maxlr-OneCylceLR-Adam-CE Loss"

model = HeteroGNN_GATConv(metadata, hidden_channels, mlp_layers=num_layers, conv_layers=num_layers, intra_aggr=intra_aggr, inter_aggr=inter_aggr, dropout=dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
model

In [None]:
batches_per_epoch = len(train_loader)
total_steps = epochs * batches_per_epoch

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=maxlr, total_steps=total_steps, epochs=epochs, cycle_momentum=False)

In [None]:
metrics = {
    'train_loss': [],
    'valid_loss': [],
    'train_acc': [],
    'valid_acc': [],
    'train_f1': [],
    'valid_f1': []
}

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_f1 = train(model, train_loader, optimizer, criterion, het_predict, scheduler=scheduler, device=device)
    valid_loss, valid_acc, valid_f1 = test(model, test_loader, criterion, het_predict, device=device)
    
    # Update metrics
    metrics['train_loss'].append(train_loss)
    metrics['valid_loss'].append(valid_loss)
    metrics['train_acc'].append(train_acc)
    metrics['valid_acc'].append(valid_acc)
    metrics['train_f1'].append(train_f1)
    metrics['valid_f1'].append(valid_f1)
    
    print(f"Epoch {epoch}/{epochs}")
    print(f"Train Loss: {train_loss:.4f} | Validation Loss: {valid_loss:.4f}")
    print(f"Train Acc: {train_acc:.4f} | Validation Acc: {valid_acc:.4f}")
    print(f"Train F1: {train_f1:.4f} | Validation F1: {valid_f1:.4f}\n")