In [1]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

from collections import defaultdict
import random
import numpy as np
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_AUTHOR

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, sim_author edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (similar co-authors) full_emb linear_layer dropout small_graph'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarAuthor for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_small_graph_simauthor', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]

model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-05 08:20:37,381 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-05 08:20:37,382 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Could not load triplets from file. Generating triplets...
Checking data validity...
Out of 8865 checked papers, 5525 are valid and 3340 are invalid.
Generating hard triplets ...




Total triplets generated: 5026. Done.
Generated 5026 triplets.
Saving triplets...
Triplets saved.
Train size: 4272, Test size: 502, Eval size: 252


In [4]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save model if loss has decreased
    if len(results['test_total_loss']) > 1 and results['test_total_loss'][-1] < min(results['test_total_loss'][:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 502 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 502 (50.00%)
        Test/Eval Loss: 1.0058, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
    Eval Results:
        Correct positive: 252 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 252 (50.00%)
        Test/Eval Loss: 0.9913, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 410 (81.67%), Correct negative: 303 (60.36%)
        Total correct: 713 (71.02%)
        Test/Eval Loss: 0.5516, Test/Eval Accuracy: 0.7102
        Precision: 0.6732, Recall: 0.8167, F1: 0.7381
    Eval Results:
        Correct positive: 199 (78.97%), Correct negative: 102 (40.48%)
        Total correct: 301 (59.72%)
        Test/Eval Loss: 0.7934, Test/Eval Accuracy: 0.5972
     

Epoch 2/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 405 (80.68%), Correct negative: 315 (62.75%)
        Total correct: 720 (71.71%)
        Test/Eval Loss: 0.5354, Test/Eval Accuracy: 0.7171
        Precision: 0.6841, Recall: 0.8068, F1: 0.7404
    Eval Results:
        Correct positive: 104 (41.27%), Correct negative: 115 (45.63%)
        Total correct: 219 (43.45%)
        Test/Eval Loss: 1.1212, Test/Eval Accuracy: 0.4345
        Precision: 0.4315, Recall: 0.4127, F1: 0.4219
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 416 (82.87%), Correct negative: 353 (70.32%)
        Total correct: 769 (76.59%)
        Test/Eval Loss: 0.4582, Test/Eval Accuracy: 0.7659
        Precision: 0.7363, Recall: 0.8287, F1: 0.7798
    Eval Results:
        Correct positive: 150 (59.52%), Correct negative: 126 (50.00%)
        Total correct: 276 (54.76%)
        Test/Eval Loss: 0.9172, Test/Eval Accuracy: 0.5476
 

Epoch 3/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 419 (83.47%), Correct negative: 336 (66.93%)
        Total correct: 755 (75.20%)
        Test/Eval Loss: 0.4612, Test/Eval Accuracy: 0.7520
        Precision: 0.7162, Recall: 0.8347, F1: 0.7709
    Eval Results:
        Correct positive: 194 (76.98%), Correct negative: 121 (48.02%)
        Total correct: 315 (62.50%)
        Test/Eval Loss: 0.8221, Test/Eval Accuracy: 0.6250
        Precision: 0.5969, Recall: 0.7698, F1: 0.6724
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 427 (85.06%), Correct negative: 302 (60.16%)
        Total correct: 729 (72.61%)
        Test/Eval Loss: 0.5206, Test/Eval Accuracy: 0.7261
        Precision: 0.6810, Recall: 0.8506, F1: 0.7564
    Eval Results:
        Correct positive: 115 (45.63%), Correct negative: 141 (55.95%)
        Total correct: 256 (50.79%)
        Test/Eval Loss: 1.0069, Test/Eval Accuracy: 0.5079
 

Epoch 4/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 443 (88.25%), Correct negative: 312 (62.15%)
        Total correct: 755 (75.20%)
        Test/Eval Loss: 0.5058, Test/Eval Accuracy: 0.7520
        Precision: 0.6998, Recall: 0.8825, F1: 0.7806
    Eval Results:
        Correct positive: 175 (69.44%), Correct negative: 121 (48.02%)
        Total correct: 296 (58.73%)
        Test/Eval Loss: 0.8560, Test/Eval Accuracy: 0.5873
        Precision: 0.5719, Recall: 0.6944, F1: 0.6272
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 453 (90.24%), Correct negative: 316 (62.95%)
        Total correct: 769 (76.59%)
        Test/Eval Loss: 0.4188, Test/Eval Accuracy: 0.7659
        Precision: 0.7089, Recall: 0.9024, F1: 0.7940
    Eval Results:
        Correct positive: 139 (55.16%), Correct negative: 136 (53.97%)
        Total correct: 275 (54.56%)
        Test/Eval Loss: 0.9393, Test/Eval Accuracy: 0.5456
 

Epoch 5/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 449 (89.44%), Correct negative: 326 (64.94%)
        Total correct: 775 (77.19%)
        Test/Eval Loss: 0.4303, Test/Eval Accuracy: 0.7719
        Precision: 0.7184, Recall: 0.8944, F1: 0.7968
    Eval Results:
        Correct positive: 195 (77.38%), Correct negative: 156 (61.90%)
        Total correct: 351 (69.64%)
        Test/Eval Loss: 0.5698, Test/Eval Accuracy: 0.6964
        Precision: 0.6701, Recall: 0.7738, F1: 0.7182
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 445 (88.65%), Correct negative: 344 (68.53%)
        Total correct: 789 (78.59%)
        Test/Eval Loss: 0.3961, Test/Eval Accuracy: 0.7859
        Precision: 0.7380, Recall: 0.8865, F1: 0.8054
    Eval Results:
        Correct positive: 139 (55.16%), Correct negative: 213 (84.52%)
        Total correct: 352 (69.84%)
        Test/Eval Loss: 0.5533, Test/Eval Accuracy: 0.6984
 

Epoch 6/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 462 (92.03%), Correct negative: 286 (56.97%)
        Total correct: 748 (74.50%)
        Test/Eval Loss: 0.4659, Test/Eval Accuracy: 0.7450
        Precision: 0.6814, Recall: 0.9203, F1: 0.7831
    Eval Results:
        Correct positive: 169 (67.06%), Correct negative: 221 (87.70%)
        Total correct: 390 (77.38%)
        Test/Eval Loss: 0.3929, Test/Eval Accuracy: 0.7738
        Precision: 0.8450, Recall: 0.6706, F1: 0.7478
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 442 (88.05%), Correct negative: 302 (60.16%)
        Total correct: 744 (74.10%)
        Test/Eval Loss: 0.4819, Test/Eval Accuracy: 0.7410
        Precision: 0.6885, Recall: 0.8805, F1: 0.7727
    Eval Results:
        Correct positive: 147 (58.33%), Correct negative: 152 (60.32%)
        Total correct: 299 (59.33%)
        Test/Eval Loss: 0.8459, Test/Eval Accuracy: 0.5933
 

Epoch 7/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 466 (92.83%), Correct negative: 294 (58.57%)
        Total correct: 760 (75.70%)
        Test/Eval Loss: 0.4570, Test/Eval Accuracy: 0.7570
        Precision: 0.6914, Recall: 0.9283, F1: 0.7925
    Eval Results:
        Correct positive: 207 (82.14%), Correct negative: 83 (32.94%)
        Total correct: 290 (57.54%)
        Test/Eval Loss: 0.8486, Test/Eval Accuracy: 0.5754
        Precision: 0.5505, Recall: 0.8214, F1: 0.6592
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 456 (90.84%), Correct negative: 284 (56.57%)
        Total correct: 740 (73.71%)
        Test/Eval Loss: 0.4986, Test/Eval Accuracy: 0.7371
        Precision: 0.6766, Recall: 0.9084, F1: 0.7755
    Eval Results:
        Correct positive: 209 (82.94%), Correct negative: 100 (39.68%)
        Total correct: 309 (61.31%)
        Test/Eval Loss: 0.7379, Test/Eval Accuracy: 0.6131
  

Epoch 8/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 463 (92.23%), Correct negative: 303 (60.36%)
        Total correct: 766 (76.29%)
        Test/Eval Loss: 0.4419, Test/Eval Accuracy: 0.7629
        Precision: 0.6994, Recall: 0.9223, F1: 0.7955
    Eval Results:
        Correct positive: 107 (42.46%), Correct negative: 216 (85.71%)
        Total correct: 323 (64.09%)
        Test/Eval Loss: 0.6279, Test/Eval Accuracy: 0.6409
        Precision: 0.7483, Recall: 0.4246, F1: 0.5418
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 468 (93.23%), Correct negative: 293 (58.37%)
        Total correct: 761 (75.80%)
        Test/Eval Loss: 0.4268, Test/Eval Accuracy: 0.7580
        Precision: 0.6913, Recall: 0.9323, F1: 0.7939
    Eval Results:
        Correct positive: 226 (89.68%), Correct negative: 179 (71.03%)
        Total correct: 405 (80.36%)
        Test/Eval Loss: 0.3956, Test/Eval Accuracy: 0.8036
 

Epoch 9/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 473 (94.22%), Correct negative: 267 (53.19%)
        Total correct: 740 (73.71%)
        Test/Eval Loss: 0.4527, Test/Eval Accuracy: 0.7371
        Precision: 0.6681, Recall: 0.9422, F1: 0.7818
    Eval Results:
        Correct positive: 190 (75.40%), Correct negative: 129 (51.19%)
        Total correct: 319 (63.29%)
        Test/Eval Loss: 0.6987, Test/Eval Accuracy: 0.6329
        Precision: 0.6070, Recall: 0.7540, F1: 0.6726
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 459 (91.43%), Correct negative: 297 (59.16%)
        Total correct: 756 (75.30%)
        Test/Eval Loss: 0.4794, Test/Eval Accuracy: 0.7530
        Precision: 0.6913, Recall: 0.9143, F1: 0.7873
    Eval Results:
        Correct positive: 204 (80.95%), Correct negative: 118 (46.83%)
        Total correct: 322 (63.89%)
        Test/Eval Loss: 0.7073, Test/Eval Accuracy: 0.6389
 

Epoch 10/10:   0%|          | 0/134 [00:00<?, ?it/s]

___ Current Batch 1/134 _________________________
    Test Results:
        Correct positive: 454 (90.44%), Correct negative: 331 (65.94%)
        Total correct: 785 (78.19%)
        Test/Eval Loss: 0.4034, Test/Eval Accuracy: 0.7819
        Precision: 0.7264, Recall: 0.9044, F1: 0.8057
    Eval Results:
        Correct positive: 169 (67.06%), Correct negative: 231 (91.67%)
        Total correct: 400 (79.37%)
        Test/Eval Loss: 0.4965, Test/Eval Accuracy: 0.7937
        Precision: 0.8895, Recall: 0.6706, F1: 0.7647
___ Current Batch 67/134 _________________________
    Test Results:
        Correct positive: 462 (92.03%), Correct negative: 316 (62.95%)
        Total correct: 778 (77.49%)
        Test/Eval Loss: 0.4306, Test/Eval Accuracy: 0.7749
        Precision: 0.7130, Recall: 0.9203, F1: 0.8035
    Eval Results:
        Correct positive: 172 (68.25%), Correct negative: 159 (63.10%)
        Total correct: 331 (65.67%)
        Test/Eval Loss: 0.6669, Test/Eval Accuracy: 0.6567
 