In [17]:
from collections import defaultdict

from util_homogeneous import *
from util import *
from training_homogeneous_classification_loss import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

## Configurations

In [18]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_ABSTRACT

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, org edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (abstract) full_emb linear_layer dropout dual_loss'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarAbstract for homogeneous graph sampling.


In [19]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_abstract_small_graph', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
num_eval_authors = 3
eval_papers = set()
train_data = WhoIsWhoDataset.parse_train()
for i, val in enumerate(train_data.values()):
    num_eval_authors -= 1
    for paper in val['normal_data']:
        eval_papers.add(paper)
    if num_eval_authors == 0:
        break
eval_triplets = []
for triplet in data_harvester.triplets:
    if triplet[0] in eval_papers or triplet[1] in eval_papers or triplet[2] in eval_papers:
        eval_triplets.append(triplet)

# Remove the evaluation triplets from the data harvester
train_test_triplets = [triplet for triplet in data_harvester.triplets if triplet not in eval_triplets]

random.shuffle(train_test_triplets)

train_size = int(0.85 * len(train_test_triplets))
test_size = len(train_test_triplets) - train_size

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-10 10:25:27,381 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-10 10:25:27,382 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 5150 triplets.
Train size: 4130, Test size: 729, Eval size: 291


In [20]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_triplet_loss, test_cross_entropy_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_triplet_loss'].append(test_triplet_loss)
            results['test_neg_contrastive_loss'].append(test_cross_entropy_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
    
            plot_losses(
                losses=[results['test_total_loss'], results['test_triplet_loss'], results['test_neg_contrastive_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Total Loss", "Triplet Loss", "Neg. Contrastive Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_triplet_loss, eval_cross_entropy_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_triplet_loss'].append(eval_triplet_loss)
            results['eval_neg_contrastive_loss'].append(eval_cross_entropy_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            
            plot_losses(
                losses=[results['eval_total_loss'], results['eval_triplet_loss'], results['eval_neg_contrastive_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Total Loss", "Triplet Loss", "Neg. Contrastive Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
        loss = train_dual_objective(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save model if loss has decreased
    if len(results['test_total_loss']) > 1 and results['test_total_loss'][-1] < min(results['test_total_loss'][:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 729 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 729 (50.00%)
        Test/Eval Loss: 1.2658, Test/Eval Accuracy: 0.5000
        Triplet Loss: 0.9929, Cross Entropy Loss: 0.2729
    Eval Results:
        Correct positive: 291 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 291 (50.00%)
        Test/Eval Loss: 1.1611, Test/Eval Accuracy: 0.5000
        Triplet Loss: 0.9403, Cross Entropy Loss: 0.2208
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 580 (79.56%), Correct negative: 503 (69.00%)
        Total correct: 1083 (74.28%)
        Test/Eval Loss: 0.5558, Test/Eval Accuracy: 0.7428
        Triplet Loss: 0.4634, Cross Entropy Loss: 0.0923
    Eval Results:
        Correct positive: 153 (52.58%), Correct negative: 272 (93.47%)
        Total correct: 425 (73.02%)
        Test/Eval Loss: 0.4271, Test/Eval Accuracy: 0.

Epoch 2/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 642 (88.07%), Correct negative: 443 (60.77%)
        Total correct: 1085 (74.42%)
        Test/Eval Loss: 0.5568, Test/Eval Accuracy: 0.7442
        Triplet Loss: 0.4379, Cross Entropy Loss: 0.1189
    Eval Results:
        Correct positive: 184 (63.23%), Correct negative: 272 (93.47%)
        Total correct: 456 (78.35%)
        Test/Eval Loss: 0.3978, Test/Eval Accuracy: 0.7835
        Triplet Loss: 0.3693, Cross Entropy Loss: 0.0285
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 614 (84.22%), Correct negative: 534 (73.25%)
        Total correct: 1148 (78.74%)
        Test/Eval Loss: 0.4966, Test/Eval Accuracy: 0.7874
        Triplet Loss: 0.3949, Cross Entropy Loss: 0.1017
    Eval Results:
        Correct positive: 157 (53.95%), Correct negative: 266 (91.41%)
        Total correct: 423 (72.68%)
        Test/Eval Loss: 0.4940, Test/Eval Accurac

Epoch 3/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 641 (87.93%), Correct negative: 473 (64.88%)
        Total correct: 1114 (76.41%)
        Test/Eval Loss: 0.5379, Test/Eval Accuracy: 0.7641
        Triplet Loss: 0.4087, Cross Entropy Loss: 0.1292
    Eval Results:
        Correct positive: 162 (55.67%), Correct negative: 275 (94.50%)
        Total correct: 437 (75.09%)
        Test/Eval Loss: 0.4044, Test/Eval Accuracy: 0.7509
        Triplet Loss: 0.3915, Cross Entropy Loss: 0.0130
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 621 (85.19%), Correct negative: 554 (75.99%)
        Total correct: 1175 (80.59%)
        Test/Eval Loss: 0.4040, Test/Eval Accuracy: 0.8059
        Triplet Loss: 0.3522, Cross Entropy Loss: 0.0518
    Eval Results:
        Correct positive: 183 (62.89%), Correct negative: 280 (96.22%)
        Total correct: 463 (79.55%)
        Test/Eval Loss: 0.2837, Test/Eval Accurac

Epoch 4/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 634 (86.97%), Correct negative: 574 (78.74%)
        Total correct: 1208 (82.85%)
        Test/Eval Loss: 0.3869, Test/Eval Accuracy: 0.8285
        Triplet Loss: 0.3177, Cross Entropy Loss: 0.0693
    Eval Results:
        Correct positive: 179 (61.51%), Correct negative: 284 (97.59%)
        Total correct: 463 (79.55%)
        Test/Eval Loss: 0.3107, Test/Eval Accuracy: 0.7955
        Triplet Loss: 0.2949, Cross Entropy Loss: 0.0158
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 656 (89.99%), Correct negative: 494 (67.76%)
        Total correct: 1150 (78.88%)
        Test/Eval Loss: 0.4660, Test/Eval Accuracy: 0.7888
        Triplet Loss: 0.3655, Cross Entropy Loss: 0.1006
    Eval Results:
        Correct positive: 156 (53.61%), Correct negative: 254 (87.29%)
        Total correct: 410 (70.45%)
        Test/Eval Loss: 0.5173, Test/Eval Accurac

Epoch 5/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 661 (90.67%), Correct negative: 529 (72.57%)
        Total correct: 1190 (81.62%)
        Test/Eval Loss: 0.4161, Test/Eval Accuracy: 0.8162
        Triplet Loss: 0.3317, Cross Entropy Loss: 0.0844
    Eval Results:
        Correct positive: 217 (74.57%), Correct negative: 261 (89.69%)
        Total correct: 478 (82.13%)
        Test/Eval Loss: 0.3801, Test/Eval Accuracy: 0.8213
        Triplet Loss: 0.3633, Cross Entropy Loss: 0.0168
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 647 (88.75%), Correct negative: 510 (69.96%)
        Total correct: 1157 (79.36%)
        Test/Eval Loss: 0.4368, Test/Eval Accuracy: 0.7936
        Triplet Loss: 0.3493, Cross Entropy Loss: 0.0875
    Eval Results:
        Correct positive: 223 (76.63%), Correct negative: 275 (94.50%)
        Total correct: 498 (85.57%)
        Test/Eval Loss: 0.3159, Test/Eval Accurac

Epoch 6/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 647 (88.75%), Correct negative: 506 (69.41%)
        Total correct: 1153 (79.08%)
        Test/Eval Loss: 0.4717, Test/Eval Accuracy: 0.7908
        Triplet Loss: 0.3754, Cross Entropy Loss: 0.0963
    Eval Results:
        Correct positive: 235 (80.76%), Correct negative: 205 (70.45%)
        Total correct: 440 (75.60%)
        Test/Eval Loss: 0.5243, Test/Eval Accuracy: 0.7560
        Triplet Loss: 0.4881, Cross Entropy Loss: 0.0362
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 631 (86.56%), Correct negative: 531 (72.84%)
        Total correct: 1162 (79.70%)
        Test/Eval Loss: 0.4577, Test/Eval Accuracy: 0.7970
        Triplet Loss: 0.3662, Cross Entropy Loss: 0.0915
    Eval Results:
        Correct positive: 199 (68.38%), Correct negative: 250 (85.91%)
        Total correct: 449 (77.15%)
        Test/Eval Loss: 0.5241, Test/Eval Accurac

Epoch 7/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 647 (88.75%), Correct negative: 495 (67.90%)
        Total correct: 1142 (78.33%)
        Test/Eval Loss: 0.4648, Test/Eval Accuracy: 0.7833
        Triplet Loss: 0.3692, Cross Entropy Loss: 0.0956
    Eval Results:
        Correct positive: 223 (76.63%), Correct negative: 269 (92.44%)
        Total correct: 492 (84.54%)
        Test/Eval Loss: 0.3912, Test/Eval Accuracy: 0.8454
        Triplet Loss: 0.3714, Cross Entropy Loss: 0.0199
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 660 (90.53%), Correct negative: 512 (70.23%)
        Total correct: 1172 (80.38%)
        Test/Eval Loss: 0.4195, Test/Eval Accuracy: 0.8038
        Triplet Loss: 0.3384, Cross Entropy Loss: 0.0811
    Eval Results:
        Correct positive: 245 (84.19%), Correct negative: 131 (45.02%)
        Total correct: 376 (64.60%)
        Test/Eval Loss: 0.6345, Test/Eval Accurac

Epoch 8/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 651 (89.30%), Correct negative: 524 (71.88%)
        Total correct: 1175 (80.59%)
        Test/Eval Loss: 0.4438, Test/Eval Accuracy: 0.8059
        Triplet Loss: 0.3668, Cross Entropy Loss: 0.0771
    Eval Results:
        Correct positive: 239 (82.13%), Correct negative: 184 (63.23%)
        Total correct: 423 (72.68%)
        Test/Eval Loss: 0.5401, Test/Eval Accuracy: 0.7268
        Triplet Loss: 0.4790, Cross Entropy Loss: 0.0612
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 657 (90.12%), Correct negative: 523 (71.74%)
        Total correct: 1180 (80.93%)
        Test/Eval Loss: 0.4732, Test/Eval Accuracy: 0.8093
        Triplet Loss: 0.3618, Cross Entropy Loss: 0.1114
    Eval Results:
        Correct positive: 226 (77.66%), Correct negative: 156 (53.61%)
        Total correct: 382 (65.64%)
        Test/Eval Loss: 0.6636, Test/Eval Accurac

Epoch 9/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 661 (90.67%), Correct negative: 478 (65.57%)
        Total correct: 1139 (78.12%)
        Test/Eval Loss: 0.5455, Test/Eval Accuracy: 0.7812
        Triplet Loss: 0.4145, Cross Entropy Loss: 0.1310
    Eval Results:
        Correct positive: 179 (61.51%), Correct negative: 258 (88.66%)
        Total correct: 437 (75.09%)
        Test/Eval Loss: 0.4957, Test/Eval Accuracy: 0.7509
        Triplet Loss: 0.4556, Cross Entropy Loss: 0.0400
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 667 (91.50%), Correct negative: 456 (62.55%)
        Total correct: 1123 (77.02%)
        Test/Eval Loss: 0.5010, Test/Eval Accuracy: 0.7702
        Triplet Loss: 0.3907, Cross Entropy Loss: 0.1104
    Eval Results:
        Correct positive: 197 (67.70%), Correct negative: 271 (93.13%)
        Total correct: 468 (80.41%)
        Test/Eval Loss: 0.4377, Test/Eval Accurac

Epoch 10/10:   0%|          | 0/130 [00:00<?, ?it/s]

___ Current Batch 1/130 _________________________
    Test Results:
        Correct positive: 664 (91.08%), Correct negative: 491 (67.35%)
        Total correct: 1155 (79.22%)
        Test/Eval Loss: 0.4853, Test/Eval Accuracy: 0.7922
        Triplet Loss: 0.3800, Cross Entropy Loss: 0.1053
    Eval Results:
        Correct positive: 215 (73.88%), Correct negative: 254 (87.29%)
        Total correct: 469 (80.58%)
        Test/Eval Loss: 0.4116, Test/Eval Accuracy: 0.8058
        Triplet Loss: 0.3873, Cross Entropy Loss: 0.0243
___ Current Batch 65/130 _________________________
    Test Results:
        Correct positive: 661 (90.67%), Correct negative: 539 (73.94%)
        Total correct: 1200 (82.30%)
        Test/Eval Loss: 0.3969, Test/Eval Accuracy: 0.8230
        Triplet Loss: 0.3231, Cross Entropy Loss: 0.0738
    Eval Results:
        Correct positive: 230 (79.04%), Correct negative: 207 (71.13%)
        Total correct: 437 (75.09%)
        Test/Eval Loss: 0.4880, Test/Eval Accurac