In [1]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

from collections import defaultdict
import random
import numpy as np
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_ORG

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, org edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 8,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (org) full_emb linear_layer dropout 32h 8out'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarOrg for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_org_small_graph', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
num_eval_authors = 3
eval_papers = set()
train_data = WhoIsWhoDataset.parse_train()
for i, val in enumerate(train_data.values()):
    num_eval_authors -= 1
    for paper in val['normal_data']:
        eval_papers.add(paper)
    if num_eval_authors == 0:
        break
eval_triplets = []
for triplet in data_harvester.triplets:
    if triplet[0] in eval_papers or triplet[1] in eval_papers or triplet[2] in eval_papers:
        eval_triplets.append(triplet)

# Remove the evaluation triplets from the data harvester
train_test_triplets = [triplet for triplet in data_harvester.triplets if triplet not in eval_triplets]

random.shuffle(train_test_triplets)

train_size = int(0.85 * len(train_test_triplets))
test_size = len(train_test_triplets) - train_size

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
#optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-10 15:40:58,095 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-10 15:40:58,096 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 3684 triplets.
Train size: 2850, Test size: 503, Eval size: 331


In [4]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
            # Save model if loss has decreased
            if len(results['eval_accuracies']) > 1 and results['eval_accuracies'][-1] > max(results['eval_accuracies'][:-1]):
                print(f"Saving model at epoch {epoch}...")
                torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        



Epoch 1/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 503 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 503 (50.00%)
        Test/Eval Loss: 0.9552, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
    Eval Results:
        Correct positive: 331 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 331 (50.00%)
        Test/Eval Loss: 0.9837, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 365 (72.56%), Correct negative: 409 (81.31%)
        Total correct: 774 (76.94%)
        Test/Eval Loss: 0.4227, Test/Eval Accuracy: 0.7694
        Precision: 0.7952, Recall: 0.7256, F1: 0.7588
    Eval Results:
        Correct positive: 275 (83.08%), Correct negative: 113 (34.14%)
        Total correct: 388 (58.61%)
        Test/Eval Loss: 0.5935, Test/Eval Accuracy: 0.5861
       

Epoch 2/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 428 (85.09%), Correct negative: 410 (81.51%)
        Total correct: 838 (83.30%)
        Test/Eval Loss: 0.2902, Test/Eval Accuracy: 0.8330
        Precision: 0.8215, Recall: 0.8509, F1: 0.8359
    Eval Results:
        Correct positive: 282 (85.20%), Correct negative: 54 (16.31%)
        Total correct: 336 (50.76%)
        Test/Eval Loss: 0.7593, Test/Eval Accuracy: 0.5076
        Precision: 0.5045, Recall: 0.8520, F1: 0.6337
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 427 (84.89%), Correct negative: 409 (81.31%)
        Total correct: 836 (83.10%)
        Test/Eval Loss: 0.3039, Test/Eval Accuracy: 0.8310
        Precision: 0.8196, Recall: 0.8489, F1: 0.8340
    Eval Results:
        Correct positive: 214 (64.65%), Correct negative: 250 (75.53%)
        Total correct: 464 (70.09%)
        Test/Eval Loss: 0.6462, Test/Eval Accuracy: 0.7009
    

Epoch 3/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 417 (82.90%), Correct negative: 442 (87.87%)
        Total correct: 859 (85.39%)
        Test/Eval Loss: 0.2789, Test/Eval Accuracy: 0.8539
        Precision: 0.8724, Recall: 0.8290, F1: 0.8502
    Eval Results:
        Correct positive: 242 (73.11%), Correct negative: 246 (74.32%)
        Total correct: 488 (73.72%)
        Test/Eval Loss: 0.6129, Test/Eval Accuracy: 0.7372
        Precision: 0.7401, Recall: 0.7311, F1: 0.7356
Saving model at epoch 3...
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 437 (86.88%), Correct negative: 367 (72.96%)
        Total correct: 804 (79.92%)
        Test/Eval Loss: 0.3605, Test/Eval Accuracy: 0.7992
        Precision: 0.7627, Recall: 0.8688, F1: 0.8123
    Eval Results:
        Correct positive: 231 (69.79%), Correct negative: 252 (76.13%)
        Total correct: 483 (72.96%)
        Test/Eval Loss: 0.5961, Tes

Epoch 4/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 390 (77.53%), Correct negative: 460 (91.45%)
        Total correct: 850 (84.49%)
        Test/Eval Loss: 0.2596, Test/Eval Accuracy: 0.8449
        Precision: 0.9007, Recall: 0.7753, F1: 0.8333
    Eval Results:
        Correct positive: 245 (74.02%), Correct negative: 295 (89.12%)
        Total correct: 540 (81.57%)
        Test/Eval Loss: 0.3924, Test/Eval Accuracy: 0.8157
        Precision: 0.8719, Recall: 0.7402, F1: 0.8007
Saving model at epoch 4...
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 423 (84.10%), Correct negative: 386 (76.74%)
        Total correct: 809 (80.42%)
        Test/Eval Loss: 0.3420, Test/Eval Accuracy: 0.8042
        Precision: 0.7833, Recall: 0.8410, F1: 0.8111
    Eval Results:
        Correct positive: 222 (67.07%), Correct negative: 230 (69.49%)
        Total correct: 452 (68.28%)
        Test/Eval Loss: 0.6352, Tes

Epoch 5/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 425 (84.49%), Correct negative: 423 (84.10%)
        Total correct: 848 (84.29%)
        Test/Eval Loss: 0.2644, Test/Eval Accuracy: 0.8429
        Precision: 0.8416, Recall: 0.8449, F1: 0.8433
    Eval Results:
        Correct positive: 253 (76.44%), Correct negative: 276 (83.38%)
        Total correct: 529 (79.91%)
        Test/Eval Loss: 0.4454, Test/Eval Accuracy: 0.7991
        Precision: 0.8214, Recall: 0.7644, F1: 0.7919
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 458 (91.05%), Correct negative: 372 (73.96%)
        Total correct: 830 (82.50%)
        Test/Eval Loss: 0.3176, Test/Eval Accuracy: 0.8250
        Precision: 0.7776, Recall: 0.9105, F1: 0.8388
    Eval Results:
        Correct positive: 245 (74.02%), Correct negative: 189 (57.10%)
        Total correct: 434 (65.56%)
        Test/Eval Loss: 0.7151, Test/Eval Accuracy: 0.6556
   

Epoch 6/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 424 (84.29%), Correct negative: 407 (80.91%)
        Total correct: 831 (82.60%)
        Test/Eval Loss: 0.3306, Test/Eval Accuracy: 0.8260
        Precision: 0.8154, Recall: 0.8429, F1: 0.8289
    Eval Results:
        Correct positive: 268 (80.97%), Correct negative: 225 (67.98%)
        Total correct: 493 (74.47%)
        Test/Eval Loss: 0.5192, Test/Eval Accuracy: 0.7447
        Precision: 0.7166, Recall: 0.8097, F1: 0.7603
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 440 (87.48%), Correct negative: 402 (79.92%)
        Total correct: 842 (83.70%)
        Test/Eval Loss: 0.2979, Test/Eval Accuracy: 0.8370
        Precision: 0.8133, Recall: 0.8748, F1: 0.8429
    Eval Results:
        Correct positive: 218 (65.86%), Correct negative: 211 (63.75%)
        Total correct: 429 (64.80%)
        Test/Eval Loss: 0.6699, Test/Eval Accuracy: 0.6480
   

Epoch 7/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 409 (81.31%), Correct negative: 433 (86.08%)
        Total correct: 842 (83.70%)
        Test/Eval Loss: 0.3202, Test/Eval Accuracy: 0.8370
        Precision: 0.8539, Recall: 0.8131, F1: 0.8330
    Eval Results:
        Correct positive: 242 (73.11%), Correct negative: 181 (54.68%)
        Total correct: 423 (63.90%)
        Test/Eval Loss: 0.6243, Test/Eval Accuracy: 0.6390
        Precision: 0.6173, Recall: 0.7311, F1: 0.6694
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 462 (91.85%), Correct negative: 343 (68.19%)
        Total correct: 805 (80.02%)
        Test/Eval Loss: 0.3779, Test/Eval Accuracy: 0.8002
        Precision: 0.7428, Recall: 0.9185, F1: 0.8213
    Eval Results:
        Correct positive: 280 (84.59%), Correct negative: 189 (57.10%)
        Total correct: 469 (70.85%)
        Test/Eval Loss: 0.5544, Test/Eval Accuracy: 0.7085
   

Epoch 8/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 414 (82.31%), Correct negative: 400 (79.52%)
        Total correct: 814 (80.91%)
        Test/Eval Loss: 0.3397, Test/Eval Accuracy: 0.8091
        Precision: 0.8008, Recall: 0.8231, F1: 0.8118
    Eval Results:
        Correct positive: 235 (71.00%), Correct negative: 303 (91.54%)
        Total correct: 538 (81.27%)
        Test/Eval Loss: 0.3807, Test/Eval Accuracy: 0.8127
        Precision: 0.8935, Recall: 0.7100, F1: 0.7912
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 432 (85.88%), Correct negative: 402 (79.92%)
        Total correct: 834 (82.90%)
        Test/Eval Loss: 0.3031, Test/Eval Accuracy: 0.8290
        Precision: 0.8105, Recall: 0.8588, F1: 0.8340
    Eval Results:
        Correct positive: 256 (77.34%), Correct negative: 251 (75.83%)
        Total correct: 507 (76.59%)
        Test/Eval Loss: 0.5266, Test/Eval Accuracy: 0.7659
   

Epoch 9/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 431 (85.69%), Correct negative: 365 (72.56%)
        Total correct: 796 (79.13%)
        Test/Eval Loss: 0.3545, Test/Eval Accuracy: 0.7913
        Precision: 0.7575, Recall: 0.8569, F1: 0.8041
    Eval Results:
        Correct positive: 199 (60.12%), Correct negative: 285 (86.10%)
        Total correct: 484 (73.11%)
        Test/Eval Loss: 0.4645, Test/Eval Accuracy: 0.7311
        Precision: 0.8122, Recall: 0.6012, F1: 0.6910
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 441 (87.67%), Correct negative: 403 (80.12%)
        Total correct: 844 (83.90%)
        Test/Eval Loss: 0.3041, Test/Eval Accuracy: 0.8390
        Precision: 0.8152, Recall: 0.8767, F1: 0.8448
    Eval Results:
        Correct positive: 216 (65.26%), Correct negative: 257 (77.64%)
        Total correct: 473 (71.45%)
        Test/Eval Loss: 0.6081, Test/Eval Accuracy: 0.7145
   

Epoch 10/10:   0%|          | 0/90 [00:00<?, ?it/s]

___ Current Batch 1/90 _________________________
    Test Results:
        Correct positive: 441 (87.67%), Correct negative: 414 (82.31%)
        Total correct: 855 (84.99%)
        Test/Eval Loss: 0.2953, Test/Eval Accuracy: 0.8499
        Precision: 0.8321, Recall: 0.8767, F1: 0.8538
    Eval Results:
        Correct positive: 252 (76.13%), Correct negative: 284 (85.80%)
        Total correct: 536 (80.97%)
        Test/Eval Loss: 0.4256, Test/Eval Accuracy: 0.8097
        Precision: 0.8428, Recall: 0.7613, F1: 0.8000
___ Current Batch 45/90 _________________________
    Test Results:
        Correct positive: 436 (86.68%), Correct negative: 435 (86.48%)
        Total correct: 871 (86.58%)
        Test/Eval Loss: 0.2576, Test/Eval Accuracy: 0.8658
        Precision: 0.8651, Recall: 0.8668, F1: 0.8659
    Eval Results:
        Correct positive: 269 (81.27%), Correct negative: 263 (79.46%)
        Total correct: 532 (80.36%)
        Test/Eval Loss: 0.4591, Test/Eval Accuracy: 0.8036
   