In [1]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

from collections import defaultdict
import random
import numpy as np
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_AUTHOR

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, sim_author edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 8,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (similar co-authors) full_emb linear_layer dropout small_graph low_dim'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarAuthor for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_small_graph_simauthor', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
num_eval_authors = 3
eval_papers = set()
train_data = WhoIsWhoDataset.parse_train()
for i, val in enumerate(train_data.values()):
    num_eval_authors -= 1
    for paper in val['normal_data']:
        eval_papers.add(paper)
    if num_eval_authors == 0:
        break
eval_triplets = []
for triplet in data_harvester.triplets:
    if triplet[0] in eval_papers or triplet[1] in eval_papers or triplet[2] in eval_papers:
        eval_triplets.append(triplet)

# Remove the evaluation triplets from the data harvester
train_test_triplets = [triplet for triplet in data_harvester.triplets if triplet not in eval_triplets]

random.shuffle(train_test_triplets)

train_size = int(0.85 * len(train_test_triplets))
test_size = len(train_test_triplets) - train_size

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]

model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-11 19:37:18,567 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-11 19:37:18,567 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 5026 triplets.
Train size: 3995, Test size: 706, Eval size: 325


In [4]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
            # Save model if loss has decreased
            if len(results['eval_accuracies']) > 1 and results['eval_accuracies'][-1] > max(results['eval_accuracies'][:-1]):
                print(f"Saving model at epoch {epoch}...")
                torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1



Epoch 1/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 706 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 706 (50.00%)
        Test/Eval Loss: 0.9895, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
    Eval Results:
        Correct positive: 325 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 325 (50.00%)
        Test/Eval Loss: 1.0156, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 516 (73.09%), Correct negative: 481 (68.13%)
        Total correct: 997 (70.61%)
        Test/Eval Loss: 0.5849, Test/Eval Accuracy: 0.7061
        Precision: 0.6964, Recall: 0.7309, F1: 0.7132
    Eval Results:
        Correct positive: 193 (59.38%), Correct negative: 211 (64.92%)
        Total correct: 404 (62.15%)
        Test/Eval Loss: 0.6696, Test/Eval Accuracy: 0.6215
     

Epoch 2/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 589 (83.43%), Correct negative: 442 (62.61%)
        Total correct: 1031 (73.02%)
        Test/Eval Loss: 0.5270, Test/Eval Accuracy: 0.7302
        Precision: 0.6905, Recall: 0.8343, F1: 0.7556
    Eval Results:
        Correct positive: 267 (82.15%), Correct negative: 119 (36.62%)
        Total correct: 386 (59.38%)
        Test/Eval Loss: 0.8170, Test/Eval Accuracy: 0.5938
        Precision: 0.5645, Recall: 0.8215, F1: 0.6692
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 613 (86.83%), Correct negative: 411 (58.22%)
        Total correct: 1024 (72.52%)
        Test/Eval Loss: 0.4987, Test/Eval Accuracy: 0.7252
        Precision: 0.6751, Recall: 0.8683, F1: 0.7596
    Eval Results:
        Correct positive: 237 (72.92%), Correct negative: 115 (35.38%)
        Total correct: 352 (54.15%)
        Test/Eval Loss: 0.8813, Test/Eval Accuracy: 0.5415

Epoch 3/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 628 (88.95%), Correct negative: 454 (64.31%)
        Total correct: 1082 (76.63%)
        Test/Eval Loss: 0.4431, Test/Eval Accuracy: 0.7663
        Precision: 0.7136, Recall: 0.8895, F1: 0.7919
    Eval Results:
        Correct positive: 226 (69.54%), Correct negative: 139 (42.77%)
        Total correct: 365 (56.15%)
        Test/Eval Loss: 0.7953, Test/Eval Accuracy: 0.5615
        Precision: 0.5485, Recall: 0.6954, F1: 0.6133
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 622 (88.10%), Correct negative: 409 (57.93%)
        Total correct: 1031 (73.02%)
        Test/Eval Loss: 0.4881, Test/Eval Accuracy: 0.7302
        Precision: 0.6768, Recall: 0.8810, F1: 0.7655
    Eval Results:
        Correct positive: 233 (71.69%), Correct negative: 142 (43.69%)
        Total correct: 375 (57.69%)
        Test/Eval Loss: 0.8324, Test/Eval Accuracy: 0.5769

Epoch 4/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 627 (88.81%), Correct negative: 402 (56.94%)
        Total correct: 1029 (72.88%)
        Test/Eval Loss: 0.5003, Test/Eval Accuracy: 0.7288
        Precision: 0.6735, Recall: 0.8881, F1: 0.7660
    Eval Results:
        Correct positive: 265 (81.54%), Correct negative: 128 (39.38%)
        Total correct: 393 (60.46%)
        Test/Eval Loss: 0.6966, Test/Eval Accuracy: 0.6046
        Precision: 0.5736, Recall: 0.8154, F1: 0.6734
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 639 (90.51%), Correct negative: 432 (61.19%)
        Total correct: 1071 (75.85%)
        Test/Eval Loss: 0.4391, Test/Eval Accuracy: 0.7585
        Precision: 0.6999, Recall: 0.9051, F1: 0.7894
    Eval Results:
        Correct positive: 209 (64.31%), Correct negative: 132 (40.62%)
        Total correct: 341 (52.46%)
        Test/Eval Loss: 0.9238, Test/Eval Accuracy: 0.5246

Epoch 5/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 648 (91.78%), Correct negative: 411 (58.22%)
        Total correct: 1059 (75.00%)
        Test/Eval Loss: 0.4442, Test/Eval Accuracy: 0.7500
        Precision: 0.6872, Recall: 0.9178, F1: 0.7859
    Eval Results:
        Correct positive: 245 (75.38%), Correct negative: 108 (33.23%)
        Total correct: 353 (54.31%)
        Test/Eval Loss: 0.8873, Test/Eval Accuracy: 0.5431
        Precision: 0.5303, Recall: 0.7538, F1: 0.6226
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 643 (91.08%), Correct negative: 386 (54.67%)
        Total correct: 1029 (72.88%)
        Test/Eval Loss: 0.4768, Test/Eval Accuracy: 0.7288
        Precision: 0.6677, Recall: 0.9108, F1: 0.7705
    Eval Results:
        Correct positive: 245 (75.38%), Correct negative: 95 (29.23%)
        Total correct: 340 (52.31%)
        Test/Eval Loss: 0.8740, Test/Eval Accuracy: 0.5231


Epoch 6/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 655 (92.78%), Correct negative: 389 (55.10%)
        Total correct: 1044 (73.94%)
        Test/Eval Loss: 0.4804, Test/Eval Accuracy: 0.7394
        Precision: 0.6739, Recall: 0.9278, F1: 0.7807
    Eval Results:
        Correct positive: 177 (54.46%), Correct negative: 83 (25.54%)
        Total correct: 260 (40.00%)
        Test/Eval Loss: 1.1309, Test/Eval Accuracy: 0.4000
        Precision: 0.4224, Recall: 0.5446, F1: 0.4758
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 641 (90.79%), Correct negative: 431 (61.05%)
        Total correct: 1072 (75.92%)
        Test/Eval Loss: 0.4357, Test/Eval Accuracy: 0.7592
        Precision: 0.6998, Recall: 0.9079, F1: 0.7904
    Eval Results:
        Correct positive: 124 (38.15%), Correct negative: 153 (47.08%)
        Total correct: 277 (42.62%)
        Test/Eval Loss: 1.2343, Test/Eval Accuracy: 0.4262


Epoch 7/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 601 (85.13%), Correct negative: 447 (63.31%)
        Total correct: 1048 (74.22%)
        Test/Eval Loss: 0.4475, Test/Eval Accuracy: 0.7422
        Precision: 0.6988, Recall: 0.8513, F1: 0.7676
    Eval Results:
        Correct positive: 175 (53.85%), Correct negative: 131 (40.31%)
        Total correct: 306 (47.08%)
        Test/Eval Loss: 1.0200, Test/Eval Accuracy: 0.4708
        Precision: 0.4743, Recall: 0.5385, F1: 0.5043
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 643 (91.08%), Correct negative: 341 (48.30%)
        Total correct: 984 (69.69%)
        Test/Eval Loss: 0.5344, Test/Eval Accuracy: 0.6969
        Precision: 0.6379, Recall: 0.9108, F1: 0.7503
    Eval Results:
        Correct positive: 116 (35.69%), Correct negative: 82 (25.23%)
        Total correct: 198 (30.46%)
        Test/Eval Loss: 1.4329, Test/Eval Accuracy: 0.3046
 

Epoch 8/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 673 (95.33%), Correct negative: 261 (36.97%)
        Total correct: 934 (66.15%)
        Test/Eval Loss: 0.5753, Test/Eval Accuracy: 0.6615
        Precision: 0.6020, Recall: 0.9533, F1: 0.7379
    Eval Results:
        Correct positive: 264 (81.23%), Correct negative: 54 (16.62%)
        Total correct: 318 (48.92%)
        Test/Eval Loss: 1.0323, Test/Eval Accuracy: 0.4892
        Precision: 0.4935, Recall: 0.8123, F1: 0.6140
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 660 (93.48%), Correct negative: 321 (45.47%)
        Total correct: 981 (69.48%)
        Test/Eval Loss: 0.5196, Test/Eval Accuracy: 0.6948
        Precision: 0.6316, Recall: 0.9348, F1: 0.7539
    Eval Results:
        Correct positive: 284 (87.38%), Correct negative: 66 (20.31%)
        Total correct: 350 (53.85%)
        Test/Eval Loss: 0.8382, Test/Eval Accuracy: 0.5385
   

Epoch 9/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 661 (93.63%), Correct negative: 359 (50.85%)
        Total correct: 1020 (72.24%)
        Test/Eval Loss: 0.4966, Test/Eval Accuracy: 0.7224
        Precision: 0.6558, Recall: 0.9363, F1: 0.7713
    Eval Results:
        Correct positive: 208 (64.00%), Correct negative: 115 (35.38%)
        Total correct: 323 (49.69%)
        Test/Eval Loss: 1.0020, Test/Eval Accuracy: 0.4969
        Precision: 0.4976, Recall: 0.6400, F1: 0.5599
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 672 (95.18%), Correct negative: 362 (51.27%)
        Total correct: 1034 (73.23%)
        Test/Eval Loss: 0.4781, Test/Eval Accuracy: 0.7323
        Precision: 0.6614, Recall: 0.9518, F1: 0.7805
    Eval Results:
        Correct positive: 220 (67.69%), Correct negative: 103 (31.69%)
        Total correct: 323 (49.69%)
        Test/Eval Loss: 1.0426, Test/Eval Accuracy: 0.4969

Epoch 10/10:   0%|          | 0/125 [00:00<?, ?it/s]

___ Current Batch 1/125 _________________________
    Test Results:
        Correct positive: 609 (86.26%), Correct negative: 466 (66.01%)
        Total correct: 1075 (76.13%)
        Test/Eval Loss: 0.4461, Test/Eval Accuracy: 0.7613
        Precision: 0.7173, Recall: 0.8626, F1: 0.7833
    Eval Results:
        Correct positive: 176 (54.15%), Correct negative: 124 (38.15%)
        Total correct: 300 (46.15%)
        Test/Eval Loss: 0.9620, Test/Eval Accuracy: 0.4615
        Precision: 0.4668, Recall: 0.5415, F1: 0.5014
___ Current Batch 62/125 _________________________
    Test Results:
        Correct positive: 639 (90.51%), Correct negative: 439 (62.18%)
        Total correct: 1078 (76.35%)
        Test/Eval Loss: 0.4259, Test/Eval Accuracy: 0.7635
        Precision: 0.7053, Recall: 0.9051, F1: 0.7928
    Eval Results:
        Correct positive: 145 (44.62%), Correct negative: 100 (30.77%)
        Total correct: 245 (37.69%)
        Test/Eval Loss: 1.2443, Test/Eval Accuracy: 0.3769