In [9]:
from notebooks.util_homogeneous import HomogeneousGraphTripletDataset
from util_homogeneous import *
from util import plot_loss, save_training_results
from training_homogeneous import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

## Configurations

In [10]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_TITLE

node_properties = [
    'id',
    'feature_vec',
]

database = 'homogeneous-graph-compressed-emb'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on homogeneous graph (publication nodes with title and abstract, co-author edges) using Triplet Loss and dimension reduced embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATEncoderLinearDropout
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (title) compressed_emb linear_layer dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

In [11]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config)


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-11-06 13:04:03,782 - DatabaseWrapper - INFO - Connecting to the database ...
2024-11-06 13:04:03,782 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 11755 triplets.
Train size: 9991, Test size: 1175, Eval size: 589


In [12]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(
                test_accuracies,
                epoch_len=2,
                plot_title='Test Accuracy',
                plot_avg=False, 
                x_label='Test Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(
                eval_accuracies, 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                plot_avg=False, 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 1175 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 1175 (50.00%)
        Test Loss: 1.0668, Test Accuracy: 0.5000
    Eval Results:
        Correct positive: 589 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 589 (50.00%)
        Eval Loss: 1.0338, Eval Accuracy: 0.5000
___ Current Batch 156/313 _________________________
    Test Results:
        Correct positive: 981 (83.49%), Correct negative: 502 (42.72%)
        Total correct: 1483 (63.11%)
        Test Loss: 0.7659, Test Accuracy: 0.6311
    Eval Results:
        Correct positive: 552 (93.72%), Correct negative: 28 (4.75%)
        Total correct: 580 (49.24%)
        Eval Loss: 0.9873, Eval Accuracy: 0.4924
Saving model at epoch 1...


Epoch 2/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 887 (75.49%), Correct negative: 693 (58.98%)
        Total correct: 1580 (67.23%)
        Test Loss: 0.6219, Test Accuracy: 0.6723
    Eval Results:
        Correct positive: 517 (87.78%), Correct negative: 109 (18.51%)
        Total correct: 626 (53.14%)
        Eval Loss: 0.8970, Eval Accuracy: 0.5314
___ Current Batch 156/313 _________________________
    Test Results:
        Correct positive: 954 (81.19%), Correct negative: 631 (53.70%)
        Total correct: 1585 (67.45%)
        Test Loss: 0.5915, Test Accuracy: 0.6745
    Eval Results:
        Correct positive: 499 (84.72%), Correct negative: 156 (26.49%)
        Total correct: 655 (55.60%)
        Eval Loss: 0.8561, Eval Accuracy: 0.5560
Saving model at epoch 2...


Epoch 3/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 930 (79.15%), Correct negative: 752 (64.00%)
        Total correct: 1682 (71.57%)
        Test Loss: 0.5325, Test Accuracy: 0.7157
    Eval Results:
        Correct positive: 528 (89.64%), Correct negative: 86 (14.60%)
        Total correct: 614 (52.12%)
        Eval Loss: 0.8942, Eval Accuracy: 0.5212
___ Current Batch 156/313 _________________________
    Test Results:
        Correct positive: 951 (80.94%), Correct negative: 789 (67.15%)
        Total correct: 1740 (74.04%)
        Test Loss: 0.4609, Test Accuracy: 0.7404
    Eval Results:
        Correct positive: 516 (87.61%), Correct negative: 168 (28.52%)
        Total correct: 684 (58.06%)
        Eval Loss: 0.8367, Eval Accuracy: 0.5806
Saving model at epoch 3...


Epoch 4/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 826 (70.30%), Correct negative: 890 (75.74%)
        Total correct: 1716 (73.02%)
        Test Loss: 0.4364, Test Accuracy: 0.7302
    Eval Results:
        Correct positive: 414 (70.29%), Correct negative: 217 (36.84%)
        Total correct: 631 (53.57%)
        Eval Loss: 0.8787, Eval Accuracy: 0.5357
___ Current Batch 156/313 _________________________
    Test Results:
        Correct positive: 940 (80.00%), Correct negative: 779 (66.30%)
        Total correct: 1719 (73.15%)
        Test Loss: 0.4695, Test Accuracy: 0.7315
    Eval Results:
        Correct positive: 530 (89.98%), Correct negative: 218 (37.01%)
        Total correct: 748 (63.50%)
        Eval Loss: 0.7733, Eval Accuracy: 0.6350


Epoch 5/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 856 (72.85%), Correct negative: 910 (77.45%)
        Total correct: 1766 (75.15%)
        Test Loss: 0.4038, Test Accuracy: 0.7515
    Eval Results:
        Correct positive: 508 (86.25%), Correct negative: 131 (22.24%)
        Total correct: 639 (54.24%)
        Eval Loss: 0.9136, Eval Accuracy: 0.5424
___ Current Batch 156/313 _________________________
    Test Results:
        Correct positive: 982 (83.57%), Correct negative: 790 (67.23%)
        Total correct: 1772 (75.40%)
        Test Loss: 0.4138, Test Accuracy: 0.7540
    Eval Results:
        Correct positive: 500 (84.89%), Correct negative: 120 (20.37%)
        Total correct: 620 (52.63%)
        Eval Loss: 0.9694, Eval Accuracy: 0.5263


Epoch 6/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 908 (77.28%), Correct negative: 887 (75.49%)
        Total correct: 1795 (76.38%)
        Test Loss: 0.3787, Test Accuracy: 0.7638
    Eval Results:
        Correct positive: 435 (73.85%), Correct negative: 128 (21.73%)
        Total correct: 563 (47.79%)
        Eval Loss: 0.9634, Eval Accuracy: 0.4779
___ Current Batch 156/313 _________________________
    Test Results:
        Correct positive: 927 (78.89%), Correct negative: 865 (73.62%)
        Total correct: 1792 (76.26%)
        Test Loss: 0.3945, Test Accuracy: 0.7626
    Eval Results:
        Correct positive: 511 (86.76%), Correct negative: 229 (38.88%)
        Total correct: 740 (62.82%)
        Eval Loss: 0.7496, Eval Accuracy: 0.6282


Epoch 7/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 855 (72.77%), Correct negative: 978 (83.23%)
        Total correct: 1833 (78.00%)
        Test Loss: 0.3226, Test Accuracy: 0.7800
    Eval Results:
        Correct positive: 407 (69.10%), Correct negative: 106 (18.00%)
        Total correct: 513 (43.55%)
        Eval Loss: 1.1342, Eval Accuracy: 0.4355
___ Current Batch 156/313 _________________________
    Test Results:
        Correct positive: 954 (81.19%), Correct negative: 864 (73.53%)
        Total correct: 1818 (77.36%)
        Test Loss: 0.3522, Test Accuracy: 0.7736
    Eval Results:
        Correct positive: 443 (75.21%), Correct negative: 65 (11.04%)
        Total correct: 508 (43.12%)
        Eval Loss: 1.1560, Eval Accuracy: 0.4312


Epoch 8/10:   0%|          | 0/313 [00:00<?, ?it/s]

___ Current Batch 1/313 _________________________
    Test Results:
        Correct positive: 929 (79.06%), Correct negative: 939 (79.91%)
        Total correct: 1868 (79.49%)
        Test Loss: 0.3141, Test Accuracy: 0.7949
    Eval Results:
        Correct positive: 505 (85.74%), Correct negative: 55 (9.34%)
        Total correct: 560 (47.54%)
        Eval Loss: 1.0793, Eval Accuracy: 0.4754


KeyboardInterrupt: 