In [17]:
from notebooks.util_homogeneous import HomogeneousGraphTripletDataset
from util_homogeneous import *
from util_heterogeneous import plot_loss, save_training_results
from training_homogeneous import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

## Configurations

In [18]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_ABSTRACT

node_properties = [
    'id',
    'feature_vec',
]

database = 'homogeneous-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on homogeneous graph (publication nodes with title and abstract, abstract edges) using Triplet Loss and full embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATEncoderLinearDropout
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (abstract) full_emb linear_layer dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

In [19]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_abstract', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-11-10 12:57:14,857 - DatabaseWrapper - INFO - Connecting to the database ...
2024-11-10 12:57:14,858 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Could not load triplets from file. Generating triplets...
Checking data validity...
Out of 20034 checked papers, 7665 are valid and 12369 are invalid.
Generating hard triplets ...




Total triplets generated: 6368. Done.
Generated 6368 triplets.
Saving triplets...
Triplets saved.
Train size: 5412, Test size: 636, Eval size: 320


In [20]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(
                test_accuracies,
                epoch_len=2,
                plot_title='Test Accuracy',
                plot_avg=False, 
                x_label='Test Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(
                eval_accuracies, 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                plot_avg=False, 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 636 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 636 (50.00%)
        Test Loss: 1.0256, Test Accuracy: 0.5000
    Eval Results:
        Correct positive: 320 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 320 (50.00%)
        Eval Loss: 0.9901, Eval Accuracy: 0.5000
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 422 (66.35%), Correct negative: 578 (90.88%)
        Total correct: 1000 (78.62%)
        Test Loss: 0.2398, Test Accuracy: 0.7862
    Eval Results:
        Correct positive: 174 (54.37%), Correct negative: 168 (52.50%)
        Total correct: 342 (53.44%)
        Eval Loss: 0.9645, Eval Accuracy: 0.5344
Saving model at epoch 1...


Epoch 2/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 383 (60.22%), Correct negative: 591 (92.92%)
        Total correct: 974 (76.57%)
        Test Loss: 0.2267, Test Accuracy: 0.7657
    Eval Results:
        Correct positive: 121 (37.81%), Correct negative: 153 (47.81%)
        Total correct: 274 (42.81%)
        Eval Loss: 1.2027, Eval Accuracy: 0.4281
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 394 (61.95%), Correct negative: 577 (90.72%)
        Total correct: 971 (76.34%)
        Test Loss: 0.2319, Test Accuracy: 0.7634
    Eval Results:
        Correct positive: 107 (33.44%), Correct negative: 187 (58.44%)
        Total correct: 294 (45.94%)
        Eval Loss: 1.0916, Eval Accuracy: 0.4594


Epoch 3/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 499 (78.46%), Correct negative: 583 (91.67%)
        Total correct: 1082 (85.06%)
        Test Loss: 0.1897, Test Accuracy: 0.8506
    Eval Results:
        Correct positive: 173 (54.06%), Correct negative: 123 (38.44%)
        Total correct: 296 (46.25%)
        Eval Loss: 1.0870, Eval Accuracy: 0.4625
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 458 (72.01%), Correct negative: 575 (90.41%)
        Total correct: 1033 (81.21%)
        Test Loss: 0.2014, Test Accuracy: 0.8121
    Eval Results:
        Correct positive: 178 (55.62%), Correct negative: 83 (25.94%)
        Total correct: 261 (40.78%)
        Eval Loss: 1.1952, Eval Accuracy: 0.4078


Epoch 4/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 455 (71.54%), Correct negative: 589 (92.61%)
        Total correct: 1044 (82.08%)
        Test Loss: 0.1862, Test Accuracy: 0.8208
    Eval Results:
        Correct positive: 202 (63.12%), Correct negative: 108 (33.75%)
        Total correct: 310 (48.44%)
        Eval Loss: 1.1138, Eval Accuracy: 0.4844
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 399 (62.74%), Correct negative: 584 (91.82%)
        Total correct: 983 (77.28%)
        Test Loss: 0.1959, Test Accuracy: 0.7728
    Eval Results:
        Correct positive: 172 (53.75%), Correct negative: 158 (49.38%)
        Total correct: 330 (51.56%)
        Eval Loss: 0.9911, Eval Accuracy: 0.5156


Epoch 5/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 424 (66.67%), Correct negative: 598 (94.03%)
        Total correct: 1022 (80.35%)
        Test Loss: 0.1926, Test Accuracy: 0.8035
    Eval Results:
        Correct positive: 167 (52.19%), Correct negative: 150 (46.88%)
        Total correct: 317 (49.53%)
        Eval Loss: 0.9195, Eval Accuracy: 0.4953
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 458 (72.01%), Correct negative: 593 (93.24%)
        Total correct: 1051 (82.63%)
        Test Loss: 0.1896, Test Accuracy: 0.8263
    Eval Results:
        Correct positive: 170 (53.12%), Correct negative: 176 (55.00%)
        Total correct: 346 (54.06%)
        Eval Loss: 1.0277, Eval Accuracy: 0.5406


Epoch 6/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 529 (83.18%), Correct negative: 582 (91.51%)
        Total correct: 1111 (87.34%)
        Test Loss: 0.1859, Test Accuracy: 0.8734
    Eval Results:
        Correct positive: 228 (71.25%), Correct negative: 65 (20.31%)
        Total correct: 293 (45.78%)
        Eval Loss: 1.1118, Eval Accuracy: 0.4578
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 503 (79.09%), Correct negative: 588 (92.45%)
        Total correct: 1091 (85.77%)
        Test Loss: 0.1623, Test Accuracy: 0.8577
    Eval Results:
        Correct positive: 206 (64.38%), Correct negative: 122 (38.12%)
        Total correct: 328 (51.25%)
        Eval Loss: 1.1098, Eval Accuracy: 0.5125
Saving model at epoch 6...


Epoch 7/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 461 (72.48%), Correct negative: 591 (92.92%)
        Total correct: 1052 (82.70%)
        Test Loss: 0.1482, Test Accuracy: 0.8270
    Eval Results:
        Correct positive: 156 (48.75%), Correct negative: 137 (42.81%)
        Total correct: 293 (45.78%)
        Eval Loss: 1.1837, Eval Accuracy: 0.4578
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 505 (79.40%), Correct negative: 585 (91.98%)
        Total correct: 1090 (85.69%)
        Test Loss: 0.2071, Test Accuracy: 0.8569
    Eval Results:
        Correct positive: 209 (65.31%), Correct negative: 70 (21.88%)
        Total correct: 279 (43.59%)
        Eval Loss: 1.1497, Eval Accuracy: 0.4359


Epoch 8/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 472 (74.21%), Correct negative: 603 (94.81%)
        Total correct: 1075 (84.51%)
        Test Loss: 0.1516, Test Accuracy: 0.8451
    Eval Results:
        Correct positive: 194 (60.62%), Correct negative: 103 (32.19%)
        Total correct: 297 (46.41%)
        Eval Loss: 1.1647, Eval Accuracy: 0.4641
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 426 (66.98%), Correct negative: 605 (95.13%)
        Total correct: 1031 (81.05%)
        Test Loss: 0.1743, Test Accuracy: 0.8105
    Eval Results:
        Correct positive: 159 (49.69%), Correct negative: 121 (37.81%)
        Total correct: 280 (43.75%)
        Eval Loss: 1.1385, Eval Accuracy: 0.4375


Epoch 9/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 448 (70.44%), Correct negative: 602 (94.65%)
        Total correct: 1050 (82.55%)
        Test Loss: 0.1688, Test Accuracy: 0.8255
    Eval Results:
        Correct positive: 175 (54.69%), Correct negative: 138 (43.12%)
        Total correct: 313 (48.91%)
        Eval Loss: 1.0498, Eval Accuracy: 0.4891
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 393 (61.79%), Correct negative: 604 (94.97%)
        Total correct: 997 (78.38%)
        Test Loss: 0.1599, Test Accuracy: 0.7838
    Eval Results:
        Correct positive: 118 (36.88%), Correct negative: 190 (59.38%)
        Total correct: 308 (48.12%)
        Eval Loss: 1.1441, Eval Accuracy: 0.4813


Epoch 10/10:   0%|          | 0/170 [00:00<?, ?it/s]

___ Current Batch 1/170 _________________________
    Test Results:
        Correct positive: 513 (80.66%), Correct negative: 571 (89.78%)
        Total correct: 1084 (85.22%)
        Test Loss: 0.1840, Test Accuracy: 0.8522
    Eval Results:
        Correct positive: 221 (69.06%), Correct negative: 128 (40.00%)
        Total correct: 349 (54.53%)
        Eval Loss: 0.9596, Eval Accuracy: 0.5453
___ Current Batch 85/170 _________________________
    Test Results:
        Correct positive: 355 (55.82%), Correct negative: 598 (94.03%)
        Total correct: 953 (74.92%)
        Test Loss: 0.2008, Test Accuracy: 0.7492
    Eval Results:
        Correct positive: 130 (40.62%), Correct negative: 133 (41.56%)
        Total correct: 263 (41.09%)
        Eval Loss: 1.2492, Eval Accuracy: 0.4109
