In [1]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_TITLE

node_properties = [
    'id',
    'feature_vec',
]

database = 'dense-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on homogeneous dense graph (publication nodes with title and abstract, title edges) using Triplet Loss and full embeddings and low hidden dimensions',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 3,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATEncoderLinearDropout
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'dense (title) full_emb linear_layer dropout low-hidden-dim'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarTitle for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_dense_title', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-11-26 08:49:40,118 - DatabaseWrapper - INFO - Connecting to the database ...
2024-11-26 08:49:40,119 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 9755 triplets.
Train size: 8291, Test size: 975, Eval size: 489


In [4]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(
                test_accuracies,
                epoch_len=2,
                plot_title='Test Accuracy',
                plot_avg=False, 
                x_label='Test Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(
                eval_accuracies, 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                plot_avg=False, 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 975 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 975 (50.00%)
        Test Loss: 1.0043, Test Accuracy: 0.5000
    Eval Results:
        Correct positive: 489 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 489 (50.00%)
        Eval Loss: 1.0016, Eval Accuracy: 0.5000
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 546 (56.00%), Correct negative: 849 (87.08%)
        Total correct: 1395 (71.54%)
        Test Loss: 0.3428, Test Accuracy: 0.7154
    Eval Results:
        Correct positive: 185 (37.83%), Correct negative: 191 (39.06%)
        Total correct: 376 (38.45%)
        Eval Loss: 1.3183, Eval Accuracy: 0.3845
Saving model at epoch 1...


Epoch 2/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 691 (70.87%), Correct negative: 834 (85.54%)
        Total correct: 1525 (78.21%)
        Test Loss: 0.2710, Test Accuracy: 0.7821
    Eval Results:
        Correct positive: 402 (82.21%), Correct negative: 81 (16.56%)
        Total correct: 483 (49.39%)
        Eval Loss: 1.0455, Eval Accuracy: 0.4939
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 669 (68.62%), Correct negative: 845 (86.67%)
        Total correct: 1514 (77.64%)
        Test Loss: 0.2327, Test Accuracy: 0.7764
    Eval Results:
        Correct positive: 278 (56.85%), Correct negative: 201 (41.10%)
        Total correct: 479 (48.98%)
        Eval Loss: 1.1367, Eval Accuracy: 0.4898
Saving model at epoch 2...


Epoch 3/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 693 (71.08%), Correct negative: 856 (87.79%)
        Total correct: 1549 (79.44%)
        Test Loss: 0.2561, Test Accuracy: 0.7944
    Eval Results:
        Correct positive: 280 (57.26%), Correct negative: 228 (46.63%)
        Total correct: 508 (51.94%)
        Eval Loss: 1.0510, Eval Accuracy: 0.5194
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 797 (81.74%), Correct negative: 801 (82.15%)
        Total correct: 1598 (81.95%)
        Test Loss: 0.2606, Test Accuracy: 0.8195
    Eval Results:
        Correct positive: 464 (94.89%), Correct negative: 25 (5.11%)
        Total correct: 489 (50.00%)
        Eval Loss: 1.0662, Eval Accuracy: 0.5000


Epoch 4/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 606 (62.15%), Correct negative: 886 (90.87%)
        Total correct: 1492 (76.51%)
        Test Loss: 0.2632, Test Accuracy: 0.7651
    Eval Results:
        Correct positive: 295 (60.33%), Correct negative: 175 (35.79%)
        Total correct: 470 (48.06%)
        Eval Loss: 1.0342, Eval Accuracy: 0.4806
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 680 (69.74%), Correct negative: 875 (89.74%)
        Total correct: 1555 (79.74%)
        Test Loss: 0.2588, Test Accuracy: 0.7974
    Eval Results:
        Correct positive: 349 (71.37%), Correct negative: 128 (26.18%)
        Total correct: 477 (48.77%)
        Eval Loss: 1.0347, Eval Accuracy: 0.4877


Epoch 5/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 776 (79.59%), Correct negative: 860 (88.21%)
        Total correct: 1636 (83.90%)
        Test Loss: 0.1854, Test Accuracy: 0.8390
    Eval Results:
        Correct positive: 402 (82.21%), Correct negative: 70 (14.31%)
        Total correct: 472 (48.26%)
        Eval Loss: 1.0157, Eval Accuracy: 0.4826
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 794 (81.44%), Correct negative: 836 (85.74%)
        Total correct: 1630 (83.59%)
        Test Loss: 0.2219, Test Accuracy: 0.8359
    Eval Results:
        Correct positive: 378 (77.30%), Correct negative: 146 (29.86%)
        Total correct: 524 (53.58%)
        Eval Loss: 0.9575, Eval Accuracy: 0.5358


Epoch 6/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 661 (67.79%), Correct negative: 884 (90.67%)
        Total correct: 1545 (79.23%)
        Test Loss: 0.2160, Test Accuracy: 0.7923
    Eval Results:
        Correct positive: 280 (57.26%), Correct negative: 100 (20.45%)
        Total correct: 380 (38.85%)
        Eval Loss: 1.3504, Eval Accuracy: 0.3885
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 712 (73.03%), Correct negative: 890 (91.28%)
        Total correct: 1602 (82.15%)
        Test Loss: 0.1985, Test Accuracy: 0.8215
    Eval Results:
        Correct positive: 385 (78.73%), Correct negative: 115 (23.52%)
        Total correct: 500 (51.12%)
        Eval Loss: 1.0461, Eval Accuracy: 0.5112


Epoch 7/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 727 (74.56%), Correct negative: 903 (92.62%)
        Total correct: 1630 (83.59%)
        Test Loss: 0.1765, Test Accuracy: 0.8359
    Eval Results:
        Correct positive: 346 (70.76%), Correct negative: 151 (30.88%)
        Total correct: 497 (50.82%)
        Eval Loss: 0.9720, Eval Accuracy: 0.5082
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 757 (77.64%), Correct negative: 885 (90.77%)
        Total correct: 1642 (84.21%)
        Test Loss: 0.1947, Test Accuracy: 0.8421
    Eval Results:
        Correct positive: 272 (55.62%), Correct negative: 191 (39.06%)
        Total correct: 463 (47.34%)
        Eval Loss: 1.1555, Eval Accuracy: 0.4734


Epoch 8/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 579 (59.38%), Correct negative: 911 (93.44%)
        Total correct: 1490 (76.41%)
        Test Loss: 0.1892, Test Accuracy: 0.7641
    Eval Results:
        Correct positive: 195 (39.88%), Correct negative: 293 (59.92%)
        Total correct: 488 (49.90%)
        Eval Loss: 1.2052, Eval Accuracy: 0.4990
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 822 (84.31%), Correct negative: 875 (89.74%)
        Total correct: 1697 (87.03%)
        Test Loss: 0.1636, Test Accuracy: 0.8703
    Eval Results:
        Correct positive: 360 (73.62%), Correct negative: 90 (18.40%)
        Total correct: 450 (46.01%)
        Eval Loss: 1.1673, Eval Accuracy: 0.4601
Saving model at epoch 8...


Epoch 9/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 770 (78.97%), Correct negative: 895 (91.79%)
        Total correct: 1665 (85.38%)
        Test Loss: 0.1652, Test Accuracy: 0.8538
    Eval Results:
        Correct positive: 362 (74.03%), Correct negative: 74 (15.13%)
        Total correct: 436 (44.58%)
        Eval Loss: 1.2920, Eval Accuracy: 0.4458
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 769 (78.87%), Correct negative: 907 (93.03%)
        Total correct: 1676 (85.95%)
        Test Loss: 0.1522, Test Accuracy: 0.8595
    Eval Results:
        Correct positive: 355 (72.60%), Correct negative: 124 (25.36%)
        Total correct: 479 (48.98%)
        Eval Loss: 1.0464, Eval Accuracy: 0.4898
Saving model at epoch 9...


Epoch 10/10:   0%|          | 0/260 [00:00<?, ?it/s]

___ Current Batch 1/260 _________________________
    Test Results:
        Correct positive: 714 (73.23%), Correct negative: 912 (93.54%)
        Total correct: 1626 (83.38%)
        Test Loss: 0.1991, Test Accuracy: 0.8338
    Eval Results:
        Correct positive: 258 (52.76%), Correct negative: 216 (44.17%)
        Total correct: 474 (48.47%)
        Eval Loss: 1.0538, Eval Accuracy: 0.4847
___ Current Batch 130/260 _________________________
    Test Results:
        Correct positive: 803 (82.36%), Correct negative: 883 (90.56%)
        Total correct: 1686 (86.46%)
        Test Loss: 0.1830, Test Accuracy: 0.8646
    Eval Results:
        Correct positive: 359 (73.42%), Correct negative: 182 (37.22%)
        Total correct: 541 (55.32%)
        Eval Loss: 0.8530, Eval Accuracy: 0.5532
