In [10]:
from notebooks.util import GraphTripletDataset
from util import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

### Configurations

In [11]:
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    #EdgeType.SIM_VENUE,
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'title',
    'abstract',
    'title_emb',
    'abstract_emb',
    'feature_vec',
]

database = 'homogeneous-graph-compressed-emb'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations

config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on graph (publication nodes with title and abstract, similarity and co-author edges) using Pairwise Contrastive Loss and dimension reduced embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 20,
    'batch_size': 32,
}

model_class = HeteroGATEncoderLinear
loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

save_file_postfix = "triplets_hetero_edges_compressed_emb_linear_layer_dropout"

### Training functions

In [12]:
def train(model, batch_anchor, batch_pos, batch_neg, optimizer):
    model.train()
    
    optimizer.zero_grad()
    
    batch_anchor = batch_anchor.to(device)
    batch_pos = batch_pos.to(device)
    batch_neg = batch_neg.to(device)

    emb_a = model(batch_anchor)
    emb_p = model(batch_pos)
    emb_n = model(batch_neg)
    
    emb_a_central = emb_a[NodeType.PUBLICATION.value][batch_anchor.central_node_id]
    emb_p_central = emb_p[NodeType.PUBLICATION.value][batch_pos.central_node_id]
    emb_n_central = emb_n[NodeType.PUBLICATION.value][batch_neg.central_node_id]
    
    loss = loss_fn(emb_a_central, emb_p_central, emb_n_central)
    
    loss.backward()
    optimizer.step()
        
    batch_loss = loss.item()
    #print(f"Batch loss: {batch_loss:.4f}")
    return batch_loss

def test(model, dataloader):
    model.eval()
    total_loss = 0
    distances = []
    
    with torch.no_grad():
        for batch_anchor, batch_pos, batch_neg in dataloader:
            batch_anchor = batch_anchor.to(device)
            batch_pos = batch_pos.to(device)
            batch_neg = batch_neg.to(device)
    
            emb_a = model(batch_anchor)
            emb_p = model(batch_pos)
            emb_n = model(batch_neg)
            
            emb_a_central = emb_a[NodeType.PUBLICATION.value][batch_anchor.central_node_id]
            emb_p_central = emb_p[NodeType.PUBLICATION.value][batch_pos.central_node_id]
            emb_n_central = emb_n[NodeType.PUBLICATION.value][batch_neg.central_node_id]
            
            loss = loss_fn(emb_a_central, emb_p_central, emb_n_central)
            #loss, dist = contrastive_loss(embeddings1_central, embeddings2_central, labels, config['margin'])
            total_loss += loss.item()
            #distances.extend(dist.cpu().numpy())
            #labels_list.extend(labels.cpu().numpy())
        
    # Compute accuracy
    # distances = np.array(distances)
    #labels_list = np.array(labels_list).astype(int)
    #predictions = (distances <= config['margin']).astype(int)
    #accuracy = accuracy_score(labels_list, predictions)
    
    # Compute average loss    
    avg_loss = total_loss / len(dataloader)
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {0:.4f}")
    return avg_loss, 0 #accuracy
    

### Training Configuration

In [13]:
db = DatabaseWrapper(database=database)
data_harvester = TripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config, save_file_postfix=save_file_postfix)


# Split the pairs into train and test
random.shuffle(data_harvester.triplets)
train_size = int(0.95 * len(data_harvester.triplets))
train_triplets = data_harvester.triplets[:train_size]
test_triplets = data_harvester.triplets[train_size:]
config['train_size'] = train_size
config['test_size'] = len(data_harvester.triplets) - train_size

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = GraphTripletDataset(train_triplets, gs, config=config)
test_dataset = GraphTripletDataset(test_triplets, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create model
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(metadata, config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-10-16 15:42:06,715 - DatabaseWrapper - INFO - Connecting to the database ...
2024-10-16 15:42:06,716 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 11755 triplets.


### Training Loop

In [None]:
num_epochs = config['num_epochs']
train_losses = []
test_losses = []
test_accuracies = []

for epoch in range(1, num_epochs + 1):
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if len(train_losses) % 10 == 0:
            test_loss, test_accuracy = test(model, test_dataloader)
            test_losses.append(test_loss)
            test_accuracies.append(test_accuracy)
            test_epoch_marker_pos = [marker/10 for marker in epoch_marker_pos if marker != 0]
            plot_loss(test_losses, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Loss', plot_avg=True, plot_file=f'./data/losses/test_loss_{save_file_postfix}.png')
            plot_loss(test_accuracies, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Accuracy', plot_avg=True, plot_file=f'./data/losses/test_accuracy_{save_file_postfix}.png')
            
        loss = train(model, batch_anchor, batch_pos, batch_neg, optimizer)
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_marker_pos=epoch_marker_pos, plot_title='Training Loss', plot_file=f'./data/losses/train_loss_{save_file_postfix}.png')
    
    # Save config and training results
    save_training_results(train_losses, test_losses, None, config, f'./data/results/training_results_{save_file_postfix}.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), f'./data/models/gat_encoder_{save_file_postfix}.pt')

Epoch 1/20:   0%|          | 0/349 [00:00<?, ?it/s]

Test Loss: 1.0024, Test Accuracy: 0.0000
Test Loss: 0.8456, Test Accuracy: 0.0000
