In [9]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

from collections import defaultdict
import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

## Configurations

In [10]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_TITLE

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, title edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (title) full_emb linear_layer dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarTitle for homogeneous graph sampling.


In [11]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_title_small_graph', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
num_eval_authors = 3
eval_papers = set()
train_data = WhoIsWhoDataset.parse_train()
for i, val in enumerate(train_data.values()):
    num_eval_authors -= 1
    for paper in val['normal_data']:
        eval_papers.add(paper)
    if num_eval_authors == 0:
        break
eval_triplets = []
for triplet in data_harvester.triplets:
    if triplet[0] in eval_papers or triplet[1] in eval_papers or triplet[2] in eval_papers:
        eval_triplets.append(triplet)

# Remove the evaluation triplets from the data harvester
train_test_triplets = [triplet for triplet in data_harvester.triplets if triplet not in eval_triplets]

random.shuffle(train_test_triplets)

train_size = int(0.85 * len(train_test_triplets))
test_size = len(train_test_triplets) - train_size

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
#optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-06 09:04:34,756 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-06 09:04:34,756 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 8202 triplets.
Train size: 6442, Test size: 1137, Eval size: 623


In [12]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
            # Save model if loss has decreased
            if len(results['eval_accuracies']) > 1 and results['eval_accuracies'][-1] > max(results['eval_accuracies'][:-1]):
                print(f"Saving model at epoch {epoch}...")
                torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1



Epoch 1/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1137 (100.00%), Correct negative: 1 (0.09%)
        Total correct: 1138 (50.04%)
        Test/Eval Loss: 0.9917, Test/Eval Accuracy: 0.5004
        Precision: 0.5002, Recall: 1.0000, F1: 0.6669
    Eval Results:
        Correct positive: 623 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 623 (50.00%)
        Test/Eval Loss: 0.9460, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 984 (86.54%), Correct negative: 930 (81.79%)
        Total correct: 1914 (84.17%)
        Test/Eval Loss: 0.2916, Test/Eval Accuracy: 0.8417
        Precision: 0.8262, Recall: 0.8654, F1: 0.8454
    Eval Results:
        Correct positive: 320 (51.36%), Correct negative: 337 (54.09%)
        Total correct: 657 (52.73%)
        Test/Eval Loss: 0.9967, Test/Eval Accuracy: 0.5273
 

Epoch 2/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1019 (89.62%), Correct negative: 931 (81.88%)
        Total correct: 1950 (85.75%)
        Test/Eval Loss: 0.2390, Test/Eval Accuracy: 0.8575
        Precision: 0.8318, Recall: 0.8962, F1: 0.8628
    Eval Results:
        Correct positive: 390 (62.60%), Correct negative: 275 (44.14%)
        Total correct: 665 (53.37%)
        Test/Eval Loss: 1.1495, Test/Eval Accuracy: 0.5337
        Precision: 0.5285, Recall: 0.6260, F1: 0.5731
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1018 (89.53%), Correct negative: 891 (78.36%)
        Total correct: 1909 (83.95%)
        Test/Eval Loss: 0.2756, Test/Eval Accuracy: 0.8395
        Precision: 0.8054, Recall: 0.8953, F1: 0.8480
    Eval Results:
        Correct positive: 348 (55.86%), Correct negative: 275 (44.14%)
        Total correct: 623 (50.00%)
        Test/Eval Loss: 1.1323, Test/Eval Accuracy: 0.5

Epoch 3/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1044 (91.82%), Correct negative: 918 (80.74%)
        Total correct: 1962 (86.28%)
        Test/Eval Loss: 0.2250, Test/Eval Accuracy: 0.8628
        Precision: 0.8266, Recall: 0.9182, F1: 0.8700
    Eval Results:
        Correct positive: 364 (58.43%), Correct negative: 272 (43.66%)
        Total correct: 636 (51.04%)
        Test/Eval Loss: 1.1332, Test/Eval Accuracy: 0.5104
        Precision: 0.5091, Recall: 0.5843, F1: 0.5441
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 966 (84.96%), Correct negative: 988 (86.90%)
        Total correct: 1954 (85.93%)
        Test/Eval Loss: 0.2563, Test/Eval Accuracy: 0.8593
        Precision: 0.8664, Recall: 0.8496, F1: 0.8579
    Eval Results:
        Correct positive: 398 (63.88%), Correct negative: 92 (14.77%)
        Total correct: 490 (39.33%)
        Test/Eval Loss: 1.2407, Test/Eval Accuracy: 0.393

Epoch 4/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1000 (87.95%), Correct negative: 903 (79.42%)
        Total correct: 1903 (83.69%)
        Test/Eval Loss: 0.2908, Test/Eval Accuracy: 0.8369
        Precision: 0.8104, Recall: 0.8795, F1: 0.8435
    Eval Results:
        Correct positive: 375 (60.19%), Correct negative: 252 (40.45%)
        Total correct: 627 (50.32%)
        Test/Eval Loss: 1.1059, Test/Eval Accuracy: 0.5032
        Precision: 0.5027, Recall: 0.6019, F1: 0.5478
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1016 (89.36%), Correct negative: 962 (84.61%)
        Total correct: 1978 (86.98%)
        Test/Eval Loss: 0.2328, Test/Eval Accuracy: 0.8698
        Precision: 0.8531, Recall: 0.8936, F1: 0.8729
    Eval Results:
        Correct positive: 416 (66.77%), Correct negative: 264 (42.38%)
        Total correct: 680 (54.57%)
        Test/Eval Loss: 1.1149, Test/Eval Accuracy: 0.5

Epoch 5/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1013 (89.09%), Correct negative: 863 (75.90%)
        Total correct: 1876 (82.50%)
        Test/Eval Loss: 0.3184, Test/Eval Accuracy: 0.8250
        Precision: 0.7871, Recall: 0.8909, F1: 0.8358
    Eval Results:
        Correct positive: 312 (50.08%), Correct negative: 279 (44.78%)
        Total correct: 591 (47.43%)
        Test/Eval Loss: 1.0975, Test/Eval Accuracy: 0.4743
        Precision: 0.4756, Recall: 0.5008, F1: 0.4879
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1012 (89.01%), Correct negative: 960 (84.43%)
        Total correct: 1972 (86.72%)
        Test/Eval Loss: 0.2384, Test/Eval Accuracy: 0.8672
        Precision: 0.8511, Recall: 0.8901, F1: 0.8702
    Eval Results:
        Correct positive: 350 (56.18%), Correct negative: 91 (14.61%)
        Total correct: 441 (35.39%)
        Test/Eval Loss: 1.2988, Test/Eval Accuracy: 0.35

Epoch 6/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1062 (93.40%), Correct negative: 899 (79.07%)
        Total correct: 1961 (86.24%)
        Test/Eval Loss: 0.2685, Test/Eval Accuracy: 0.8624
        Precision: 0.8169, Recall: 0.9340, F1: 0.8716
    Eval Results:
        Correct positive: 346 (55.54%), Correct negative: 265 (42.54%)
        Total correct: 611 (49.04%)
        Test/Eval Loss: 1.1801, Test/Eval Accuracy: 0.4904
        Precision: 0.4915, Recall: 0.5554, F1: 0.5215
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1025 (90.15%), Correct negative: 909 (79.95%)
        Total correct: 1934 (85.05%)
        Test/Eval Loss: 0.2840, Test/Eval Accuracy: 0.8505
        Precision: 0.8180, Recall: 0.9015, F1: 0.8577
    Eval Results:
        Correct positive: 341 (54.74%), Correct negative: 174 (27.93%)
        Total correct: 515 (41.33%)
        Test/Eval Loss: 1.2865, Test/Eval Accuracy: 0.4

Epoch 7/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1033 (90.85%), Correct negative: 910 (80.04%)
        Total correct: 1943 (85.44%)
        Test/Eval Loss: 0.2606, Test/Eval Accuracy: 0.8544
        Precision: 0.8198, Recall: 0.9085, F1: 0.8619
    Eval Results:
        Correct positive: 343 (55.06%), Correct negative: 282 (45.26%)
        Total correct: 625 (50.16%)
        Test/Eval Loss: 1.0944, Test/Eval Accuracy: 0.5016
        Precision: 0.5015, Recall: 0.5506, F1: 0.5249
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1009 (88.74%), Correct negative: 962 (84.61%)
        Total correct: 1971 (86.68%)
        Test/Eval Loss: 0.2444, Test/Eval Accuracy: 0.8668
        Precision: 0.8522, Recall: 0.8874, F1: 0.8695
    Eval Results:
        Correct positive: 378 (60.67%), Correct negative: 266 (42.70%)
        Total correct: 644 (51.69%)
        Test/Eval Loss: 1.1064, Test/Eval Accuracy: 0.5

Epoch 8/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1049 (92.26%), Correct negative: 953 (83.82%)
        Total correct: 2002 (88.04%)
        Test/Eval Loss: 0.2172, Test/Eval Accuracy: 0.8804
        Precision: 0.8508, Recall: 0.9226, F1: 0.8852
    Eval Results:
        Correct positive: 392 (62.92%), Correct negative: 261 (41.89%)
        Total correct: 653 (52.41%)
        Test/Eval Loss: 1.1136, Test/Eval Accuracy: 0.5241
        Precision: 0.5199, Recall: 0.6292, F1: 0.5694
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1059 (93.14%), Correct negative: 899 (79.07%)
        Total correct: 1958 (86.10%)
        Test/Eval Loss: 0.2363, Test/Eval Accuracy: 0.8610
        Precision: 0.8165, Recall: 0.9314, F1: 0.8702
    Eval Results:
        Correct positive: 380 (61.00%), Correct negative: 272 (43.66%)
        Total correct: 652 (52.33%)
        Test/Eval Loss: 1.1743, Test/Eval Accuracy: 0.5

Epoch 9/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1019 (89.62%), Correct negative: 972 (85.49%)
        Total correct: 1991 (87.55%)
        Test/Eval Loss: 0.2293, Test/Eval Accuracy: 0.8755
        Precision: 0.8606, Recall: 0.8962, F1: 0.8781
    Eval Results:
        Correct positive: 378 (60.67%), Correct negative: 86 (13.80%)
        Total correct: 464 (37.24%)
        Test/Eval Loss: 1.3462, Test/Eval Accuracy: 0.3724
        Precision: 0.4131, Recall: 0.6067, F1: 0.4915
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1049 (92.26%), Correct negative: 917 (80.65%)
        Total correct: 1966 (86.46%)
        Test/Eval Loss: 0.2454, Test/Eval Accuracy: 0.8646
        Precision: 0.8266, Recall: 0.9226, F1: 0.8720
    Eval Results:
        Correct positive: 381 (61.16%), Correct negative: 269 (43.18%)
        Total correct: 650 (52.17%)
        Test/Eval Loss: 1.1781, Test/Eval Accuracy: 0.52

Epoch 10/10:   0%|          | 0/202 [00:00<?, ?it/s]

___ Current Batch 1/202 _________________________
    Test Results:
        Correct positive: 1032 (90.77%), Correct negative: 927 (81.53%)
        Total correct: 1959 (86.15%)
        Test/Eval Loss: 0.2376, Test/Eval Accuracy: 0.8615
        Precision: 0.8309, Recall: 0.9077, F1: 0.8676
    Eval Results:
        Correct positive: 361 (57.95%), Correct negative: 259 (41.57%)
        Total correct: 620 (49.76%)
        Test/Eval Loss: 1.1482, Test/Eval Accuracy: 0.4976
        Precision: 0.4979, Recall: 0.5795, F1: 0.5356
___ Current Batch 101/202 _________________________
    Test Results:
        Correct positive: 1062 (93.40%), Correct negative: 879 (77.31%)
        Total correct: 1941 (85.36%)
        Test/Eval Loss: 0.2741, Test/Eval Accuracy: 0.8536
        Precision: 0.8045, Recall: 0.9340, F1: 0.8645
    Eval Results:
        Correct positive: 386 (61.96%), Correct negative: 232 (37.24%)
        Total correct: 618 (49.60%)
        Test/Eval Loss: 1.2063, Test/Eval Accuracy: 0.4