In [1]:
import os
import torch
print("Using torch", torch.__version__)

Using torch 2.1.0+cu118


In [2]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
!pip install pyg-library

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu118.html


In [3]:
from torch_geometric.data import Data
from torch_geometric.datasets import MovieLens100K
from torch_geometric import nn
import torch_geometric.transforms as T

In [4]:
dataset = MovieLens100K(root='/tmp/movielens')

In [5]:
movielens_raw = dataset[0]
movielens_raw

HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
    edge_label_index=[2, 20000],
    edge_label=[20000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  }
)

In [6]:
del movielens_raw[("user", "rates", "movie")].edge_label_index
del movielens_raw[("user", "rates", "movie")].edge_label

In [7]:
movielens_raw

HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  }
)

In [8]:
node_types, edge_types = movielens_raw.metadata()
print(node_types)
print(edge_types)

['movie', 'user']
[('user', 'rates', 'movie'), ('movie', 'rated_by', 'user')]


In [9]:
transform = T.Compose([
    T.RandomLinkSplit(
        num_val=0.05,
        num_test=0.1,
        disjoint_train_ratio=0.2,   # supervision
        add_negative_train_samples=False,
        neg_sampling_ratio=1.0,
        edge_types=("user", "rates", "movie"),
        rev_edge_types=('movie', 'rated_by', 'user')
    )
])

In [10]:
train_data, val_data, test_data = transform(movielens_raw)

In [11]:
def user_importance(num_users, edge_index):
    """
    Calculates and normalizes the importance score for each user based on the average number of edges they are involved in.

    Parameters:
    num_users (int): The number of users.
    edge_index (Tensor): 2D tensor of edge indices, where the first element of each pair is a user label.

    Returns:
    Tensor: 1D tensor of importance scores for each user.
    """
    # Create a range of user labels based on the number of users
    user_labels = torch.arange(num_users)

    # Count the edges for each user
    edge_counts = torch.zeros(num_users, dtype=torch.float32)
    for user in edge_index[0, :]:
        if user < num_users:
            edge_counts[user] += 1


    # Calculate the importance score for each user
    # Users closer to the average have a higher score
    importance_scores = 1 / (1 + torch.abs(edge_counts - torch.mean(edge_counts))/torch.std(edge_counts))


    return importance_scores

user_score=user_importance(movielens_raw["user"].num_nodes, movielens_raw[("user", "rates", "movie")].edge_index)

In [12]:
print(user_score)

tensor([0.6256, 0.6515, 0.5959, 0.5419, 0.9315, 0.7691, 0.3954, 0.6045, 0.5350,
        0.9014, 0.8061, 0.5834, 0.2253, 0.6566, 0.7714, 0.8411, 0.5600, 0.5305,
        0.5283, 0.5875, 0.8918, 0.8496, 0.9636, 0.6566, 0.6566, 0.8165, 0.5283,
        0.6464, 0.5527, 0.5834, 0.5754, 0.5715, 0.5419, 0.5283, 0.5527, 0.5250,
        0.6366, 0.7933, 0.5385, 0.5715, 0.6089, 0.9636, 0.7552, 0.9454, 0.6002,
        0.5316, 0.5490, 0.6178, 0.7908, 0.5385, 0.5490, 0.6271, 0.5490, 0.6045,
        0.5283, 0.9212, 0.7310, 0.9454, 0.3917, 0.7104, 0.5350, 0.6871, 0.6947,
        0.7762, 0.6947, 0.5676, 0.5350, 0.5563, 0.6415, 0.9145, 0.5715, 0.8763,
        0.6224, 0.5676, 0.6724, 0.6778, 0.6464, 0.5419, 0.5917, 0.5419, 0.6045,
        0.9863, 0.9246, 0.6670, 0.5305, 0.5316, 0.7485, 0.5316, 0.6515, 0.5142,
        0.7064, 0.4129, 0.5250, 0.3810, 0.5373, 0.5959, 0.6271, 0.5490, 0.7933,
        0.6224, 0.6415, 0.7552, 0.5455, 0.7933, 0.5490, 0.5959, 0.5385, 0.5527,
        0.6303, 0.9673, 0.5350, 0.5794, 

In [13]:
def edge_importance(user_importance, edge_index,labels):
    """
    Calculates the importance score for each edge based on the user importance scores.

    Parameters:
    user_importance_scores (Tensor): 1D tensor of user importance scores.
    edge_index (Tensor): 2D tensor of edge indices with shape [2, num_edges].

    Returns:
    Tensor: 1D tensor of importance scores for each edge.
    """
    # Number of edges is the size of the second dimension of edge_index
    num_edges = edge_index.size(1)

    # Initialize a tensor to hold the importance scores for each edge
    edge_scores = torch.zeros(num_edges, dtype=torch.float32)

    # Assign each edge the importance score of the user it involves
    for i in range(num_edges):
        if(labels[i]==1):
            user = edge_index[0, i]
            edge_scores[i] = user_importance[user]
        else:
            edge_scores[i] = user_importance.max()



    return edge_scores

In [14]:
import sys
import torch_geometric
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv
from torch_geometric.nn import BatchNorm, LayerNorm, HeteroBatchNorm, HeteroLayerNorm
from torch_geometric.nn import to_hetero

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output = self.conv3(output2, edge_index)
        return output

class SAGE_RES(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)
        self.res = torch.nn.Linear(in_channels, out_channels)
        # self.norm1 = BatchNorm(hidden_channels, 2)
        # self.norm2 = BatchNorm(hidden_channels, 2)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        # output1 = self.norm1(output1)
        output2 = self.relu(self.conv2(output1, edge_index))
        # output2 = self.norm2(output2)
        output3 = self.conv3(output2, edge_index)
        output_res = self.res(node_feature)
        return (output3 + output_res) * 0.5

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = GATConv(in_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv(hidden_channels, out_channels, add_self_loops=False)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output = self.conv3(output2, edge_index)
        return output

class GAT_RES(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = GATConv(in_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv(hidden_channels, out_channels, add_self_loops=False)
        self.res = torch.nn.Linear(in_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output3 = self.conv3(output2, edge_index)
        output_res = self.res(node_feature)
        return (output3 + output_res) * 0.5

class Embedder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # [0]: movie, feature dim: 18
        # [1]: user, feature dim: 24
        movies = movielens_raw[node_types[0]]
        users = movielens_raw[node_types[1]]
        self.sage = SAGE(in_channels, hidden_channels, out_channels)
        self.sage = to_hetero(self.sage, metadata=movielens_raw.metadata())
        self.sage_res = SAGE_RES(in_channels, hidden_channels, out_channels)
        self.sage_res = to_hetero(self.sage_res, metadata=movielens_raw.metadata())
        self.gat = GAT(in_channels, hidden_channels, out_channels)
        self.gat = to_hetero(self.gat, metadata=movielens_raw.metadata())
        self.gat_res = GAT_RES(in_channels, hidden_channels, out_channels)
        self.gat_res = to_hetero(self.gat_res, metadata=movielens_raw.metadata())
        self.linear_movie = torch.nn.Linear(movies.num_node_features, in_channels)
        self.linear_user = torch.nn.Linear(users.num_node_features, in_channels)
        self.gnn = self.sage_res

    def forward(self, hetero_data):
        features = {
            node_types[0]: self.linear_movie(hetero_data[node_types[0]].x),
            node_types[1]: self.linear_user(hetero_data[node_types[1]].x)
        }
        embeddings = self.gnn(features, hetero_data.edge_index_dict)
        return embeddings

In [15]:
def calc_emb_similarity(node_embs, edge_index, method="cosine"):
    # node_types[1] = "user"
    # node_types[0] = "movie"
    if method == "cosine":
        return torch.sum(node_embs[node_types[1]][edge_index[0]] * node_embs[node_types[0]][edge_index[1]], 1)

In [16]:
from torch_geometric.loader import LinkNeighborLoader
relation_rate = edge_types[0]
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[40, 20, 10],
    neg_sampling="binary",
    neg_sampling_ratio=1.0,
    edge_label_index=(relation_rate, train_data[relation_rate].edge_label_index),
    edge_label=train_data[relation_rate].edge_label,
    batch_size=256,
    shuffle=True
)



In [17]:
def get_loss(scores, labels,loss_fn,edge_weights,mode):

    alpha=0.75
    gamma=5
    loss = loss_fn(scores, labels.float())

    p_t = torch.exp(-loss)
    alpha_tensor = (1 - alpha) + labels * (2 * alpha - 1)
    # alpha_tensor = torch.where(labels == -1, 1 - alpha, alpha)
    # print("alpha",alpha_tensor.shape)
    # print("p_t",p_t.shape)
    # print("loss",loss.shape)
    # print("edge",edge_weights.shape)
    if mode == 0:
      f_loss = alpha_tensor * (1 - p_t) ** gamma * loss*edge_weights
    elif mode == 1:
      f_loss = alpha_tensor * (1 - p_t) ** gamma * loss+edge_weights*loss
    elif mode == 2:
      f_loss = alpha_tensor * (1 - p_t) ** gamma * loss
    elif mode == 3:
      f_loss = edge_weights*loss
    elif mode ==4:
      f_loss = loss

    return f_loss.mean()

In [18]:
from tqdm import tqdm
from torch_geometric.utils import negative_sampling

def train(model, dataloader, optimizer, loss_fn,user_importance,mode):
    correct_count = 0
    all_count = 0
    loss = 0
    model.train()
    for batch in tqdm(dataloader):
        optimizer.zero_grad()

        node_embeddings = model(batch)
        labels = batch[edge_types[0]].edge_label
        similarities = calc_emb_similarity(node_embeddings, batch[edge_types[0]].edge_label_index)

        predictions = similarities.sigmoid() > 0.5
        correct_count += torch.sum(predictions == labels)
        all_count += len(labels)

        edge_weights = edge_importance(user_importance,batch[edge_types[0]].edge_label_index,labels)
        loss = get_loss(similarities, labels,loss_fn,edge_weights,mode)

        loss.backward()
        optimizer.step()
    return model, (float(correct_count) / float(all_count))


In [19]:
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def test(model, hetero_data):
    model.eval()
    node_embs = model(hetero_data)
    node_embs = calc_emb_similarity(node_embs, hetero_data[edge_types[0]].edge_label_index).view(-1).sigmoid()
    return roc_auc_score(hetero_data[edge_types[0]].edge_label.cpu().numpy(), node_embs.cpu().numpy())

In [20]:
# import pandas as pd
# import matplotlib.pyplot as plt

In [21]:
# columns = ['Training Acc', 'Validation Acc', 'Test Acc']
# results_df = pd.DataFrame(columns=columns)

In [22]:
epochs = 30

val_acc = []
test_acc = []
for i in range(3):
    best_val_auc = final_test_auc = 0
    model = Embedder(movielens_raw[node_types[0]].num_node_features, hidden_channels=128, out_channels=64)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-2)
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none')
    for epoch in range(1, epochs + 1):

      model, acc = train(model, train_loader, optimizer, loss_fn,user_score,0)
      val_auc = test(model, val_data)
      test_auc = test(model, test_data)
      if val_auc > best_val_auc:
          best_val_auc = val_auc
      final_test_auc = test_auc
      print(f'Epoch: {epoch:03d}, Training Accuracy: {acc:.4f}, Val AUC: {val_auc:.4f}, Test AUC: {test_auc:.4f}')
      # results_df = results_df.append({'Training Acc': acc,
      #                                 'Validation Acc': val_auc,
      #                                 'Test Acc': test_auc},
      #                                ignore_index=True)
    val_acc.append(best_val_auc)
    test_acc.append(final_test_auc)



100%|██████████| 54/54 [00:05<00:00,  9.92it/s]


Epoch: 001, Training Accuracy: 0.5354, Val AUC: 0.7042, Test AUC: 0.6954


100%|██████████| 54/54 [00:06<00:00,  8.86it/s]


Epoch: 002, Training Accuracy: 0.6334, Val AUC: 0.7877, Test AUC: 0.7797


100%|██████████| 54/54 [00:05<00:00,  9.43it/s]


Epoch: 003, Training Accuracy: 0.6844, Val AUC: 0.8343, Test AUC: 0.8252


100%|██████████| 54/54 [00:04<00:00, 10.90it/s]


Epoch: 004, Training Accuracy: 0.7066, Val AUC: 0.8446, Test AUC: 0.8378


100%|██████████| 54/54 [00:05<00:00, 10.55it/s]


Epoch: 005, Training Accuracy: 0.7235, Val AUC: 0.8574, Test AUC: 0.8552


100%|██████████| 54/54 [00:05<00:00, 10.68it/s]


Epoch: 006, Training Accuracy: 0.7371, Val AUC: 0.8574, Test AUC: 0.8554


100%|██████████| 54/54 [00:05<00:00, 10.09it/s]


Epoch: 007, Training Accuracy: 0.7447, Val AUC: 0.8650, Test AUC: 0.8658


100%|██████████| 54/54 [00:04<00:00, 10.90it/s]


Epoch: 008, Training Accuracy: 0.7513, Val AUC: 0.8699, Test AUC: 0.8708


100%|██████████| 54/54 [00:05<00:00, 10.29it/s]


Epoch: 009, Training Accuracy: 0.7507, Val AUC: 0.8792, Test AUC: 0.8781


100%|██████████| 54/54 [00:05<00:00, 10.50it/s]


Epoch: 010, Training Accuracy: 0.7608, Val AUC: 0.8621, Test AUC: 0.8614


100%|██████████| 54/54 [00:05<00:00, 10.45it/s]


Epoch: 011, Training Accuracy: 0.7669, Val AUC: 0.8881, Test AUC: 0.8844


100%|██████████| 54/54 [00:05<00:00, 10.18it/s]


Epoch: 012, Training Accuracy: 0.7741, Val AUC: 0.8783, Test AUC: 0.8768


100%|██████████| 54/54 [00:04<00:00, 10.81it/s]


Epoch: 013, Training Accuracy: 0.7740, Val AUC: 0.8735, Test AUC: 0.8717


100%|██████████| 54/54 [00:05<00:00, 10.15it/s]


Epoch: 014, Training Accuracy: 0.7856, Val AUC: 0.8938, Test AUC: 0.8931


100%|██████████| 54/54 [00:05<00:00, 10.30it/s]


Epoch: 015, Training Accuracy: 0.7761, Val AUC: 0.8885, Test AUC: 0.8863


100%|██████████| 54/54 [00:05<00:00, 10.25it/s]


Epoch: 016, Training Accuracy: 0.7811, Val AUC: 0.8925, Test AUC: 0.8920


100%|██████████| 54/54 [00:05<00:00,  9.78it/s]


Epoch: 017, Training Accuracy: 0.7787, Val AUC: 0.8903, Test AUC: 0.8869


100%|██████████| 54/54 [00:05<00:00, 10.03it/s]


Epoch: 018, Training Accuracy: 0.7812, Val AUC: 0.8948, Test AUC: 0.8938


100%|██████████| 54/54 [00:05<00:00, 10.35it/s]


Epoch: 019, Training Accuracy: 0.7892, Val AUC: 0.8928, Test AUC: 0.8905


100%|██████████| 54/54 [00:04<00:00, 10.88it/s]


Epoch: 020, Training Accuracy: 0.7862, Val AUC: 0.8947, Test AUC: 0.8938


100%|██████████| 54/54 [00:05<00:00, 10.15it/s]


Epoch: 021, Training Accuracy: 0.7860, Val AUC: 0.8930, Test AUC: 0.8893


100%|██████████| 54/54 [00:05<00:00, 10.18it/s]


Epoch: 022, Training Accuracy: 0.7827, Val AUC: 0.8957, Test AUC: 0.8961


100%|██████████| 54/54 [00:05<00:00, 10.45it/s]


Epoch: 023, Training Accuracy: 0.7835, Val AUC: 0.8898, Test AUC: 0.8932


100%|██████████| 54/54 [00:05<00:00, 10.54it/s]


Epoch: 024, Training Accuracy: 0.7855, Val AUC: 0.8950, Test AUC: 0.8929


100%|██████████| 54/54 [00:05<00:00, 10.40it/s]


Epoch: 025, Training Accuracy: 0.7913, Val AUC: 0.8961, Test AUC: 0.8970


100%|██████████| 54/54 [00:05<00:00,  9.92it/s]


Epoch: 026, Training Accuracy: 0.7908, Val AUC: 0.8941, Test AUC: 0.8933


100%|██████████| 54/54 [00:05<00:00, 10.41it/s]


Epoch: 027, Training Accuracy: 0.7900, Val AUC: 0.8921, Test AUC: 0.8937


100%|██████████| 54/54 [00:05<00:00,  9.65it/s]


Epoch: 028, Training Accuracy: 0.7924, Val AUC: 0.8972, Test AUC: 0.8962


100%|██████████| 54/54 [00:05<00:00, 10.27it/s]


Epoch: 029, Training Accuracy: 0.7943, Val AUC: 0.8941, Test AUC: 0.8923


100%|██████████| 54/54 [00:05<00:00, 10.18it/s]


Epoch: 030, Training Accuracy: 0.7956, Val AUC: 0.9022, Test AUC: 0.9012


100%|██████████| 54/54 [00:05<00:00, 10.59it/s]


Epoch: 001, Training Accuracy: 0.5887, Val AUC: 0.7967, Test AUC: 0.7945


100%|██████████| 54/54 [00:05<00:00, 10.60it/s]


Epoch: 002, Training Accuracy: 0.6947, Val AUC: 0.8147, Test AUC: 0.8134


100%|██████████| 54/54 [00:05<00:00, 10.23it/s]


Epoch: 003, Training Accuracy: 0.7190, Val AUC: 0.8541, Test AUC: 0.8510


100%|██████████| 54/54 [00:04<00:00, 10.95it/s]


Epoch: 004, Training Accuracy: 0.7374, Val AUC: 0.8577, Test AUC: 0.8582


100%|██████████| 54/54 [00:05<00:00, 10.65it/s]


Epoch: 005, Training Accuracy: 0.7512, Val AUC: 0.8720, Test AUC: 0.8688


100%|██████████| 54/54 [00:05<00:00, 10.33it/s]


Epoch: 006, Training Accuracy: 0.7558, Val AUC: 0.8765, Test AUC: 0.8737


100%|██████████| 54/54 [00:05<00:00, 10.05it/s]


Epoch: 007, Training Accuracy: 0.7619, Val AUC: 0.8778, Test AUC: 0.8777


100%|██████████| 54/54 [00:05<00:00, 10.53it/s]


Epoch: 008, Training Accuracy: 0.7663, Val AUC: 0.8855, Test AUC: 0.8828


100%|██████████| 54/54 [00:05<00:00, 10.67it/s]


Epoch: 009, Training Accuracy: 0.7665, Val AUC: 0.8830, Test AUC: 0.8794


100%|██████████| 54/54 [00:05<00:00,  9.99it/s]


Epoch: 010, Training Accuracy: 0.7682, Val AUC: 0.8812, Test AUC: 0.8761


100%|██████████| 54/54 [00:04<00:00, 10.88it/s]


Epoch: 011, Training Accuracy: 0.7739, Val AUC: 0.8810, Test AUC: 0.8781


100%|██████████| 54/54 [00:05<00:00, 10.23it/s]


Epoch: 012, Training Accuracy: 0.7781, Val AUC: 0.8863, Test AUC: 0.8834


100%|██████████| 54/54 [00:05<00:00, 10.49it/s]


Epoch: 013, Training Accuracy: 0.7769, Val AUC: 0.8892, Test AUC: 0.8873


100%|██████████| 54/54 [00:05<00:00, 10.40it/s]


Epoch: 014, Training Accuracy: 0.7773, Val AUC: 0.8896, Test AUC: 0.8872


100%|██████████| 54/54 [00:05<00:00, 10.25it/s]


Epoch: 015, Training Accuracy: 0.7777, Val AUC: 0.8857, Test AUC: 0.8827


100%|██████████| 54/54 [00:05<00:00, 10.15it/s]


Epoch: 016, Training Accuracy: 0.7829, Val AUC: 0.8933, Test AUC: 0.8898


100%|██████████| 54/54 [00:05<00:00,  9.72it/s]


Epoch: 017, Training Accuracy: 0.7844, Val AUC: 0.8888, Test AUC: 0.8887


100%|██████████| 54/54 [00:05<00:00,  9.93it/s]


Epoch: 018, Training Accuracy: 0.7856, Val AUC: 0.8939, Test AUC: 0.8923


100%|██████████| 54/54 [00:05<00:00,  9.95it/s]


Epoch: 019, Training Accuracy: 0.7838, Val AUC: 0.8924, Test AUC: 0.8922


100%|██████████| 54/54 [00:04<00:00, 10.84it/s]


Epoch: 020, Training Accuracy: 0.7888, Val AUC: 0.8906, Test AUC: 0.8872


100%|██████████| 54/54 [00:05<00:00, 10.09it/s]


Epoch: 021, Training Accuracy: 0.7880, Val AUC: 0.8856, Test AUC: 0.8848


100%|██████████| 54/54 [00:05<00:00, 10.06it/s]


Epoch: 022, Training Accuracy: 0.7932, Val AUC: 0.8935, Test AUC: 0.8920


100%|██████████| 54/54 [00:05<00:00, 10.44it/s]


Epoch: 023, Training Accuracy: 0.7893, Val AUC: 0.8879, Test AUC: 0.8880


100%|██████████| 54/54 [00:05<00:00, 10.57it/s]


Epoch: 024, Training Accuracy: 0.7892, Val AUC: 0.8864, Test AUC: 0.8871


100%|██████████| 54/54 [00:05<00:00, 10.76it/s]


Epoch: 025, Training Accuracy: 0.7888, Val AUC: 0.9015, Test AUC: 0.8977


100%|██████████| 54/54 [00:05<00:00, 10.57it/s]


Epoch: 026, Training Accuracy: 0.7953, Val AUC: 0.8962, Test AUC: 0.8956


100%|██████████| 54/54 [00:04<00:00, 10.90it/s]


Epoch: 027, Training Accuracy: 0.7969, Val AUC: 0.8935, Test AUC: 0.8957


100%|██████████| 54/54 [00:05<00:00, 10.25it/s]


Epoch: 028, Training Accuracy: 0.7934, Val AUC: 0.8943, Test AUC: 0.8960


100%|██████████| 54/54 [00:05<00:00, 10.52it/s]


Epoch: 029, Training Accuracy: 0.7928, Val AUC: 0.9017, Test AUC: 0.8993


100%|██████████| 54/54 [00:05<00:00,  9.91it/s]


Epoch: 030, Training Accuracy: 0.7947, Val AUC: 0.8922, Test AUC: 0.8922


100%|██████████| 54/54 [00:05<00:00, 10.52it/s]


Epoch: 001, Training Accuracy: 0.6028, Val AUC: 0.8004, Test AUC: 0.7933


100%|██████████| 54/54 [00:05<00:00, 10.72it/s]


Epoch: 002, Training Accuracy: 0.6960, Val AUC: 0.8112, Test AUC: 0.8079


100%|██████████| 54/54 [00:05<00:00, 10.70it/s]


Epoch: 003, Training Accuracy: 0.7153, Val AUC: 0.8461, Test AUC: 0.8367


100%|██████████| 54/54 [00:04<00:00, 10.82it/s]


Epoch: 004, Training Accuracy: 0.7286, Val AUC: 0.8549, Test AUC: 0.8485


100%|██████████| 54/54 [00:05<00:00, 10.14it/s]


Epoch: 005, Training Accuracy: 0.7356, Val AUC: 0.8663, Test AUC: 0.8582


100%|██████████| 54/54 [00:05<00:00, 10.50it/s]


Epoch: 006, Training Accuracy: 0.7559, Val AUC: 0.8718, Test AUC: 0.8684


100%|██████████| 54/54 [00:04<00:00, 10.81it/s]


Epoch: 007, Training Accuracy: 0.7527, Val AUC: 0.8729, Test AUC: 0.8700


100%|██████████| 54/54 [00:05<00:00, 10.29it/s]


Epoch: 008, Training Accuracy: 0.7628, Val AUC: 0.8800, Test AUC: 0.8761


100%|██████████| 54/54 [00:05<00:00, 10.59it/s]


Epoch: 009, Training Accuracy: 0.7733, Val AUC: 0.8791, Test AUC: 0.8778


100%|██████████| 54/54 [00:05<00:00, 10.63it/s]


Epoch: 010, Training Accuracy: 0.7724, Val AUC: 0.8807, Test AUC: 0.8779


100%|██████████| 54/54 [00:04<00:00, 10.93it/s]


Epoch: 011, Training Accuracy: 0.7721, Val AUC: 0.8840, Test AUC: 0.8841


100%|██████████| 54/54 [00:05<00:00, 10.24it/s]


Epoch: 012, Training Accuracy: 0.7788, Val AUC: 0.8828, Test AUC: 0.8825


100%|██████████| 54/54 [00:05<00:00, 10.70it/s]


Epoch: 013, Training Accuracy: 0.7774, Val AUC: 0.8917, Test AUC: 0.8889


100%|██████████| 54/54 [00:05<00:00, 10.72it/s]


Epoch: 014, Training Accuracy: 0.7779, Val AUC: 0.8858, Test AUC: 0.8838


100%|██████████| 54/54 [00:05<00:00, 10.49it/s]


Epoch: 015, Training Accuracy: 0.7781, Val AUC: 0.8865, Test AUC: 0.8851


100%|██████████| 54/54 [00:05<00:00, 10.47it/s]


Epoch: 016, Training Accuracy: 0.7829, Val AUC: 0.8827, Test AUC: 0.8818


100%|██████████| 54/54 [00:05<00:00, 10.34it/s]


Epoch: 017, Training Accuracy: 0.7815, Val AUC: 0.8812, Test AUC: 0.8831


100%|██████████| 54/54 [00:05<00:00, 10.80it/s]


Epoch: 018, Training Accuracy: 0.7831, Val AUC: 0.8904, Test AUC: 0.8907


100%|██████████| 54/54 [00:05<00:00, 10.59it/s]


Epoch: 019, Training Accuracy: 0.7851, Val AUC: 0.8941, Test AUC: 0.8924


100%|██████████| 54/54 [00:04<00:00, 10.80it/s]


Epoch: 020, Training Accuracy: 0.7918, Val AUC: 0.8938, Test AUC: 0.8956


100%|██████████| 54/54 [00:05<00:00, 10.51it/s]


Epoch: 021, Training Accuracy: 0.7885, Val AUC: 0.8895, Test AUC: 0.8899


100%|██████████| 54/54 [00:05<00:00, 10.59it/s]


Epoch: 022, Training Accuracy: 0.7900, Val AUC: 0.8934, Test AUC: 0.8910


100%|██████████| 54/54 [00:04<00:00, 10.81it/s]


Epoch: 023, Training Accuracy: 0.7933, Val AUC: 0.8966, Test AUC: 0.8983


100%|██████████| 54/54 [00:05<00:00, 10.38it/s]


Epoch: 024, Training Accuracy: 0.7875, Val AUC: 0.8905, Test AUC: 0.8911


100%|██████████| 54/54 [00:05<00:00,  9.82it/s]


Epoch: 025, Training Accuracy: 0.7852, Val AUC: 0.8922, Test AUC: 0.8929


100%|██████████| 54/54 [00:05<00:00, 10.58it/s]


Epoch: 026, Training Accuracy: 0.7944, Val AUC: 0.8882, Test AUC: 0.8923


100%|██████████| 54/54 [00:05<00:00, 10.35it/s]


Epoch: 027, Training Accuracy: 0.7925, Val AUC: 0.8884, Test AUC: 0.8879


100%|██████████| 54/54 [00:05<00:00, 10.35it/s]


Epoch: 028, Training Accuracy: 0.7866, Val AUC: 0.8953, Test AUC: 0.8984


100%|██████████| 54/54 [00:05<00:00, 10.70it/s]


Epoch: 029, Training Accuracy: 0.7858, Val AUC: 0.8878, Test AUC: 0.8871


100%|██████████| 54/54 [00:05<00:00, 10.67it/s]


Epoch: 030, Training Accuracy: 0.7877, Val AUC: 0.8924, Test AUC: 0.8944


In [23]:
def calculate_standard_error(tensor):
    mean = torch.mean(tensor)
    std_dev = torch.std(tensor, unbiased=True)
    n = torch.numel(tensor)
    standard_error = std_dev / torch.sqrt(torch.tensor(n, dtype=torch.float))

    return standard_error

In [24]:
print(f'Mean Validation AUC: {torch.mean(torch.tensor(val_acc)):.4f} +/- {calculate_standard_error(torch.tensor(val_acc)):.4f}')
print(f'Mean Test AUC: {torch.mean(torch.tensor(test_acc)):.4f} +/- {calculate_standard_error(torch.tensor(test_acc)):.4f}')

Mean Validation AUC: 0.9002 +/- 0.0018
Mean Test AUC: 0.8959 +/- 0.0027
