In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter
import torch
import torch_geometric as pyg
import torchmetrics
import tqdm
from pairing.data import PairData, Dataset, loader
import pairing.data
import copy

import scipy
import scipy.stats

import uuid

torch.manual_seed(42)

def make_sequential(num_layers, input_dim, output_dim, is_last=False):
    layers = []
    layers.append(
        torch.nn.Sequential(torch.nn.Linear(input_dim, output_dim),
                            torch.nn.ReLU(), torch.nn.Dropout(p=0)))
    while len(layers) < num_layers:
        layers.append(
            torch.nn.Sequential(torch.nn.Linear(output_dim, output_dim),
                                torch.nn.ReLU(), torch.nn.Dropout(p=0)))

    if is_last:
        if num_layers == 1:
            layers[-1] = torch.nn.Sequential(
                torch.nn.Linear(input_dim, output_dim))
        else:
            layers[-1] = torch.nn.Sequential(
                torch.nn.Linear(output_dim, output_dim))

    return torch.nn.Sequential(*layers)


class GCN(torch.nn.Module):
    def __init__(self, num_convs, num_linear, embedding_size, aggr_steps, architecture):
        super(GCN, self).__init__()

        self.layers = []
        self.task = "graph"

        self.pad = torch.nn.ZeroPad2d(
            (0, embedding_size - Dataset.num_node_features(), 0, 0))

        self.gcn = self.make_conv(architecture, num_linear, embedding_size)
        self.gcn.to(device)
        self.num_convs = num_convs

        self.architecture = architecture
        self.aggr_steps = aggr_steps
        self.readout = pyg.nn.aggr.Set2Set(embedding_size, aggr_steps)
        self.readout.to(device)

        self.post_mp = make_sequential(num_linear, 2 * embedding_size, embedding_size, is_last=True)
        self.post_mp.to(device)

    def make_conv(self, architecture, num_linear, embedding_size):
        if architecture == "GCN":
            return pyg.nn.GCNConv(embedding_size, embedding_size)
        elif architecture == "GIN":
            return pyg.nn.GINConv(make_sequential(num_linear, embedding_size, embedding_size))
        elif architecture == "NNConv":
            mpfn = make_sequential(1, Dataset.num_edge_features(), embedding_size ** 2)
            return pyg.nn.NNConv(embedding_size, embedding_size, mpfn)
        else:
            raise KeyError(f"Received invalid architecture = {architecture}.")

    def forward(self, x, edge_index, edge_attr, batch_index):
        x = self.pad(x)
        for _ in range(self.num_convs):
            if self.architecture == "NNConv":
                x = self.gcn(x, edge_index, edge_attr)
            else:
                x = self.gcn(x, edge_index)

        pooled = torch.cat([pyg.nn.pool.global_add_pool(x, batch_index), pyg.nn.pool.global_mean_pool(x, batch_index)], dim=1)
        if self.aggr_steps > 0:
            pooled = self.readout(x, index=batch_index)
        return self.post_mp(pooled)


class MixturePredictor(torch.nn.Module):
    def __init__(self, num_convs, num_linear, embedding_size, aggr_steps, architecture):
        super(MixturePredictor, self).__init__()

        self.gcn = GCN(num_convs, num_linear, embedding_size, aggr_steps, architecture)
        self.out = make_sequential(num_linear, 2 * embedding_size, 33, is_last=True)

    def forward(self, x_s, edge_index_s, edge_attr_s, x_s_batch, x_t, edge_index_t, edge_attr_t, x_t_batch, y=None):
        emb_s = self.gcn(x_s, edge_index_s, edge_attr_s, x_s_batch)
        emb_t = self.gcn(x_t, edge_index_t, edge_attr_t, x_t_batch)

        embedding = torch.cat([emb_s, emb_t], dim=1)
        return self.out(embedding)


def load_model(model_path):
    model = torch.load(model_path)
    model.eval()
    return model


def evaluate_model(model, test_loader):
    num_classes = Dataset.num_classes()
    auroc_scores = []

    # Initialize AUROC metrics for each label
    auroc_metrics = [torchmetrics.classification.BinaryAUROC() for _ in range(num_classes)]
    
    preds, targets = [], []

    for batch_data in test_loader:
        outputs = model(batch_data.x_s, batch_data.edge_index_s, batch_data.edge_attr_s, batch_data.x_s_batch,
                        batch_data.x_t, batch_data.edge_index_t, batch_data.edge_attr_t, batch_data.x_t_batch)
        preds.append(outputs)
        targets.append(batch_data.y)

    preds = torch.cat(preds, dim=0)
    targets = torch.cat(targets, dim=0)

    # Calculate AUROC for each label
    for i in range(num_classes):
        auroc_metrics[i].update(preds[:, i], targets[:, i].int())

    for i in range(num_classes):
        auroc_scores.append(auroc_metrics[i].compute().item())

    return auroc_scores


if __name__ == "__main__":
    device = "cpu"  # or "cuda" if GPU is available and model was trained on GPU
    test = Dataset(is_train=False)
    test_loader = loader(test, batch_size=164)  # Adjust batch size according to your system
    model = load_model("model_33.pt")
    model = model.to(device)
    
    auroc_scores = evaluate_model(model, test_loader)
    for i, score in enumerate(auroc_scores):
        print(f"AUROC Score for label {i}: {score}")


  from .autonotebook import tqdm as notebook_tqdm


AUROC Score for label 0: 0.9998523592948914
AUROC Score for label 1: 0.9998905062675476
AUROC Score for label 2: 0.9836894273757935
AUROC Score for label 3: 0.9636465311050415
AUROC Score for label 4: 0.9883670806884766
AUROC Score for label 5: 0.9198623299598694
AUROC Score for label 6: 0.9997385740280151
AUROC Score for label 7: 0.9995074272155762
AUROC Score for label 8: 0.994845986366272
AUROC Score for label 9: 0.9991586804389954
AUROC Score for label 10: 0.32822728157043457
AUROC Score for label 11: 0.6136036515235901
AUROC Score for label 12: 0.8472949862480164
AUROC Score for label 13: 0.653875470161438
AUROC Score for label 14: 0.6470919251441956
AUROC Score for label 15: 0.6985824704170227
AUROC Score for label 16: 0.6534140110015869
AUROC Score for label 17: 0.47521963715553284
AUROC Score for label 18: 0.6675715446472168
AUROC Score for label 19: 0.651344895362854
AUROC Score for label 20: 0.4697982966899872
AUROC Score for label 21: 0.37342214584350586
AUROC Score for labe