In [5]:
from util import *
from training_heterogeneous import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

### Configurations

In [6]:
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'feature_vec',
]

database = 'homogeneous-graph-compressed-emb'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (1 Conv layer, 2 linear layers + dropout) trained on homogeneous graph (publication nodes with title and abstract, co-author edges) using Triplet Loss and dimension reduced embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HeteroGATEncoder1Conv2LinearDropout
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homo_edges compressed_emb 1Conv 2Linear dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

### Training Configuration

In [7]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config)


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = GraphTripletDataset(train_triplets, gs, config=config)
test_dataset = GraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = GraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(metadata, config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-11-06 09:47:00,511 - DatabaseWrapper - INFO - Connecting to the database ...
2024-11-06 09:47:00,512 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 11755 triplets.
Train size: 9991, Test size: 1175, Eval size: 589


### Training Loop

In [8]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(
                test_accuracies,
                epoch_len=2,
                plot_title='Test Accuracy',
                plot_avg=False, 
                x_label='Test Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(
                eval_accuracies, 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                plot_avg=False, 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
Correct positive: 1175 (100.00%), Correct negative: 0 (0.00%)
Total correct: 1175 (50.00%)
Test Loss: 1.0096, Test Accuracy: 0.5000
    Eval Results:
Correct positive: 589 (100.00%), Correct negative: 0 (0.00%)
Total correct: 589 (50.00%)
Eval Loss: 1.0170, Eval Accuracy: 0.5000
___ Current Batch 156/313 _________________________
    Test Results:
Correct positive: 509 (43.32%), Correct negative: 880 (74.89%)
Total correct: 1389 (59.11%)
Test Loss: 0.7189, Test Accuracy: 0.5911
    Eval Results:
Correct positive: 153 (25.98%), Correct negative: 509 (86.42%)
Total correct: 662 (56.20%)
Eval Loss: 0.7596, Eval Accuracy: 0.5620
Saving model at epoch 1...


Epoch 2/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
Correct positive: 304 (25.87%), Correct negative: 1007 (85.70%)
Total correct: 1311 (55.79%)
Test Loss: 0.6892, Test Accuracy: 0.5579
    Eval Results:
Correct positive: 91 (15.45%), Correct negative: 541 (91.85%)
Total correct: 632 (53.65%)
Eval Loss: 0.7255, Eval Accuracy: 0.5365
___ Current Batch 156/313 _________________________
    Test Results:
Correct positive: 297 (25.28%), Correct negative: 1069 (90.98%)
Total correct: 1366 (58.13%)
Test Loss: 0.5476, Test Accuracy: 0.5813
    Eval Results:
Correct positive: 60 (10.19%), Correct negative: 557 (94.57%)
Total correct: 617 (52.38%)
Eval Loss: 0.9080, Eval Accuracy: 0.5238
Saving model at epoch 2...


Epoch 3/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
Correct positive: 145 (12.34%), Correct negative: 1112 (94.64%)
Total correct: 1257 (53.49%)
Test Loss: 0.5676, Test Accuracy: 0.5349
    Eval Results:
Correct positive: 29 (4.92%), Correct negative: 564 (95.76%)
Total correct: 593 (50.34%)
Eval Loss: 0.9690, Eval Accuracy: 0.5034
___ Current Batch 156/313 _________________________
    Test Results:
Correct positive: 189 (16.09%), Correct negative: 1126 (95.83%)
Total correct: 1315 (55.96%)
Test Loss: 0.4473, Test Accuracy: 0.5596
    Eval Results:
Correct positive: 49 (8.32%), Correct negative: 562 (95.42%)
Total correct: 611 (51.87%)
Eval Loss: 0.7700, Eval Accuracy: 0.5187
Saving model at epoch 3...


Epoch 4/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
Correct positive: 175 (14.89%), Correct negative: 1123 (95.57%)
Total correct: 1298 (55.23%)
Test Loss: 0.4407, Test Accuracy: 0.5523
    Eval Results:
Correct positive: 29 (4.92%), Correct negative: 566 (96.10%)
Total correct: 595 (50.51%)
Eval Loss: 0.7324, Eval Accuracy: 0.5051
___ Current Batch 156/313 _________________________
    Test Results:
Correct positive: 162 (13.79%), Correct negative: 1139 (96.94%)
Total correct: 1301 (55.36%)
Test Loss: 0.3942, Test Accuracy: 0.5536
    Eval Results:
Correct positive: 27 (4.58%), Correct negative: 564 (95.76%)
Total correct: 591 (50.17%)
Eval Loss: 0.8220, Eval Accuracy: 0.5017
Saving model at epoch 4...


Epoch 5/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
Correct positive: 132 (11.23%), Correct negative: 1139 (96.94%)
Total correct: 1271 (54.09%)
Test Loss: 0.4057, Test Accuracy: 0.5409
    Eval Results:
Correct positive: 31 (5.26%), Correct negative: 565 (95.93%)
Total correct: 596 (50.59%)
Eval Loss: 0.8605, Eval Accuracy: 0.5059
___ Current Batch 156/313 _________________________
    Test Results:
Correct positive: 161 (13.70%), Correct negative: 1132 (96.34%)
Total correct: 1293 (55.02%)
Test Loss: 0.3585, Test Accuracy: 0.5502
    Eval Results:
Correct positive: 28 (4.75%), Correct negative: 565 (95.93%)
Total correct: 593 (50.34%)
Eval Loss: 0.8146, Eval Accuracy: 0.5034
Saving model at epoch 5...


Epoch 6/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
Correct positive: 90 (7.66%), Correct negative: 1140 (97.02%)
Total correct: 1230 (52.34%)
Test Loss: 0.3672, Test Accuracy: 0.5234
    Eval Results:
Correct positive: 30 (5.09%), Correct negative: 572 (97.11%)
Total correct: 602 (51.10%)
Eval Loss: 0.9809, Eval Accuracy: 0.5110
___ Current Batch 156/313 _________________________
    Test Results:
Correct positive: 134 (11.40%), Correct negative: 1135 (96.60%)
Total correct: 1269 (54.00%)
Test Loss: 0.3207, Test Accuracy: 0.5400
    Eval Results:
Correct positive: 29 (4.92%), Correct negative: 561 (95.25%)
Total correct: 590 (50.08%)
Eval Loss: 0.7688, Eval Accuracy: 0.5008


KeyboardInterrupt: 