In [1]:
from util_homogeneous import *
from util import *
from training_homogeneous import *
from gat_models import *

from collections import defaultdict
import random
import numpy as np
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


## Configurations

In [2]:
# Graph sampling configurations
node_spec = NodeType.PUBLICATION

edge_spec = EdgeType.SIM_ORG

node_properties = [
    'id',
    'feature_vec',
]

database = 'small-graph'
gs = GraphSampling(
    node_spec=[node_spec],
    edge_spec=[edge_spec],
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on small graph (publication nodes with title and abstract, org edges) using Triplet Loss and full embeddings',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'dropout_p': 0.4,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

model_class = HomoGATv2Encoder
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'homogeneous (org) full_emb linear_layer dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarOrg for homogeneous graph sampling.


In [3]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=[edge_spec], config=config, valid_triplets_save_file='valid_triplets_homogeneous_org_small_graph', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = HomogeneousGraphTripletDataset(train_triplets, gs, config=config)
test_dataset = HomogeneousGraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = HomogeneousGraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    node_spec.value,
    edge_spec.value
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=config['dropout_p']).to(device)
#optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-05 10:37:35,003 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-05 10:37:35,004 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Could not load triplets from file. Generating triplets...
Checking data validity...
Out of 8865 checked papers, 2854 are valid and 6011 are invalid.
Generating hard triplets ...




Too few triplets generated: 1842. Generating more triplets...
Total triplets generated: 3684. Done.
Generated 3684 triplets.
Saving triplets...
Triplets saved.
Train size: 3131, Test size: 368, Eval size: 185


In [None]:
num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save model if loss has decreased
    if len(results['test_total_loss']) > 1 and results['test_total_loss'][-1] < min(results['test_total_loss'][:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')



Epoch 1/10:   0%|          | 0/98 [00:00<?, ?it/s]

___ Current Batch 1/98 _________________________
    Test Results:
        Correct positive: 368 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 368 (50.00%)
        Test/Eval Loss: 0.8982, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
    Eval Results:
        Correct positive: 185 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 185 (50.00%)
        Test/Eval Loss: 0.8931, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 49/98 _________________________
    Test Results:
        Correct positive: 323 (87.77%), Correct negative: 308 (83.70%)
        Total correct: 631 (85.73%)
        Test/Eval Loss: 0.2615, Test/Eval Accuracy: 0.8573
        Precision: 0.8433, Recall: 0.8777, F1: 0.8602
    Eval Results:
        Correct positive: 136 (73.51%), Correct negative: 166 (89.73%)
        Total correct: 302 (81.62%)
        Test/Eval Loss: 0.3672, Test/Eval Accuracy: 0.8162
       

Epoch 2/10:   0%|          | 0/98 [00:00<?, ?it/s]

___ Current Batch 1/98 _________________________
    Test Results:
        Correct positive: 316 (85.87%), Correct negative: 323 (87.77%)
        Total correct: 639 (86.82%)
        Test/Eval Loss: 0.2538, Test/Eval Accuracy: 0.8682
        Precision: 0.8753, Recall: 0.8587, F1: 0.8669
    Eval Results:
        Correct positive: 131 (70.81%), Correct negative: 141 (76.22%)
        Total correct: 272 (73.51%)
        Test/Eval Loss: 0.4207, Test/Eval Accuracy: 0.7351
        Precision: 0.7486, Recall: 0.7081, F1: 0.7278
___ Current Batch 49/98 _________________________
    Test Results:
        Correct positive: 331 (89.95%), Correct negative: 302 (82.07%)
        Total correct: 633 (86.01%)
        Test/Eval Loss: 0.2489, Test/Eval Accuracy: 0.8601
        Precision: 0.8338, Recall: 0.8995, F1: 0.8654
    Eval Results:
        Correct positive: 149 (80.54%), Correct negative: 135 (72.97%)
        Total correct: 284 (76.76%)
        Test/Eval Loss: 0.3929, Test/Eval Accuracy: 0.7676
   

Epoch 3/10:   0%|          | 0/98 [00:00<?, ?it/s]

___ Current Batch 1/98 _________________________
    Test Results:
        Correct positive: 322 (87.50%), Correct negative: 313 (85.05%)
        Total correct: 635 (86.28%)
        Test/Eval Loss: 0.2415, Test/Eval Accuracy: 0.8628
        Precision: 0.8541, Recall: 0.8750, F1: 0.8644
    Eval Results:
        Correct positive: 135 (72.97%), Correct negative: 161 (87.03%)
        Total correct: 296 (80.00%)
        Test/Eval Loss: 0.3793, Test/Eval Accuracy: 0.8000
        Precision: 0.8491, Recall: 0.7297, F1: 0.7849
___ Current Batch 49/98 _________________________
    Test Results:
        Correct positive: 319 (86.68%), Correct negative: 318 (86.41%)
        Total correct: 637 (86.55%)
        Test/Eval Loss: 0.2447, Test/Eval Accuracy: 0.8655
        Precision: 0.8645, Recall: 0.8668, F1: 0.8657
    Eval Results:
        Correct positive: 146 (78.92%), Correct negative: 166 (89.73%)
        Total correct: 312 (84.32%)
        Test/Eval Loss: 0.2945, Test/Eval Accuracy: 0.8432
   