In [16]:
from util import *
from training_heterogeneous import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

### Configurations

In [17]:
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    EdgeType.SIM_VENUE,
    EdgeType.SIM_ABSTRACT,
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'title',
    'abstract',
    'title_emb',
    'abstract_emb',
    'feature_vec',
]

database = 'homogeneous-graph-compressed-emb'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer) trained on heterogeneous graph (publication nodes with title and abstract, similarity and co-author edges) using Triplet Loss and dimension reduced embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HeteroGATEncoderLinear
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'hetero_edges compressed_emb linear_layer'
result_folder_path = f'./data/results/{result_folder_name}'
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    #EdgeType.SIM_VENUE,
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'title',
    'abstract',
    'title_emb',
    'abstract_emb',
    'feature_vec',
]

database = 'homogeneous-graph-compressed-emb'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on homogeneous graph (publication nodes with title and abstract, co-author edges) using Triplet Loss and dimension reduced embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 20,
    'batch_size': 32,
}

model_class = HeteroGATEncoderLinearDropout
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homo_edges compressed_emb linear_layer dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

### Training Configuration

In [14]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config)


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = GraphTripletDataset(train_triplets, gs, config=config)
test_dataset = GraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = GraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(metadata, config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-11-04 18:17:34,251 - DatabaseWrapper - INFO - Connecting to the database ...
2024-11-04 18:17:34,252 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Could not load triplets from file. Generating triplets...
Checking data validity...
Out of 20034 checked papers, 12937 are valid and 7097 are invalid.
Preparing pairs...
Total triplets: 11755. Done.
Generated 11755 triplets.
Saving triplets...
Triplets saved.
Train size: 9991, Test size: 1175, Eval size: 589


### Training Loop

In [15]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(test_accuracies, epoch_len=2, plot_title='Test Accuracy', plot_avg=False, plot_file=result_folder_path + '/test_accuracy.png')
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(eval_accuracies, epoch_len=2, plot_title='Evaluation Accuracy', plot_avg=False, plot_file=result_folder_path + '/eval_accuracy.png')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 1/313 _________________________
Test Loss: 1.0114
Correct positive: 589 (100.00%), Correct negative: 0 (0.00%)
Total correct: 589 (50.00%)
Eval Loss: 1.0111, Eval Accuracy: 0.5000
___ Batch 11/313 _________________________
Test Loss: 0.8908
Correct positive: 311 (52.80%), Correct negative: 263 (44.65%)
Total correct: 574 (48.73%)
Eval Loss: 1.0244, Eval Accuracy: 0.4873
___ Batch 21/313 _________________________
Test Loss: 0.8515
Correct positive: 350 (59.42%), Correct negative: 285 (48.39%)
Total correct: 635 (53.90%)
Eval Loss: 0.8974, Eval Accuracy: 0.5390
___ Batch 31/313 _________________________
Test Loss: 0.7738
Correct positive: 346 (58.74%), Correct negative: 273 (46.35%)
Total correct: 619 (52.55%)
Eval Loss: 0.9289, Eval Accuracy: 0.5255
___ Batch 41/313 _________________________
Test Loss: 0.7711
Correct positive: 191 (32.43%), Correct negative: 338 (57.39%)
Total correct: 529 (44.91%)
Eval Loss: 1.1852, Eval Accuracy: 0.4491
___ Batch 51/313 _____________________

Epoch 2/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 8/313 _________________________
Test Loss: 0.4258
Correct positive: 63 (10.70%), Correct negative: 485 (82.34%)
Total correct: 548 (46.52%)
Eval Loss: 1.1769, Eval Accuracy: 0.4652
___ Batch 18/313 _________________________
Test Loss: 0.4137
Correct positive: 154 (26.15%), Correct negative: 447 (75.89%)
Total correct: 601 (51.02%)
Eval Loss: 1.0103, Eval Accuracy: 0.5102
___ Batch 28/313 _________________________
Test Loss: 0.4248
Correct positive: 87 (14.77%), Correct negative: 522 (88.62%)
Total correct: 609 (51.70%)
Eval Loss: 1.0619, Eval Accuracy: 0.5170
___ Batch 38/313 _________________________
Test Loss: 0.4078
Correct positive: 240 (40.75%), Correct negative: 439 (74.53%)
Total correct: 679 (57.64%)
Eval Loss: 0.9016, Eval Accuracy: 0.5764
___ Batch 48/313 _________________________
Test Loss: 0.4057
Correct positive: 199 (33.79%), Correct negative: 431 (73.17%)
Total correct: 630 (53.48%)
Eval Loss: 0.9022, Eval Accuracy: 0.5348
___ Batch 58/313 _____________________

Epoch 3/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 5/313 _________________________
Test Loss: 0.3018
Correct positive: 172 (29.20%), Correct negative: 438 (74.36%)
Total correct: 610 (51.78%)
Eval Loss: 1.0593, Eval Accuracy: 0.5178
___ Batch 15/313 _________________________
Test Loss: 0.3225
Correct positive: 69 (11.71%), Correct negative: 535 (90.83%)
Total correct: 604 (51.27%)
Eval Loss: 0.8903, Eval Accuracy: 0.5127
___ Batch 25/313 _________________________
Test Loss: 0.3390
Correct positive: 68 (11.54%), Correct negative: 545 (92.53%)
Total correct: 613 (52.04%)
Eval Loss: 0.8572, Eval Accuracy: 0.5204
___ Batch 35/313 _________________________
Test Loss: 0.3515
Correct positive: 210 (35.65%), Correct negative: 505 (85.74%)
Total correct: 715 (60.70%)
Eval Loss: 0.8292, Eval Accuracy: 0.6070
___ Batch 45/313 _________________________
Test Loss: 0.3390
Correct positive: 154 (26.15%), Correct negative: 520 (88.29%)
Total correct: 674 (57.22%)
Eval Loss: 0.8119, Eval Accuracy: 0.5722
___ Batch 55/313 _____________________

Epoch 4/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 2/313 _________________________
Test Loss: 0.3391
Correct positive: 166 (28.18%), Correct negative: 523 (88.79%)
Total correct: 689 (58.49%)
Eval Loss: 0.7968, Eval Accuracy: 0.5849
___ Batch 12/313 _________________________
Test Loss: 0.2721
Correct positive: 217 (36.84%), Correct negative: 487 (82.68%)
Total correct: 704 (59.76%)
Eval Loss: 0.8195, Eval Accuracy: 0.5976
___ Batch 22/313 _________________________
Test Loss: 0.2603
Correct positive: 199 (33.79%), Correct negative: 456 (77.42%)
Total correct: 655 (55.60%)
Eval Loss: 0.9251, Eval Accuracy: 0.5560
___ Batch 32/313 _________________________
Test Loss: 0.2701
Correct positive: 178 (30.22%), Correct negative: 448 (76.06%)
Total correct: 626 (53.14%)
Eval Loss: 0.9090, Eval Accuracy: 0.5314
___ Batch 42/313 _________________________
Test Loss: 0.2613
Correct positive: 186 (31.58%), Correct negative: 440 (74.70%)
Total correct: 626 (53.14%)
Eval Loss: 0.8080, Eval Accuracy: 0.5314
___ Batch 52/313 ___________________

Epoch 5/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 9/313 _________________________
Test Loss: 0.2553
Correct positive: 112 (19.02%), Correct negative: 517 (87.78%)
Total correct: 629 (53.40%)
Eval Loss: 0.9185, Eval Accuracy: 0.5340
___ Batch 19/313 _________________________
Test Loss: 0.2478
Correct positive: 72 (12.22%), Correct negative: 524 (88.96%)
Total correct: 596 (50.59%)
Eval Loss: 0.9292, Eval Accuracy: 0.5059
___ Batch 29/313 _________________________
Test Loss: 0.2354
Correct positive: 103 (17.49%), Correct negative: 526 (89.30%)
Total correct: 629 (53.40%)
Eval Loss: 0.8320, Eval Accuracy: 0.5340
___ Batch 39/313 _________________________
Test Loss: 0.2698
Correct positive: 60 (10.19%), Correct negative: 525 (89.13%)
Total correct: 585 (49.66%)
Eval Loss: 0.9532, Eval Accuracy: 0.4966
___ Batch 49/313 _________________________
Test Loss: 0.2400
Correct positive: 55 (9.34%), Correct negative: 501 (85.06%)
Total correct: 556 (47.20%)
Eval Loss: 1.2363, Eval Accuracy: 0.4720
___ Batch 59/313 _______________________

Epoch 6/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 6/313 _________________________
Test Loss: 0.2443
Correct positive: 111 (18.85%), Correct negative: 524 (88.96%)
Total correct: 635 (53.90%)
Eval Loss: 0.8782, Eval Accuracy: 0.5390
___ Batch 16/313 _________________________
Test Loss: 0.2467
Correct positive: 135 (22.92%), Correct negative: 536 (91.00%)
Total correct: 671 (56.96%)
Eval Loss: 0.7620, Eval Accuracy: 0.5696
___ Batch 26/313 _________________________
Test Loss: 0.2386
Correct positive: 143 (24.28%), Correct negative: 516 (87.61%)
Total correct: 659 (55.94%)
Eval Loss: 0.9698, Eval Accuracy: 0.5594
___ Batch 36/313 _________________________
Test Loss: 0.2300
Correct positive: 207 (35.14%), Correct negative: 482 (81.83%)
Total correct: 689 (58.49%)
Eval Loss: 0.8735, Eval Accuracy: 0.5849
___ Batch 46/313 _________________________
Test Loss: 0.2243
Correct positive: 154 (26.15%), Correct negative: 493 (83.70%)
Total correct: 647 (54.92%)
Eval Loss: 0.8428, Eval Accuracy: 0.5492
___ Batch 56/313 ___________________

Epoch 7/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 3/313 _________________________
Test Loss: 0.2201
Correct positive: 114 (19.35%), Correct negative: 477 (80.98%)
Total correct: 591 (50.17%)
Eval Loss: 0.9898, Eval Accuracy: 0.5017
___ Batch 13/313 _________________________
Test Loss: 0.2576
Correct positive: 72 (12.22%), Correct negative: 527 (89.47%)
Total correct: 599 (50.85%)
Eval Loss: 1.4135, Eval Accuracy: 0.5085
___ Batch 23/313 _________________________
Test Loss: 0.2199
Correct positive: 110 (18.68%), Correct negative: 524 (88.96%)
Total correct: 634 (53.82%)
Eval Loss: 0.9170, Eval Accuracy: 0.5382
___ Batch 33/313 _________________________
Test Loss: 0.2180
Correct positive: 92 (15.62%), Correct negative: 512 (86.93%)
Total correct: 604 (51.27%)
Eval Loss: 1.0765, Eval Accuracy: 0.5127
___ Batch 43/313 _________________________
Test Loss: 0.2234
Correct positive: 109 (18.51%), Correct negative: 529 (89.81%)
Total correct: 638 (54.16%)
Eval Loss: 0.8159, Eval Accuracy: 0.5416
___ Batch 53/313 _____________________

Epoch 8/20:   0%|          | 0/313 [00:00<?, ?it/s]

___ Batch 10/313 _________________________
Test Loss: 0.2142
Correct positive: 163 (27.67%), Correct negative: 483 (82.00%)
Total correct: 646 (54.84%)
Eval Loss: 0.9249, Eval Accuracy: 0.5484
___ Batch 20/313 _________________________
Test Loss: 0.2041
Correct positive: 107 (18.17%), Correct negative: 512 (86.93%)
Total correct: 619 (52.55%)
Eval Loss: 0.9365, Eval Accuracy: 0.5255
___ Batch 30/313 _________________________
Test Loss: 0.1943
Correct positive: 110 (18.68%), Correct negative: 494 (83.87%)
Total correct: 604 (51.27%)
Eval Loss: 0.9279, Eval Accuracy: 0.5127
___ Batch 40/313 _________________________
Test Loss: 0.1913
Correct positive: 119 (20.20%), Correct negative: 513 (87.10%)
Total correct: 632 (53.65%)
Eval Loss: 0.9512, Eval Accuracy: 0.5365
___ Batch 50/313 _________________________
Test Loss: 0.1933
Correct positive: 103 (17.49%), Correct negative: 493 (83.70%)
Total correct: 596 (50.59%)
Eval Loss: 1.0239, Eval Accuracy: 0.5059
___ Batch 60/313 __________________

Exception ignored in: <function Workspace.__del__ at 0x7cf8b72c8a60>
Traceback (most recent call last):
  File "/home/vincie/.anaconda3/envs/master/lib/python3.9/site-packages/neo4j/_sync/work/workspace.py", line 62, in __del__
    def __del__(self):
KeyboardInterrupt: 


KeyboardInterrupt: 