In [1]:
from notebooks.util import GraphTripletDataset
from util import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

  from tqdm.autonotebook import tqdm, trange


### Configurations

In [2]:
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    #EdgeType.SIM_VENUE,
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'title',
    'abstract',
    'title_emb',
    'abstract_emb',
    'feature_vec',
]

database = 'homogeneous-graph-compressed-emb'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer) trained on homogeneous graph (publication nodes with title and abstract, co-author edges) using Triplet Loss and dimension reduced embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 20,
    'batch_size': 32,
}

model_class = HeteroGATEncoderLinear
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

save_file_postfix = "triplets_homo_edges_compressed_emb_linear_layer"

### Training functions

In [3]:
def train(model, batch_anchor, batch_pos, batch_neg, optimizer):
    model.train()
    
    optimizer.zero_grad()
    
    batch_anchor = batch_anchor.to(device)
    batch_pos = batch_pos.to(device)
    batch_neg = batch_neg.to(device)

    emb_a = model(batch_anchor)
    emb_p = model(batch_pos)
    emb_n = model(batch_neg)
    
    emb_a_central = emb_a[NodeType.PUBLICATION.value][batch_anchor.central_node_id]
    emb_p_central = emb_p[NodeType.PUBLICATION.value][batch_pos.central_node_id]
    emb_n_central = emb_n[NodeType.PUBLICATION.value][batch_neg.central_node_id]
    
    loss = loss_fn(emb_a_central, emb_p_central, emb_n_central)
    
    loss.backward()
    optimizer.step()
        
    batch_loss = loss.item()
    #print(f"Batch loss: {batch_loss:.4f}")
    return batch_loss

def test(model, dataloader):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch_anchor, batch_pos, batch_neg in dataloader:
            batch_anchor = batch_anchor.to(device)
            batch_pos = batch_pos.to(device)
            batch_neg = batch_neg.to(device)
    
            emb_a = model(batch_anchor)
            emb_p = model(batch_pos)
            emb_n = model(batch_neg)
            
            emb_a_central = emb_a[NodeType.PUBLICATION.value][batch_anchor.central_node_id]
            emb_p_central = emb_p[NodeType.PUBLICATION.value][batch_pos.central_node_id]
            emb_n_central = emb_n[NodeType.PUBLICATION.value][batch_neg.central_node_id]
            
            loss = loss_fn(emb_a_central, emb_p_central, emb_n_central)
            total_loss += loss.item()
    
    # Compute average loss    
    avg_loss = total_loss / len(dataloader)
    print(f"Test Loss: {avg_loss:.4f}")
    return avg_loss


def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    total_num_correct = 0
    total_pos_correct = 0
    total_neg_correct = 0
    total_num_samples = 0
    with torch.no_grad():
        for batch_anchor, batch_pos, batch_neg in dataloader:
            batch_anchor = batch_anchor.to(device)
            batch_pos = batch_pos.to(device)
            batch_neg = batch_neg.to(device)

            emb_a = model(batch_anchor)
            emb_p = model(batch_pos)
            emb_n = model(batch_neg)

            emb_a_central = emb_a[NodeType.PUBLICATION.value][batch_anchor.central_node_id]
            emb_p_central = emb_p[NodeType.PUBLICATION.value][batch_pos.central_node_id]
            emb_n_central = emb_n[NodeType.PUBLICATION.value][batch_neg.central_node_id]

            # Compute loss
            loss = loss_fn(emb_a_central, emb_p_central, emb_n_central)
            total_loss += loss.item()

            # Compute distances
            d_ap = F.pairwise_distance(emb_a_central, emb_p_central)
            d_an = F.pairwise_distance(emb_a_central, emb_n_central)

            # Determine correct predictions based on margin
            correct_pos = (d_ap < config['margin']).cpu()
            correct_neg = (d_an > config['margin']).cpu()

            # Sum up correct predictions
            num_correct_pos = correct_pos.sum().item()
            num_correct_neg = correct_neg.sum().item()
            num_correct = num_correct_pos + num_correct_neg

            total_num_correct += num_correct
            total_pos_correct += num_correct_pos
            total_neg_correct += num_correct_neg
            total_num_samples += len(batch_anchor)

    # Compute averages
    avg_loss = total_loss / len(dataloader)
    avg_correct_pos = total_pos_correct / total_num_samples
    avg_correct_neg = total_neg_correct / total_num_samples
    avg_num_correct = total_num_correct / (2 * total_num_samples)  # Since we have two conditions

    print(f"Correct positive: {total_pos_correct} ({avg_correct_pos * 100:.2f}%), Correct negative: {total_neg_correct} ({avg_correct_neg * 100:.2f}%)")
    print(f"Total correct: {total_num_correct} ({avg_num_correct * 100:.2f}%)")
    print(f"Eval Loss: {avg_loss:.4f}, Eval Accuracy: {avg_num_correct:.4f}")

    return avg_loss, avg_num_correct, avg_correct_pos, avg_correct_neg
            

### Training Configuration

In [4]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config, save_file_postfix=save_file_postfix)


# Split the pairs into train and test

train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = GraphTripletDataset(train_triplets, gs, config=config)
test_dataset = GraphTripletDataset(test_triplets, gs, config=config)
eval_dataset = GraphTripletDataset(eval_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(metadata, config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-10-18 12:02:42,845 - DatabaseWrapper - INFO - Connecting to the database ...
2024-10-18 12:02:42,846 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 11755 triplets.
Train size: 9991, Test size: 1175, Eval size: 589


### Training Loop

In [5]:
num_epochs = config['num_epochs']
train_losses = []
test_losses = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if len(train_losses) % 10 == 0:
            print(f"___ Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            test_loss = test(model, test_dataloader)
            test_losses.append(test_loss)
            test_epoch_marker_pos = [marker/10 for marker in epoch_marker_pos if marker != 0]
            plot_loss(test_losses, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Loss', plot_avg=True, plot_file=f'./data/losses/test_loss_{save_file_postfix}.png')
            
            # Model evaluation
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(model, eval_dataloader)
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_marker_pos=test_epoch_marker_pos, plot_title='Evaluation Loss', plot_avg=True, plot_file=f'./data/losses/eval_loss_{save_file_postfix}.png')
            plot_loss(eval_accuracies, epoch_marker_pos=test_epoch_marker_pos, plot_title='Evaluation Accuracy', plot_avg=False, plot_file=f'./data/losses/eval_accuracy_{save_file_postfix}.png')
            
        loss = train(model, batch_anchor, batch_pos, batch_neg, optimizer)
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_marker_pos=epoch_marker_pos, plot_title='Training Loss', plot_avg=True, plot_file=f'./data/losses/train_loss_{save_file_postfix}.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, f'./data/results/training_results_{save_file_postfix}.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), f'./data/models/gat_encoder_{save_file_postfix}.pt')

Epoch 1/20:   0%|          | 0/313 [00:00<?, ?it/s]

Test Loss: 1.0108
Correct positive: 589 (100.00%), Correct negative: 0 (0.00%)
Total correct: 589 (50.00%)
Test Loss: 1.0078, Test Accuracy: 0.5000
[1.007752161276968]
[0.5]
[1.0]
Test Loss: 0.8915
Correct positive: 344 (58.40%), Correct negative: 238 (40.41%)
Total correct: 582 (49.41%)
Test Loss: 0.9695, Test Accuracy: 0.4941
[1.007752161276968, 0.9694569236353824]
[0.5, 0.4940577249575552]
[1.0, 0.5840407470288624]
Test Loss: 0.8325
Correct positive: 329 (55.86%), Correct negative: 368 (62.48%)
Total correct: 697 (59.17%)
Test Loss: 0.7815, Test Accuracy: 0.5917
[1.007752161276968, 0.9694569236353824, 0.7814873519696688]
[0.5, 0.4940577249575552, 0.5916808149405772]
[1.0, 0.5840407470288624, 0.5585738539898133]
Test Loss: 0.7824
Correct positive: 197 (33.45%), Correct negative: 426 (72.33%)
Total correct: 623 (52.89%)
Test Loss: 0.8853, Test Accuracy: 0.5289
[1.007752161276968, 0.9694569236353824, 0.7814873519696688, 0.8852982740653189]
[0.5, 0.4940577249575552, 0.5916808149405772, 

Epoch 2/20:   0%|          | 0/313 [00:00<?, ?it/s]

Test Loss: 0.4430
Correct positive: 193 (32.77%), Correct negative: 469 (79.63%)
Total correct: 662 (56.20%)
Test Loss: 0.9079, Test Accuracy: 0.5620
[1.007752161276968, 0.9694569236353824, 0.7814873519696688, 0.8852982740653189, 1.11994563278399, 0.9634218780617965, 0.908100143859261, 1.0203485394779004, 1.1037348822543496, 1.1150309945407666, 1.1107662288766158, 1.1828121982122723, 0.8873049208992406, 0.8147836531463423, 0.8038509237138849, 0.8274695339955782, 0.8420323933425703, 0.7430547726781744, 0.8054928999198111, 0.7464950194484309, 0.7356210272563132, 0.7588038773913133, 0.9531898780872947, 0.8453337158027449, 0.8082870511632216, 0.976501558956347, 0.78934509032651, 0.8008268679443159, 0.8779859338936052, 1.0290871168437756, 0.8407635077049858, 0.7673401675726238, 0.9079269923661885]
[0.5, 0.4940577249575552, 0.5916808149405772, 0.5288624787775892, 0.5008488964346349, 0.5161290322580645, 0.5178268251273345, 0.48047538200339557, 0.48217317487266553, 0.5076400679117148, 0.504244

Epoch 3/20:   0%|          | 0/313 [00:00<?, ?it/s]

Test Loss: 0.3245
Correct positive: 143 (24.28%), Correct negative: 516 (87.61%)
Total correct: 659 (55.94%)
Test Loss: 0.7065, Test Accuracy: 0.5594
[1.007752161276968, 0.9694569236353824, 0.7814873519696688, 0.8852982740653189, 1.11994563278399, 0.9634218780617965, 0.908100143859261, 1.0203485394779004, 1.1037348822543496, 1.1150309945407666, 1.1107662288766158, 1.1828121982122723, 0.8873049208992406, 0.8147836531463423, 0.8038509237138849, 0.8274695339955782, 0.8420323933425703, 0.7430547726781744, 0.8054928999198111, 0.7464950194484309, 0.7356210272563132, 0.7588038773913133, 0.9531898780872947, 0.8453337158027449, 0.8082870511632216, 0.976501558956347, 0.78934509032651, 0.8008268679443159, 0.8779859338936052, 1.0290871168437756, 0.8407635077049858, 0.7673401675726238, 0.9079269923661885, 0.9288153428780405, 1.01186330067484, 0.7917888682139548, 0.7789235679726851, 0.8634588373334784, 0.9404982516640111, 0.9623015350417087, 0.8267532869389183, 0.864992900898582, 0.8358168100055895,

KeyboardInterrupt: 