In [1]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

from collections import defaultdict
import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_ABSTRACT

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, abstract edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 8,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder1Conv2Linear
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (abstract) full_emb linear_layer dropout 1_conv_layer 2_linear'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarAbstract for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_abstract_small_graph', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
num_eval_authors = 3
eval_papers = set()
train_data = WhoIsWhoDataset.parse_train()
for i, val in enumerate(train_data.values()):
    num_eval_authors -= 1
    for paper in val['normal_data']:
        eval_papers.add(paper)
    if num_eval_authors == 0:
        break
eval_triplets = []
for triplet in data_harvester.triplets:
    if triplet[0] in eval_papers or triplet[1] in eval_papers or triplet[2] in eval_papers:
        eval_triplets.append(triplet)

# Remove the evaluation triplets from the data harvester
train_test_triplets = [triplet for triplet in data_harvester.triplets if triplet not in eval_triplets]

random.shuffle(train_test_triplets)

train_size = int(0.85 * len(train_test_triplets))
test_size = len(train_test_triplets) - train_size

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
#optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-12 15:03:43,978 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-12 15:03:43,979 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 5150 triplets.
Train size: 4130, Test size: 729, Eval size: 291


In [4]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
            # Save model if loss has decreased
            if len(results['eval_accuracies']) > 1 and results['eval_accuracies'][-1] > max(results['eval_accuracies'][:-1]):
                print(f"Saving model at epoch {epoch}...")
                torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    



Epoch 1/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 729 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 729 (50.00%)
        Test/Eval Loss: 0.9923, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
    Eval Results:
        Correct positive: 291 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 291 (50.00%)
        Test/Eval Loss: 0.9911, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 572 (78.46%), Correct negative: 422 (57.89%)
        Total correct: 994 (68.18%)
        Test/Eval Loss: 0.5142, Test/Eval Accuracy: 0.6818
        Precision: 0.6507, Recall: 0.7846, F1: 0.7114
    Eval Results:
        Correct positive: 241 (82.82%), Correct negative: 277 (95.19%)
        Total correct: 518 (89.00%)
        Test/Eval Loss: 0.2273, Test/Eval Accuracy: 0.8900
     

Epoch 2/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 603 (82.72%), Correct negative: 574 (78.74%)
        Total correct: 1177 (80.73%)
        Test/Eval Loss: 0.3715, Test/Eval Accuracy: 0.8073
        Precision: 0.7955, Recall: 0.8272, F1: 0.8110
    Eval Results:
        Correct positive: 241 (82.82%), Correct negative: 277 (95.19%)
        Total correct: 518 (89.00%)
        Test/Eval Loss: 0.2295, Test/Eval Accuracy: 0.8900
        Precision: 0.9451, Recall: 0.8282, F1: 0.8828
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 625 (85.73%), Correct negative: 496 (68.04%)
        Total correct: 1121 (76.89%)
        Test/Eval Loss: 0.3548, Test/Eval Accuracy: 0.7689
        Precision: 0.7284, Recall: 0.8573, F1: 0.7876
    Eval Results:
        Correct positive: 240 (82.47%), Correct negative: 281 (96.56%)
        Total correct: 521 (89.52%)
        Test/Eval Loss: 0.2157, Test/Eval Accuracy: 0.8952

Epoch 3/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 636 (87.24%), Correct negative: 474 (65.02%)
        Total correct: 1110 (76.13%)
        Test/Eval Loss: 0.3903, Test/Eval Accuracy: 0.7613
        Precision: 0.7138, Recall: 0.8724, F1: 0.7852
    Eval Results:
        Correct positive: 241 (82.82%), Correct negative: 277 (95.19%)
        Total correct: 518 (89.00%)
        Test/Eval Loss: 0.2318, Test/Eval Accuracy: 0.8900
        Precision: 0.9451, Recall: 0.8282, F1: 0.8828
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 639 (87.65%), Correct negative: 543 (74.49%)
        Total correct: 1182 (81.07%)
        Test/Eval Loss: 0.3448, Test/Eval Accuracy: 0.8107
        Precision: 0.7745, Recall: 0.8765, F1: 0.8224
    Eval Results:
        Correct positive: 243 (83.51%), Correct negative: 149 (51.20%)
        Total correct: 392 (67.35%)
        Test/Eval Loss: 0.3710, Test/Eval Accuracy: 0.6735

Epoch 4/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 639 (87.65%), Correct negative: 493 (67.63%)
        Total correct: 1132 (77.64%)
        Test/Eval Loss: 0.3622, Test/Eval Accuracy: 0.7764
        Precision: 0.7303, Recall: 0.8765, F1: 0.7968
    Eval Results:
        Correct positive: 243 (83.51%), Correct negative: 129 (44.33%)
        Total correct: 372 (63.92%)
        Test/Eval Loss: 0.4182, Test/Eval Accuracy: 0.6392
        Precision: 0.6000, Recall: 0.8351, F1: 0.6983
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 637 (87.38%), Correct negative: 473 (64.88%)
        Total correct: 1110 (76.13%)
        Test/Eval Loss: 0.3877, Test/Eval Accuracy: 0.7613
        Precision: 0.7133, Recall: 0.8738, F1: 0.7855
    Eval Results:
        Correct positive: 242 (83.16%), Correct negative: 255 (87.63%)
        Total correct: 497 (85.40%)
        Test/Eval Loss: 0.2900, Test/Eval Accuracy: 0.8540

Epoch 5/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 680 (93.28%), Correct negative: 416 (57.06%)
        Total correct: 1096 (75.17%)
        Test/Eval Loss: 0.3626, Test/Eval Accuracy: 0.7517
        Precision: 0.6848, Recall: 0.9328, F1: 0.7898
    Eval Results:
        Correct positive: 260 (89.35%), Correct negative: 21 (7.22%)
        Total correct: 281 (48.28%)
        Test/Eval Loss: 0.6699, Test/Eval Accuracy: 0.4828
        Precision: 0.4906, Recall: 0.8935, F1: 0.6334
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 668 (91.63%), Correct negative: 459 (62.96%)
        Total correct: 1127 (77.30%)
        Test/Eval Loss: 0.3523, Test/Eval Accuracy: 0.7730
        Precision: 0.7122, Recall: 0.9163, F1: 0.8014
    Eval Results:
        Correct positive: 243 (83.51%), Correct negative: 183 (62.89%)
        Total correct: 426 (73.20%)
        Test/Eval Loss: 0.3831, Test/Eval Accuracy: 0.7320
 

Epoch 6/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 631 (86.56%), Correct negative: 574 (78.74%)
        Total correct: 1205 (82.65%)
        Test/Eval Loss: 0.3234, Test/Eval Accuracy: 0.8265
        Precision: 0.8028, Recall: 0.8656, F1: 0.8330
    Eval Results:
        Correct positive: 260 (89.35%), Correct negative: 19 (6.53%)
        Total correct: 279 (47.94%)
        Test/Eval Loss: 1.0000, Test/Eval Accuracy: 0.4794
        Precision: 0.4887, Recall: 0.8935, F1: 0.6318
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 683 (93.69%), Correct negative: 332 (45.54%)
        Total correct: 1015 (69.62%)
        Test/Eval Loss: 0.4621, Test/Eval Accuracy: 0.6962
        Precision: 0.6324, Recall: 0.9369, F1: 0.7551
    Eval Results:
        Correct positive: 239 (82.13%), Correct negative: 271 (93.13%)
        Total correct: 510 (87.63%)
        Test/Eval Loss: 0.2895, Test/Eval Accuracy: 0.8763
 

Epoch 7/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 663 (90.95%), Correct negative: 491 (67.35%)
        Total correct: 1154 (79.15%)
        Test/Eval Loss: 0.3522, Test/Eval Accuracy: 0.7915
        Precision: 0.7358, Recall: 0.9095, F1: 0.8135
    Eval Results:
        Correct positive: 260 (89.35%), Correct negative: 19 (6.53%)
        Total correct: 279 (47.94%)
        Test/Eval Loss: 0.6891, Test/Eval Accuracy: 0.4794
        Precision: 0.4887, Recall: 0.8935, F1: 0.6318
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 704 (96.57%), Correct negative: 351 (48.15%)
        Total correct: 1055 (72.36%)
        Test/Eval Loss: 0.4045, Test/Eval Accuracy: 0.7236
        Precision: 0.6506, Recall: 0.9657, F1: 0.7775
    Eval Results:
        Correct positive: 257 (88.32%), Correct negative: 17 (5.84%)
        Total correct: 274 (47.08%)
        Test/Eval Loss: 0.5868, Test/Eval Accuracy: 0.4708
   

Epoch 8/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 687 (94.24%), Correct negative: 389 (53.36%)
        Total correct: 1076 (73.80%)
        Test/Eval Loss: 0.4220, Test/Eval Accuracy: 0.7380
        Precision: 0.6689, Recall: 0.9424, F1: 0.7825
    Eval Results:
        Correct positive: 260 (89.35%), Correct negative: 21 (7.22%)
        Total correct: 281 (48.28%)
        Test/Eval Loss: 0.6714, Test/Eval Accuracy: 0.4828
        Precision: 0.4906, Recall: 0.8935, F1: 0.6334
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 705 (96.71%), Correct negative: 334 (45.82%)
        Total correct: 1039 (71.26%)
        Test/Eval Loss: 0.4380, Test/Eval Accuracy: 0.7126
        Precision: 0.6409, Recall: 0.9671, F1: 0.7709
    Eval Results:
        Correct positive: 260 (89.35%), Correct negative: 13 (4.47%)
        Total correct: 273 (46.91%)
        Test/Eval Loss: 0.6266, Test/Eval Accuracy: 0.4691
   

Epoch 9/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 705 (96.71%), Correct negative: 329 (45.13%)
        Total correct: 1034 (70.92%)
        Test/Eval Loss: 0.4822, Test/Eval Accuracy: 0.7092
        Precision: 0.6380, Recall: 0.9671, F1: 0.7688
    Eval Results:
        Correct positive: 251 (86.25%), Correct negative: 78 (26.80%)
        Total correct: 329 (56.53%)
        Test/Eval Loss: 0.5239, Test/Eval Accuracy: 0.5653
        Precision: 0.5409, Recall: 0.8625, F1: 0.6649
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 686 (94.10%), Correct negative: 442 (60.63%)
        Total correct: 1128 (77.37%)
        Test/Eval Loss: 0.3590, Test/Eval Accuracy: 0.7737
        Precision: 0.7050, Recall: 0.9410, F1: 0.8061
    Eval Results:
        Correct positive: 260 (89.35%), Correct negative: 21 (7.22%)
        Total correct: 281 (48.28%)
        Test/Eval Loss: 0.8794, Test/Eval Accuracy: 0.4828
  

Epoch 10/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 703 (96.43%), Correct negative: 365 (50.07%)
        Total correct: 1068 (73.25%)
        Test/Eval Loss: 0.3928, Test/Eval Accuracy: 0.7325
        Precision: 0.6589, Recall: 0.9643, F1: 0.7829
    Eval Results:
        Correct positive: 261 (89.69%), Correct negative: 13 (4.47%)
        Total correct: 274 (47.08%)
        Test/Eval Loss: 0.9626, Test/Eval Accuracy: 0.4708
        Precision: 0.4842, Recall: 0.8969, F1: 0.6289
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 698 (95.75%), Correct negative: 400 (54.87%)
        Total correct: 1098 (75.31%)
        Test/Eval Loss: 0.4266, Test/Eval Accuracy: 0.7531
        Precision: 0.6796, Recall: 0.9575, F1: 0.7950
    Eval Results:
        Correct positive: 260 (89.35%), Correct negative: 17 (5.84%)
        Total correct: 277 (47.59%)
        Test/Eval Loss: 0.5847, Test/Eval Accuracy: 0.4759
   