In [1]:
from notebooks.util_homogeneous import HomogeneousGraphTripletDataset
from util_homogeneous import *
from util_heterogeneous import plot_loss, save_training_results
from training_homogeneous import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_TITLE

node_properties = [
    'id',
    'feature_vec',
]

database = 'homogeneous-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on homogeneous graph (publication nodes with title and abstract, title edges) using Triplet Loss and full embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATEncoderLinearDropout
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (title) full_emb linear_layer dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarTitle for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_title', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-11-17 14:12:26,699 - DatabaseWrapper - INFO - Connecting to the database ...
2024-11-17 14:12:26,700 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Could not load triplets from file. Generating triplets...
Checking data validity...
Out of 20034 checked papers, 10566 are valid and 9468 are invalid.
Generating hard triplets ...




Total triplets generated: 9597. Done.
Generated 9597 triplets.
Saving triplets...
Triplets saved.
Train size: 8157, Test size: 959, Eval size: 481


In [4]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(
                test_accuracies,
                epoch_len=2,
                plot_title='Test Accuracy',
                plot_avg=False, 
                x_label='Test Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(
                eval_accuracies, 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                plot_avg=False, 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 959 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 959 (50.00%)
        Test Loss: 1.0276, Test Accuracy: 0.5000
    Eval Results:
        Correct positive: 481 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 481 (50.00%)
        Eval Loss: 0.9973, Eval Accuracy: 0.5000
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 656 (68.40%), Correct negative: 869 (90.62%)
        Total correct: 1525 (79.51%)
        Test Loss: 0.2479, Test Accuracy: 0.7951
    Eval Results:
        Correct positive: 285 (59.25%), Correct negative: 152 (31.60%)
        Total correct: 437 (45.43%)
        Eval Loss: 1.1198, Eval Accuracy: 0.4543
Saving model at epoch 1...


Epoch 2/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 670 (69.86%), Correct negative: 858 (89.47%)
        Total correct: 1528 (79.67%)
        Test Loss: 0.2144, Test Accuracy: 0.7967
    Eval Results:
        Correct positive: 257 (53.43%), Correct negative: 185 (38.46%)
        Total correct: 442 (45.95%)
        Eval Loss: 1.2449, Eval Accuracy: 0.4595
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 791 (82.48%), Correct negative: 840 (87.59%)
        Total correct: 1631 (85.04%)
        Test Loss: 0.2009, Test Accuracy: 0.8504
    Eval Results:
        Correct positive: 347 (72.14%), Correct negative: 128 (26.61%)
        Total correct: 475 (49.38%)
        Eval Loss: 1.0519, Eval Accuracy: 0.4938
Saving model at epoch 2...


Epoch 3/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 717 (74.77%), Correct negative: 864 (90.09%)
        Total correct: 1581 (82.43%)
        Test Loss: 0.2016, Test Accuracy: 0.8243
    Eval Results:
        Correct positive: 311 (64.66%), Correct negative: 209 (43.45%)
        Total correct: 520 (54.05%)
        Eval Loss: 1.0383, Eval Accuracy: 0.5405
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 698 (72.78%), Correct negative: 874 (91.14%)
        Total correct: 1572 (81.96%)
        Test Loss: 0.1958, Test Accuracy: 0.8196
    Eval Results:
        Correct positive: 382 (79.42%), Correct negative: 78 (16.22%)
        Total correct: 460 (47.82%)
        Eval Loss: 1.1402, Eval Accuracy: 0.4782
Saving model at epoch 3...


Epoch 4/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 623 (64.96%), Correct negative: 902 (94.06%)
        Total correct: 1525 (79.51%)
        Test Loss: 0.1621, Test Accuracy: 0.7951
    Eval Results:
        Correct positive: 322 (66.94%), Correct negative: 141 (29.31%)
        Total correct: 463 (48.13%)
        Eval Loss: 1.0928, Eval Accuracy: 0.4813
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 595 (62.04%), Correct negative: 897 (93.53%)
        Total correct: 1492 (77.79%)
        Test Loss: 0.1912, Test Accuracy: 0.7779
    Eval Results:
        Correct positive: 267 (55.51%), Correct negative: 345 (71.73%)
        Total correct: 612 (63.62%)
        Eval Loss: 0.7728, Eval Accuracy: 0.6362


Epoch 5/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 694 (72.37%), Correct negative: 900 (93.85%)
        Total correct: 1594 (83.11%)
        Test Loss: 0.1550, Test Accuracy: 0.8311
    Eval Results:
        Correct positive: 318 (66.11%), Correct negative: 227 (47.19%)
        Total correct: 545 (56.65%)
        Eval Loss: 0.8974, Eval Accuracy: 0.5665
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 817 (85.19%), Correct negative: 857 (89.36%)
        Total correct: 1674 (87.28%)
        Test Loss: 0.1692, Test Accuracy: 0.8728
    Eval Results:
        Correct positive: 351 (72.97%), Correct negative: 157 (32.64%)
        Total correct: 508 (52.81%)
        Eval Loss: 0.9523, Eval Accuracy: 0.5281


Epoch 6/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 722 (75.29%), Correct negative: 883 (92.08%)
        Total correct: 1605 (83.68%)
        Test Loss: 0.1780, Test Accuracy: 0.8368
    Eval Results:
        Correct positive: 328 (68.19%), Correct negative: 187 (38.88%)
        Total correct: 515 (53.53%)
        Eval Loss: 0.9175, Eval Accuracy: 0.5353
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 746 (77.79%), Correct negative: 896 (93.43%)
        Total correct: 1642 (85.61%)
        Test Loss: 0.1503, Test Accuracy: 0.8561
    Eval Results:
        Correct positive: 345 (71.73%), Correct negative: 173 (35.97%)
        Total correct: 518 (53.85%)
        Eval Loss: 0.9553, Eval Accuracy: 0.5385
Saving model at epoch 6...


Epoch 7/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 785 (81.86%), Correct negative: 876 (91.35%)
        Total correct: 1661 (86.60%)
        Test Loss: 0.1615, Test Accuracy: 0.8660
    Eval Results:
        Correct positive: 344 (71.52%), Correct negative: 209 (43.45%)
        Total correct: 553 (57.48%)
        Eval Loss: 0.8875, Eval Accuracy: 0.5748
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 714 (74.45%), Correct negative: 891 (92.91%)
        Total correct: 1605 (83.68%)
        Test Loss: 0.1417, Test Accuracy: 0.8368
    Eval Results:
        Correct positive: 332 (69.02%), Correct negative: 176 (36.59%)
        Total correct: 508 (52.81%)
        Eval Loss: 1.0107, Eval Accuracy: 0.5281
Saving model at epoch 7...


Epoch 8/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 589 (61.42%), Correct negative: 921 (96.04%)
        Total correct: 1510 (78.73%)
        Test Loss: 0.1330, Test Accuracy: 0.7873
    Eval Results:
        Correct positive: 248 (51.56%), Correct negative: 354 (73.60%)
        Total correct: 602 (62.58%)
        Eval Loss: 0.6703, Eval Accuracy: 0.6258
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 645 (67.26%), Correct negative: 913 (95.20%)
        Total correct: 1558 (81.23%)
        Test Loss: 0.1322, Test Accuracy: 0.8123
    Eval Results:
        Correct positive: 341 (70.89%), Correct negative: 321 (66.74%)
        Total correct: 662 (68.81%)
        Eval Loss: 0.6035, Eval Accuracy: 0.6881
Saving model at epoch 8...


Epoch 9/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 756 (78.83%), Correct negative: 899 (93.74%)
        Total correct: 1655 (86.29%)
        Test Loss: 0.1409, Test Accuracy: 0.8629
    Eval Results:
        Correct positive: 366 (76.09%), Correct negative: 202 (42.00%)
        Total correct: 568 (59.04%)
        Eval Loss: 0.8568, Eval Accuracy: 0.5904
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 716 (74.66%), Correct negative: 907 (94.58%)
        Total correct: 1623 (84.62%)
        Test Loss: 0.1199, Test Accuracy: 0.8462
    Eval Results:
        Correct positive: 330 (68.61%), Correct negative: 330 (68.61%)
        Total correct: 660 (68.61%)
        Eval Loss: 0.6637, Eval Accuracy: 0.6861
Saving model at epoch 9...


Epoch 10/10:   0%|          | 0/255 [00:00<?, ?it/s]

___ Current Batch 1/255 _________________________
    Test Results:
        Correct positive: 671 (69.97%), Correct negative: 904 (94.26%)
        Total correct: 1575 (82.12%)
        Test Loss: 0.1361, Test Accuracy: 0.8212
    Eval Results:
        Correct positive: 281 (58.42%), Correct negative: 372 (77.34%)
        Total correct: 653 (67.88%)
        Eval Loss: 0.5338, Eval Accuracy: 0.6788
___ Current Batch 127/255 _________________________
    Test Results:
        Correct positive: 707 (73.72%), Correct negative: 912 (95.10%)
        Total correct: 1619 (84.41%)
        Test Loss: 0.1254, Test Accuracy: 0.8441
    Eval Results:
        Correct positive: 399 (82.95%), Correct negative: 306 (63.62%)
        Total correct: 705 (73.28%)
        Eval Loss: 0.5624, Eval Accuracy: 0.7328
