In [1]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

from collections import defaultdict
import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_ABSTRACT

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, abstract edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (abstract) full_emb linear_layer dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarAbstract for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_abstract_small_graph', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
#optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-05 08:23:52,705 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-05 08:23:52,706 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Could not load triplets from file. Generating triplets...
Checking data validity...
Out of 8865 checked papers, 3301 are valid and 5564 are invalid.
Generating hard triplets ...




Too few triplets generated: 2575. Generating more triplets...
Total triplets generated: 5150. Done.
Generated 5150 triplets.
Saving triplets...
Triplets saved.
Train size: 4377, Test size: 515, Eval size: 258


In [4]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save model if loss has decreased
    if len(results['test_total_loss']) > 1 and results['test_total_loss'][-1] < min(results['test_total_loss'][:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 515 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 515 (50.00%)
        Test/Eval Loss: 1.0172, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
    Eval Results:
        Correct positive: 258 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 258 (50.00%)
        Test/Eval Loss: 1.0544, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 434 (84.27%), Correct negative: 391 (75.92%)
        Total correct: 825 (80.10%)
        Test/Eval Loss: 0.3927, Test/Eval Accuracy: 0.8010
        Precision: 0.7778, Recall: 0.8427, F1: 0.8089
    Eval Results:
        Correct positive: 181 (70.16%), Correct negative: 150 (58.14%)
        Total correct: 331 (64.15%)
        Test/Eval Loss: 0.5969, Test/Eval Accuracy: 0.6415
     

Epoch 2/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 467 (90.68%), Correct negative: 313 (60.78%)
        Total correct: 780 (75.73%)
        Test/Eval Loss: 0.4335, Test/Eval Accuracy: 0.7573
        Precision: 0.6981, Recall: 0.9068, F1: 0.7889
    Eval Results:
        Correct positive: 180 (69.77%), Correct negative: 165 (63.95%)
        Total correct: 345 (66.86%)
        Test/Eval Loss: 0.5700, Test/Eval Accuracy: 0.6686
        Precision: 0.6593, Recall: 0.6977, F1: 0.6780
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 472 (91.65%), Correct negative: 332 (64.47%)
        Total correct: 804 (78.06%)
        Test/Eval Loss: 0.3971, Test/Eval Accuracy: 0.7806
        Precision: 0.7206, Recall: 0.9165, F1: 0.8068
    Eval Results:
        Correct positive: 183 (70.93%), Correct negative: 210 (81.40%)
        Total correct: 393 (76.16%)
        Test/Eval Loss: 0.4407, Test/Eval Accuracy: 0.7616
 

Epoch 3/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 481 (93.40%), Correct negative: 328 (63.69%)
        Total correct: 809 (78.54%)
        Test/Eval Loss: 0.3793, Test/Eval Accuracy: 0.7854
        Precision: 0.7201, Recall: 0.9340, F1: 0.8132
    Eval Results:
        Correct positive: 193 (74.81%), Correct negative: 165 (63.95%)
        Total correct: 358 (69.38%)
        Test/Eval Loss: 0.5137, Test/Eval Accuracy: 0.6938
        Precision: 0.6748, Recall: 0.7481, F1: 0.7096
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 489 (94.95%), Correct negative: 282 (54.76%)
        Total correct: 771 (74.85%)
        Test/Eval Loss: 0.4602, Test/Eval Accuracy: 0.7485
        Precision: 0.6773, Recall: 0.9495, F1: 0.7906
    Eval Results:
        Correct positive: 201 (77.91%), Correct negative: 163 (63.18%)
        Total correct: 364 (70.54%)
        Test/Eval Loss: 0.4741, Test/Eval Accuracy: 0.7054
 

Epoch 4/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 482 (93.59%), Correct negative: 313 (60.78%)
        Total correct: 795 (77.18%)
        Test/Eval Loss: 0.4454, Test/Eval Accuracy: 0.7718
        Precision: 0.7047, Recall: 0.9359, F1: 0.8040
    Eval Results:
        Correct positive: 208 (80.62%), Correct negative: 166 (64.34%)
        Total correct: 374 (72.48%)
        Test/Eval Loss: 0.4610, Test/Eval Accuracy: 0.7248
        Precision: 0.6933, Recall: 0.8062, F1: 0.7455
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 477 (92.62%), Correct negative: 367 (71.26%)
        Total correct: 844 (81.94%)
        Test/Eval Loss: 0.3147, Test/Eval Accuracy: 0.8194
        Precision: 0.7632, Recall: 0.9262, F1: 0.8368
    Eval Results:
        Correct positive: 185 (71.71%), Correct negative: 203 (78.68%)
        Total correct: 388 (75.19%)
        Test/Eval Loss: 0.4513, Test/Eval Accuracy: 0.7519
 

Epoch 5/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 491 (95.34%), Correct negative: 303 (58.83%)
        Total correct: 794 (77.09%)
        Test/Eval Loss: 0.3865, Test/Eval Accuracy: 0.7709
        Precision: 0.6984, Recall: 0.9534, F1: 0.8062
    Eval Results:
        Correct positive: 199 (77.13%), Correct negative: 201 (77.91%)
        Total correct: 400 (77.52%)
        Test/Eval Loss: 0.3691, Test/Eval Accuracy: 0.7752
        Precision: 0.7773, Recall: 0.7713, F1: 0.7743
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 490 (95.15%), Correct negative: 299 (58.06%)
        Total correct: 789 (76.60%)
        Test/Eval Loss: 0.4170, Test/Eval Accuracy: 0.7660
        Precision: 0.6941, Recall: 0.9515, F1: 0.8026
    Eval Results:
        Correct positive: 183 (70.93%), Correct negative: 208 (80.62%)
        Total correct: 391 (75.78%)
        Test/Eval Loss: 0.4727, Test/Eval Accuracy: 0.7578
 

Epoch 6/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 497 (96.50%), Correct negative: 280 (54.37%)
        Total correct: 777 (75.44%)
        Test/Eval Loss: 0.4225, Test/Eval Accuracy: 0.7544
        Precision: 0.6790, Recall: 0.9650, F1: 0.7971
    Eval Results:
        Correct positive: 220 (85.27%), Correct negative: 191 (74.03%)
        Total correct: 411 (79.65%)
        Test/Eval Loss: 0.4008, Test/Eval Accuracy: 0.7965
        Precision: 0.7666, Recall: 0.8527, F1: 0.8073
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 483 (93.79%), Correct negative: 354 (68.74%)
        Total correct: 837 (81.26%)
        Test/Eval Loss: 0.3330, Test/Eval Accuracy: 0.8126
        Precision: 0.7500, Recall: 0.9379, F1: 0.8335
    Eval Results:
        Correct positive: 221 (85.66%), Correct negative: 183 (70.93%)
        Total correct: 404 (78.29%)
        Test/Eval Loss: 0.2988, Test/Eval Accuracy: 0.7829
 

Epoch 7/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 483 (93.79%), Correct negative: 305 (59.22%)
        Total correct: 788 (76.50%)
        Test/Eval Loss: 0.4062, Test/Eval Accuracy: 0.7650
        Precision: 0.6970, Recall: 0.9379, F1: 0.7997
    Eval Results:
        Correct positive: 224 (86.82%), Correct negative: 174 (67.44%)
        Total correct: 398 (77.13%)
        Test/Eval Loss: 0.4194, Test/Eval Accuracy: 0.7713
        Precision: 0.7273, Recall: 0.8682, F1: 0.7915
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 475 (92.23%), Correct negative: 304 (59.03%)
        Total correct: 779 (75.63%)
        Test/Eval Loss: 0.4108, Test/Eval Accuracy: 0.7563
        Precision: 0.6924, Recall: 0.9223, F1: 0.7910
    Eval Results:
        Correct positive: 209 (81.01%), Correct negative: 170 (65.89%)
        Total correct: 379 (73.45%)
        Test/Eval Loss: 0.5138, Test/Eval Accuracy: 0.7345
 

Epoch 8/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 476 (92.43%), Correct negative: 349 (67.77%)
        Total correct: 825 (80.10%)
        Test/Eval Loss: 0.3316, Test/Eval Accuracy: 0.8010
        Precision: 0.7414, Recall: 0.9243, F1: 0.8228
    Eval Results:
        Correct positive: 193 (74.81%), Correct negative: 163 (63.18%)
        Total correct: 356 (68.99%)
        Test/Eval Loss: 0.5075, Test/Eval Accuracy: 0.6899
        Precision: 0.6701, Recall: 0.7481, F1: 0.7070
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 485 (94.17%), Correct negative: 291 (56.50%)
        Total correct: 776 (75.34%)
        Test/Eval Loss: 0.3882, Test/Eval Accuracy: 0.7534
        Precision: 0.6841, Recall: 0.9417, F1: 0.7925
    Eval Results:
        Correct positive: 199 (77.13%), Correct negative: 162 (62.79%)
        Total correct: 361 (69.96%)
        Test/Eval Loss: 0.5287, Test/Eval Accuracy: 0.6996
 

Epoch 9/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 468 (90.87%), Correct negative: 356 (69.13%)
        Total correct: 824 (80.00%)
        Test/Eval Loss: 0.4126, Test/Eval Accuracy: 0.8000
        Precision: 0.7464, Recall: 0.9087, F1: 0.8196
    Eval Results:
        Correct positive: 191 (74.03%), Correct negative: 192 (74.42%)
        Total correct: 383 (74.22%)
        Test/Eval Loss: 0.5208, Test/Eval Accuracy: 0.7422
        Precision: 0.7432, Recall: 0.7403, F1: 0.7417
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 482 (93.59%), Correct negative: 347 (67.38%)
        Total correct: 829 (80.49%)
        Test/Eval Loss: 0.3318, Test/Eval Accuracy: 0.8049
        Precision: 0.7415, Recall: 0.9359, F1: 0.8275
    Eval Results:
        Correct positive: 210 (81.40%), Correct negative: 165 (63.95%)
        Total correct: 375 (72.67%)
        Test/Eval Loss: 0.4602, Test/Eval Accuracy: 0.7267
 

Epoch 10/10:   0%|          | 0/137 [00:00<?, ?it/s]

___ Current Batch 1/137 _________________________
    Test Results:
        Correct positive: 499 (96.89%), Correct negative: 296 (57.48%)
        Total correct: 795 (77.18%)
        Test/Eval Loss: 0.4240, Test/Eval Accuracy: 0.7718
        Precision: 0.6950, Recall: 0.9689, F1: 0.8094
    Eval Results:
        Correct positive: 230 (89.15%), Correct negative: 158 (61.24%)
        Total correct: 388 (75.19%)
        Test/Eval Loss: 0.4837, Test/Eval Accuracy: 0.7519
        Precision: 0.6970, Recall: 0.8915, F1: 0.7823
___ Current Batch 68/137 _________________________
    Test Results:
        Correct positive: 485 (94.17%), Correct negative: 355 (68.93%)
        Total correct: 840 (81.55%)
        Test/Eval Loss: 0.3399, Test/Eval Accuracy: 0.8155
        Precision: 0.7519, Recall: 0.9417, F1: 0.8362
    Eval Results:
        Correct positive: 235 (91.09%), Correct negative: 171 (66.28%)
        Total correct: 406 (78.68%)
        Test/Eval Loss: 0.4231, Test/Eval Accuracy: 0.7868
 