In [27]:
import os
import json
from time import sleep
from typing import Any

from sklearn.metrics import accuracy_score
from torch.distributions.constraints import positive
from torch_geometric.nn.models.dimenet import triplets
from tqdm.notebook import tqdm
import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv, global_mean_pool, HeteroConv
from torch_geometric.transforms import RandomNodeSplit
from torch_geometric.data import HeteroData
import random
import numpy as np
from torch.nn.modules.loss import TripletMarginLoss
import matplotlib.pyplot as plt
from torch_geometric.data import Batch
from torch.optim import Adam
from src.shared.database_wrapper import DatabaseWrapper
from src.datasets.who_is_who import WhoIsWhoDataset
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import *
from src.model.loss.triplet_loss import TripletLoss
from src.shared import config
from torch.utils.data import Dataset
import networkx as nx
import plotly.graph_objects as go
from itertools import combinations, product
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

## Configurations

In [28]:
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    EdgeType.SIM_TITLE,
    EdgeType.SIM_ABSTRACT,
    EdgeType.SIM_VENUE,
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'title',
    'abstract',
    'venue',
    'title_emb',
    'abstract_emb',
    'venue_emb',
    'true_author',
]

gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database='homogeneous-graph',
)

# Model configurations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = {
    'experiment': 'GATv2 encoder trained on graph (publication nodes, similarity and co-author edges) using Pairwise Contrastive Loss',
    'max_hops': 2,
    'model_node_feature': 'abstract_emb',  # Node feature to use for GAT encoder
    'hidden_channels': 128,
    'out_channels': 32,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 5,
}

## Graph Pair Data Handling

In [29]:
class GraphPairDataset(Dataset):
    def __init__(self, pairs, gs):
        self.pairs = pairs  # List of tuples: (paper_id1, paper_id2, label)
        self.gs = gs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        paper_id1, paper_id2, label = self.pairs[idx]
        try:
            #print(f"Processing pair ({paper_id1}, {paper_id2})")
            # Get n-hop neighbourhood for each paper
            graph1 = self.gs.n_hop_neighbourhood(NodeType.PUBLICATION, paper_id1, max_level=config['max_hops'])
            graph2 = self.gs.n_hop_neighbourhood(NodeType.PUBLICATION, paper_id2, max_level=config['max_hops'])

            # Convert to PyG Data objects
            data1, node_map_1 = neo_to_pyg_hetero_edges(graph1, config['model_node_feature'])
            data1.central_node_id = torch.tensor([node_map_1[paper_id1]])
            
            data2, node_map_2 = neo_to_pyg_hetero_edges(graph2, config['model_node_feature'])
            data2.central_node_id = torch.tensor([node_map_2[paper_id2]])
            
            # Return data and label
            return data1, data2, torch.tensor(label, dtype=torch.float)
        except Exception as e:
            print(f"Error processing pair ({paper_id1}, {paper_id2}): {e}")
            return None
        

# This is required for the PyG DataLoader in order to handle the custom mini-batching during training 
class PairData(HeteroData):
    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == 'central_node_id':
            return 0  # Concat along batch dim
        else:
            return super().__cat_dim__(key, value, *args, **kwargs)
        
    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
        if key == 'central_node_id':
            return self.num_nodes
        else:
            return super().__inc__(key, value, *args, **kwargs)

In [30]:
def graph_data_valid(data: Data):
    try:
        node_type_val = NodeType.PUBLICATION.value
        assert data is not None, "Data object is None."
        assert data.num_nodes > 0, "Number of nodes must be greater than 0."
        assert data.num_edges > 0, "Number of edges must be greater than 0."
        assert data[node_type_val].x is not None, "Node features 'x' are missing."
        assert data[node_type_val].x.size(0) == data.num_nodes, "Mismatch between 'x' size and 'num_nodes'."
        assert data[node_type_val].x.dtype in (torch.float32, torch.float64), "Node features 'x' must be floating point."
        for key in [edge_pyg_key_vals[r] for r in edge_spec]:
            if key not in data:
                continue
            assert data[key].edge_index.size(0) == 2, f"'edge_index' for '{key}' should have shape [2, num_edges]."
            assert data[key].edge_index.size(1) == data[key].num_edges, f"Mismatch between 'edge_index' and 'num_edges' for '{key}'."
            assert data[key].edge_index is not None, f"Edge index for '{key}' is missing."
            assert data[key].edge_index.max() < data.num_nodes, f"'edge_index' for '{key}' contains invalid node indices."
        return True
    except AssertionError as e:
        print(f"Data check failed: {e}")
        return False

In [31]:
def neo_to_pyg_hetero_edges(
        data,
        node_attr: str,
):
    if not data:
        return None, None

    nodes = data["nodes"]
    relationships = data["relationships"]

    #print(f"Nodes: {len(nodes)}, Relationships: {len(relationships)}")

    # Create a PyG Data object
    pyg_data = PairData()

    node_features = []
    node_ids = []
    node_id_map = {}

    for node in nodes:
        node_id = node.get("id")
        node_feature = node.get(node_attr, None)
        if node_feature is None:
            print(f"Node {node_id} has no attribute {node_attr}")
            continue

        # Map node id to its index in the list
        idx = len(node_ids)
        node_id_map[node_id] = idx
        node_ids.append(node_id)

        # Convert node features to tensors
        node_feature_tensor = torch.tensor(node_feature, dtype=torch.float32)
        node_features.append(node_feature_tensor)

    # Convert list of features to tensor
    if node_features:
        pyg_data[NodeType.PUBLICATION.value].x = torch.vstack(node_features)
        pyg_data[NodeType.PUBLICATION.value].num_nodes = pyg_data[NodeType.PUBLICATION.value].x.size(0)
    else:
        print("No node features available.")
        return None, None

    # Process relationships
    edge_dict = {}

    for rel in relationships:
        key = edge_val_to_pyg_key_vals[rel.type]
        if key not in edge_dict:
            edge_dict[key] = [[], []]

        source_id = rel.start_node.get("id")
        target_id = rel.end_node.get("id")

        # Append the indices of the source and target nodes
        edge_dict[key][0].append(node_id_map[source_id])
        edge_dict[key][1].append(node_id_map[target_id])

    # Convert edge lists to tensors
    for key in edge_dict:
        pyg_data[key[0], key[1], key[2]].edge_index = torch.vstack([
            torch.tensor(edge_dict[key][0], dtype=torch.long),
            torch.tensor(edge_dict[key][1], dtype=torch.long)
        ])

        pyg_data[key[0], key[1], key[2]].edge_attr = torch.vstack(
            [edge_one_hot[key[1]] for _ in range(len(edge_dict[key][0]))])

    return pyg_data, node_id_map

### Harvest Positive and Negative tuples from the graph database

In [32]:
class DataHarvester:
    def __init__(self, db: DatabaseWrapper, gs: GraphSampling):
        self.db = db
        self.gs = gs
        self.pairs = []
        self.prepare_pairs()

    def prepare_pairs(self):
        print("Preparing pairs...")
        file_path = './data/hetero-pairs.json'
        
        try:
            print("Loading pairs...")
            self.load_pairs(file_path)
            print(f"Loaded {len(self.pairs)} pairs.")
        except FileNotFoundError:
            print("Could not load pairs from file. Generating pairs...")
            self.generate_pairs()
            print(f"Generated {len(self.pairs)} pairs.")
            print("Saving pairs...")
            self.save_pairs(file_path)
            print("Pairs saved.")
            
    def load_pairs(self, file_path):
        with open(file_path, 'r') as f:
            self.pairs = json.load(f)
    
    def save_pairs(self, file_path):
        with open(file_path, 'w') as f:
            json.dump(self.pairs, f)
                
    def generate_pairs(self):
        # Filter out the papers that are not present in the graph or have less than 2 edges
        paper_ids = []
        print("Checking data validity...")
        for nodes in self.db.iter_nodes_with_edge_count(NodeType.PUBLICATION, edge_spec, ['id', 'true_author']):
            for node in nodes:
                data = gs.n_hop_neighbourhood(NodeType.PUBLICATION, node['id'], max_level=1)
                data = neo_to_pyg_hetero_edges(data, config['model_node_feature'])[0]
                if not graph_data_valid(data):
                    continue
                paper_ids.append(node['id'])
        
        print(f"Total papers: {len(paper_ids)}")
        print("Preparing pairs...")
        paper_set = set(paper_ids)
        
        author_data = WhoIsWhoDataset.parse_train()
        for author_id, data in author_data.items():
            for key in data:
                data[key] = [p_id for p_id in data[key] if p_id in paper_set]
        
        # Generate pairs with labels
        pairs = []
        
        for author_id, data in author_data.items():
            normal_data = data.get('normal_data', [])
            outliers = data.get('outliers', [])
                    
            # Positive pairs: combinations of normal_data
            pos_pairs = list(combinations(normal_data, 2))
            if len(pos_pairs) > 50:
                pos_pairs = random.sample(pos_pairs, 50)
            for pair in pos_pairs:
                pairs.append((pair[0], pair[1], 1))
            
            # Negative pairs: product of normal_data and outliers
            neg_pairs = list(product(normal_data, outliers))
            if len(neg_pairs) > 50:
                neg_pairs = random.sample(neg_pairs, 50)
            elif len(neg_pairs) < len(pos_pairs):
                # Sample random paper ids from other authors
                while len(neg_pairs) < len(pos_pairs):
                    p1 = random.choice(normal_data)
                    p2 = random.choice(paper_ids)
                    if p2 not in normal_data:
                        neg_pairs.append((p1, p2))
            for pair in neg_pairs:
                pairs.append((pair[0], pair[1], 0))
        
        print(f"Total pairs: {len(pairs)}. Done.")
        self.pairs = pairs
                

## GAT Encoder

In [33]:

class HeteroGATEncoder(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, num_heads=8):
        super().__init__()

        self.conv1 = HeteroConv({
            edge_type: GATv2Conv(
                (-1, -1), hidden_channels, heads=num_heads, concat=True)
            for edge_type in metadata[1]
        }, aggr='sum')

        self.conv2 = HeteroConv({
            edge_type: GATv2Conv(
                (-1, -1), out_channels, heads=1, concat=False)
            for edge_type in metadata[1]
        }, aggr='sum')

    def forward(self, data):
        x_dict = data.x_dict
        edge_index_dict = data.edge_index_dict
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.elu(x) for key, x in x_dict.items()}

        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

In [34]:
def contrastive_loss(embeddings1, embeddings2, labels, margin=1.0):
    # Compute Euclidean distances between embeddings
    distances = F.pairwise_distance(embeddings1, embeddings2)
    
    # Loss
    loss_pos = labels * distances.pow(2)  # For positive pairs
    loss_neg = (1 - labels) * F.relu(margin - distances).pow(2)  # For negative pairs
    loss = loss_pos + loss_neg
    return loss.mean(), distances


In [35]:
# custom collate function adjusted for GraphPairDataset
def custom_collate(batch):
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  # Skip empty batches

    data1_list = [item[0] for item in batch]
    data2_list = [item[1] for item in batch]
    
    labels = torch.stack([item[2] for item in batch])
    
    batch1 = Batch.from_data_list(data1_list)
    batch2 = Batch.from_data_list(data2_list)

    return batch1, batch2, labels

In [36]:
# Plot loss
def plot_loss(losses, epoch_marker_pos=None, plot_title="Loss", plot_file=None):
    if plot_file is None:
        plot_file = f'./data/losses/loss.png'
    if epoch_marker_pos is None:
        epoch_marker_pos = []
        
    plt.figure(figsize=(10, 6))
    plt.plot(losses, label=f'Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    
    for ix, x_pos in enumerate(epoch_marker_pos):
        plt.axvline(x=x_pos, color='red', linestyle='dotted', linewidth=1)
        plt.text(
        x_pos,
        max(losses),
        f'Epoch {ix}',
        rotation=90,
        verticalalignment='top',
        horizontalalignment='right',
        fontsize=10,
        color='red'
    )

    plt.title(plot_title)
    plt.legend()
    plt.grid(True)
    plt.savefig(plot_file)
    plt.close()
    
def save_training_results(train_loss, test_loss, eval_results, config, file_path):
    results = {
        'train_loss': train_loss,
        'test_loss': test_loss,
        'eval_results': eval_results,
        'config': config,
    }
    with open(file_path, 'w') as f:
        json.dump(results, f, indent=4)

In [37]:
def train(model, batch1, batch2, labels, optimizer):
    model.train()
    
    optimizer.zero_grad()
    
    batch1 = batch1.to(device)
    batch2 = batch2.to(device)
    labels = labels.to(device)

    embeddings1 = model(batch1)
    embeddings2 = model(batch2)

    embeddings1_central = embeddings1[NodeType.PUBLICATION.value][batch1.central_node_id]
    embeddings2_central = embeddings2[NodeType.PUBLICATION.value][batch2.central_node_id]

    loss, _ = contrastive_loss(embeddings1_central, embeddings2_central, labels, config['margin'])
    loss.backward()
    optimizer.step()
        
    batch_loss = loss.item()
    #print(f"Batch loss: {batch_loss:.4f}")
    return batch_loss

def test(model, dataloader):
    model.eval()
    total_loss = 0
    distances = []
    labels_list = []
    with torch.no_grad():
        for batch1, batch2, labels in dataloader:
            batch1 = batch1.to(device)
            batch2 = batch2.to(device)
            labels = labels.to(device)
    
            embeddings1 = model(batch1)
            embeddings2 = model(batch2)
            
            embeddings1_central = embeddings1[NodeType.PUBLICATION.value][batch1.central_node_id]
            embeddings2_central = embeddings2[NodeType.PUBLICATION.value][batch2.central_node_id]
            
            loss, dist = contrastive_loss(embeddings1_central, embeddings2_central, labels, config['margin'])
            
            total_loss += loss.item()
            distances.extend(dist.cpu().numpy())
            labels_list.extend(labels.cpu().numpy())
        
    # Compute accuracy
    distances = np.array(distances)
    labels_list = np.array(labels_list).astype(int)
    predictions = (distances <= config['margin']).astype(int)
    accuracy = accuracy_score(labels_list, predictions)
    
    # Compute average loss    
    avg_loss = total_loss / len(dataloader)
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy
    

## Training 

In [38]:
db = DatabaseWrapper(database='homogeneous-graph')
data_harvester = DataHarvester(db=db, gs=gs)


# Split the pairs into train and test
random.shuffle(data_harvester.pairs)
train_size = int(0.95 * len(data_harvester.pairs))
train_pairs = data_harvester.pairs[:train_size]
test_pairs = data_harvester.pairs[train_size:]
config['train_size'] = train_size
config['test_size'] = len(data_harvester.pairs) - train_size

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = GraphPairDataset(train_pairs, gs)
test_dataset = GraphPairDataset(test_pairs, gs)

# Create the DataLoader
batch_size = 32
config['batch_size'] = batch_size

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

# Create model
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = HeteroGATEncoder(metadata, config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-10-13 12:18:55,147 - DatabaseWrapper - INFO - Connecting to the database ...
2024-10-13 12:18:55,147 - DatabaseWrapper - INFO - Database ready.


Preparing pairs...
Loading pairs...
Loaded 6308 pairs.


In [39]:
num_epochs = config['num_epochs']
margin = 1.0  # Margin for contrastive loss
train_losses = []
test_losses = []
test_accuracies = []

for epoch in range(1, num_epochs + 1):
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    
    for data1, data2, labels in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if data1 is None or data2 is None:
            continue
        
        if len(train_losses) % 10 == 0:
            test_loss, test_accuracy = test(model, test_dataloader)
            test_losses.append(test_loss)
            test_accuracies.append(test_accuracy)
            test_epoch_marker_pos = [marker/10 for marker in epoch_marker_pos if marker != 0]
            plot_loss(test_losses, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Loss', plot_file=f'./data/losses/test_loss_hetero_edges.png')
            plot_loss(test_accuracies, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Accuracy', plot_file=f'./data/losses/test_accuracy_hetero_edges.png')
            
        loss = train(model, data1, data2, labels, optimizer)
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_marker_pos=epoch_marker_pos, plot_title='Training Loss', plot_file=f'./data/losses/train_loss_hetero_edges.png')
    
    # Save config and training results
    save_training_results(train_losses, test_losses, None, config, f'./data/results/training_results_hetero_edges.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), f'./data/models/gat_encoder_hetero_edges.pt')

Epoch 1/5:   0%|          | 0/188 [00:00<?, ?it/s]

Test Loss: 0.7251, Test Accuracy: 0.5190
Test Loss: 9.3339, Test Accuracy: 0.4937
Test Loss: 10.0982, Test Accuracy: 0.4905
Test Loss: 6.3531, Test Accuracy: 0.4937
Test Loss: 4.9800, Test Accuracy: 0.4937
Test Loss: 3.4738, Test Accuracy: 0.4905
Test Loss: 3.1983, Test Accuracy: 0.4905
Test Loss: 1.9555, Test Accuracy: 0.4905
Test Loss: 1.5979, Test Accuracy: 0.4968
Test Loss: 1.7625, Test Accuracy: 0.5095
Test Loss: 1.1818, Test Accuracy: 0.4842
Test Loss: 1.0938, Test Accuracy: 0.5285
Test Loss: 0.8690, Test Accuracy: 0.5222
Test Loss: 0.7008, Test Accuracy: 0.5095
Test Loss: 0.6986, Test Accuracy: 0.5095
Test Loss: 0.5995, Test Accuracy: 0.4905
Test Loss: 0.5819, Test Accuracy: 0.5380
Test Loss: 0.5540, Test Accuracy: 0.5475
Test Loss: 0.5174, Test Accuracy: 0.5032
Saving model at epoch 1...


Epoch 2/5:   0%|          | 0/188 [00:00<?, ?it/s]

Test Loss: 0.4558, Test Accuracy: 0.5158
Test Loss: 0.4119, Test Accuracy: 0.5285
Test Loss: 0.3697, Test Accuracy: 0.5316
Test Loss: 0.3957, Test Accuracy: 0.5570
Test Loss: 0.3494, Test Accuracy: 0.5411
Test Loss: 0.3332, Test Accuracy: 0.5127
Test Loss: 0.3211, Test Accuracy: 0.5095
Test Loss: 0.3388, Test Accuracy: 0.5348
Test Loss: 0.3266, Test Accuracy: 0.5063
Test Loss: 0.3225, Test Accuracy: 0.5411
Test Loss: 0.3020, Test Accuracy: 0.5222
Test Loss: 0.3082, Test Accuracy: 0.5253
Test Loss: 0.2852, Test Accuracy: 0.5095
Test Loss: 0.2947, Test Accuracy: 0.5063
Test Loss: 0.2996, Test Accuracy: 0.5348
Test Loss: 0.3336, Test Accuracy: 0.5316
Test Loss: 0.3486, Test Accuracy: 0.5158
Test Loss: 0.3105, Test Accuracy: 0.5316
Test Loss: 0.3127, Test Accuracy: 0.5063


Epoch 3/5:   0%|          | 0/188 [00:00<?, ?it/s]

Test Loss: 0.3147, Test Accuracy: 0.5127
Test Loss: 0.2967, Test Accuracy: 0.5222
Test Loss: 0.2954, Test Accuracy: 0.5158
Test Loss: 0.2859, Test Accuracy: 0.5222
Test Loss: 0.2797, Test Accuracy: 0.5222
Test Loss: 0.3085, Test Accuracy: 0.5190
Test Loss: 0.3040, Test Accuracy: 0.5316
Test Loss: 0.3363, Test Accuracy: 0.5348
Test Loss: 0.2772, Test Accuracy: 0.5127
Test Loss: 0.3070, Test Accuracy: 0.5380
Test Loss: 0.2937, Test Accuracy: 0.5348
Test Loss: 0.2941, Test Accuracy: 0.5222
Test Loss: 0.2558, Test Accuracy: 0.5127
Test Loss: 0.2749, Test Accuracy: 0.5222
Test Loss: 0.3039, Test Accuracy: 0.5411
Test Loss: 0.2807, Test Accuracy: 0.5190
Test Loss: 0.2795, Test Accuracy: 0.5222
Test Loss: 0.2690, Test Accuracy: 0.5095
Test Loss: 0.2752, Test Accuracy: 0.5190


Epoch 4/5:   0%|          | 0/188 [00:00<?, ?it/s]

Test Loss: 0.3052, Test Accuracy: 0.5158
Test Loss: 0.3002, Test Accuracy: 0.5222
Test Loss: 0.3297, Test Accuracy: 0.5380
Test Loss: 0.3688, Test Accuracy: 0.4937
Test Loss: 0.2641, Test Accuracy: 0.5000
Test Loss: 0.3141, Test Accuracy: 0.5316
Test Loss: 0.3519, Test Accuracy: 0.5316
Test Loss: 0.2671, Test Accuracy: 0.5285
Test Loss: 0.2619, Test Accuracy: 0.5285
Test Loss: 0.2572, Test Accuracy: 0.5285
Test Loss: 0.2551, Test Accuracy: 0.5348
Test Loss: 0.2849, Test Accuracy: 0.5411
Test Loss: 0.2720, Test Accuracy: 0.5285
Test Loss: 0.2895, Test Accuracy: 0.5222
Test Loss: 0.2771, Test Accuracy: 0.5063
Test Loss: 0.2683, Test Accuracy: 0.5032
Test Loss: 0.2842, Test Accuracy: 0.5158
Test Loss: 0.2589, Test Accuracy: 0.5032
Test Loss: 0.2677, Test Accuracy: 0.5158


Epoch 5/5:   0%|          | 0/188 [00:00<?, ?it/s]

Test Loss: 0.2700, Test Accuracy: 0.5222
Test Loss: 0.2864, Test Accuracy: 0.5348
Test Loss: 0.2776, Test Accuracy: 0.5506
Test Loss: 0.2528, Test Accuracy: 0.5285
Test Loss: 0.2785, Test Accuracy: 0.5411
Test Loss: 0.2705, Test Accuracy: 0.5190
Test Loss: 0.2642, Test Accuracy: 0.5411
Test Loss: 0.2606, Test Accuracy: 0.5253
Test Loss: 0.2521, Test Accuracy: 0.5158
Test Loss: 0.2707, Test Accuracy: 0.5190
Test Loss: 0.2869, Test Accuracy: 0.5190
Test Loss: 0.2629, Test Accuracy: 0.5316
Test Loss: 0.2410, Test Accuracy: 0.5158
Test Loss: 0.2599, Test Accuracy: 0.5063
Test Loss: 0.2650, Test Accuracy: 0.5348
Test Loss: 0.2596, Test Accuracy: 0.5222
Test Loss: 0.3143, Test Accuracy: 0.5253
Test Loss: 0.2565, Test Accuracy: 0.5222


In [40]:
# Evaluation function
def evaluate(model, dataloader):
    model.eval()
    distances = []
    labels_list = []
    with torch.no_grad():
        for batch1, batch2, labels in dataloader:
            batch1 = batch1.to(device)
            batch2 = batch2.to(device)
            embeddings1 = model(batch1)
            embeddings2 = model(batch2)
            dist = F.pairwise_distance(embeddings1, embeddings2).cpu().numpy()
            distances.extend(dist)
            labels_list.extend(labels.numpy())
    return distances, labels_list

# After training
from sklearn.metrics import roc_auc_score

distances, labels_list = evaluate(model, dataloader)
roc_auc = roc_auc_score(labels_list, -np.array(distances))  # Negative distances for similarity
print(f"ROC AUC Score: {roc_auc:.4f}")


NameError: name 'dataloader' is not defined