In [None]:
import argparse
import os
import torch
from torch_geometric.loader import DataLoader
from src.loadData import GraphDataset
from src.utils import set_seed
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

from src.models import GNN, MLP, CompleteModel

# Set the random seed
set_seed()

def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

import torch
import torch.nn.functional as F


def train(data_loader, model, optimizer, criterion, device, save_checkpoints, checkpoint_path, current_epoch):
    model.train()
    total_loss = 0
    for data in tqdm(data_loader, desc="Iterating training graphs", unit="batch"):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Save checkpoints if required
    if save_checkpoints:
        checkpoint_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_file)
        print(f"Checkpoint saved at {checkpoint_file}")

    return total_loss / len(data_loader)

def evaluate(data_loader, model, device, calculate_accuracy=False):
    model.eval()
    correct = 0
    total = 0
    predictions = []
    with torch.no_grad():
        for data in tqdm(data_loader, desc="Iterating eval graphs", unit="batch"):
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            predictions.extend(pred.cpu().numpy())
            if calculate_accuracy:
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
    if calculate_accuracy:
        accuracy = correct / total
        return accuracy, predictions
    return predictions

def save_predictions(predictions, test_path):
    script_dir = os.path.dirname(os.path.abspath(__file__))
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))
    
    os.makedirs(submission_folder, exist_ok=True)
    
    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")
    
    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({
        "id": test_graph_ids,
        "pred": predictions
    })
    
    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")


def plot_training_progress(train_losses, train_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 6))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color='blue')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss per Epoch')

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color='green')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy per Epoch')

    # Save plots in the current directory
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.close()

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn.functional as F


In [8]:
test_path = "./datasets/C/test.json.gz"
train_path = "./datasets/C/train.json.gz"
batch_size = 32

# Prepare test dataset and loader
test_dataset = GraphDataset(test_path, transform=add_zeros)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_dataset = GraphDataset(train_path, transform=add_zeros)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [8]:
print(test_dataset.get(0).edge_attr[10])


tensor([0.0000, 0.0000, 0.4910, 0.0000, 0.0000, 0.0990, 0.0540])


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import NNConv, global_mean_pool
from src.conv import GINConv

class VGAE_MessagePassing(nn.Module):
    def __init__(self, in_channels, edge_attr_dim, hidden_dim, latent_dim):
        super().__init__()
        # Edge network to produce convolution weights from edge attributes        
        self.conv1 = GINConv(hidden_dim)
        self.conv2 = GINConv(hidden_dim)

        self.mu_layer = torch.nn.Linear(hidden_dim, latent_dim)
        self.logvar_layer = torch.nn.Linear(hidden_dim, latent_dim)

        self.edge_attr_decoder = nn.Sequential(
            nn.Linear(latent_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, edge_attr_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))

        mu = self.mu_layer(x)
        logvar = self.logvar_layer(x)
        return mu, logvar

class VGAE(nn.Module):
    def __init__(self, in_channels, edge_attr_dim, hidden_dim, latent_dim, num_classes):
        super().__init__()
        self.encoder = VGAE_MessagePassing(in_channels, edge_attr_dim, hidden_dim, latent_dim)
        self.classifier = nn.Linear(latent_dim, num_classes)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, edge_index):
        adj_pred = torch.sigmoid(torch.mm(z, z.t()))
        
        edge_feat_input = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)
        edge_attr_pred = torch.sigmoid(self.edge_attr_decoder(edge_feat_input))
        
        return adj_pred, edge_attr_pred

    def forward(self, data, inference=False):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        print(x.shape, edge_index.shape, edge_attr.shape, batch.shape)
        # x = torch.ones()
        x = x.reshape((-1, 1))
        mu, logvar = self.encoder(x, edge_index, edge_attr)
        if not inference:
            z = self.reparameterize(mu, logvar)
        
        # Graph-level embedding via mean pooling of latent node embeddings
        graph_emb = global_mean_pool(z, batch)
        class_logits= self.classifier(graph_emb)
        return z, mu, logvar, class_logits
    
    def loss(self, z, mu, logvar, class_logits, data, alpha=1, beta=0.1, gamma=0.1, delta=0.1):
        classification_loss = F.cross_entropy(class_logits, data.y)

        adj_pred, edge_attr_pred = self.decode(z, data.edge_index)
        adj_true = torch.zeros_like(adj_pred)
        adj_true[data.edge_index[0], data.edge_index[1]]

        adj_loss = F.binary_cross_entropy(adj_pred, adj_true)
        edge_attr_loss = F.mse_loss(edge_attr_pred, data.edge_attr)

        kl_loss = -0.5 * torch.mean(torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))

        loss = (
            alpha * classification_loss +
            beta * adj_loss +
            gamma * edge_attr_loss +
            delta * kl_loss
        )

        return loss





model = VGAE(1, edge_attr_dim=7, hidden_dim=128, latent_dim=16, num_classes=6)
for data in train_loader:
    res = model(data)
    print(res)

torch.Size([8260]) torch.Size([2, 141262]) torch.Size([141262, 7]) torch.Size([8260])
(tensor([[ 1.0827, -0.2646,  0.8171,  ...,  0.7688, -0.3303, -0.7539],
        [ 0.3374,  0.4915,  0.5873,  ..., -0.2523, -1.4078,  0.1018],
        [ 1.5704, -0.5685, -0.1056,  ..., -0.3058,  0.7888,  1.5209],
        ...,
        [-0.6141, -0.4504,  0.6907,  ...,  0.6537, -0.1442,  0.2400],
        [-0.1318, -0.1694,  0.0414,  ..., -0.2727,  1.0073, -0.1876],
        [-0.4725, -0.1344,  0.3407,  ..., -0.2777, -0.9140,  0.3938]],
       grad_fn=<AddBackward0>), tensor([[ 0.2524, -0.2335,  0.0638,  ...,  0.5993,  0.1568, -0.4196],
        [ 0.0683, -0.0586,  0.0170,  ...,  0.1022,  0.0387, -0.0323],
        [ 0.6592, -0.3210, -0.0240,  ...,  1.2888,  0.1719, -0.7789],
        ...,
        [ 0.0686, -0.0879, -0.0084,  ...,  0.1451,  0.0221, -0.0109],
        [ 0.0678, -0.0865, -0.0071,  ...,  0.1467,  0.0222, -0.0095],
        [ 0.0682, -0.0897, -0.0070,  ...,  0.1430,  0.0254, -0.0117]],
       grad_f

KeyboardInterrupt: 

In [4]:
class GCELoss(nn.Module):
    def __init__(self, q=0.7, smoothing=0.1, temperature=2.0, num_classes=6):
        """
        Generalized Cross Entropy Loss
        Args:
            q: exponent hyperparameter, controls sensitivity to noise. 0.7 is a good default.
            num_classes: number of classes in your classification problem.
            reduction: 'mean' or 'sum'
        """
        super(GCELoss, self).__init__()
        assert q > 0 and q <= 1, "q should be in (0, 1]"
        self.q = q
        self.num_classes = num_classes
        self.smoothing = smoothing
        self.temperature = temperature

    def forward(self, logits, targets, reduction="mean"):
        """
        logits: [batch_size, num_classes] (raw output from the model)
        targets: [batch_size] (ground-truth labels)
        """
        probs = F.softmax(logits / self.temperature, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
        targets_one_hot = (1 - self.smoothing) * targets_one_hot + self.smoothing / self.num_classes

        # Get p_y^q for each sample
        pt = (probs * targets_one_hot).sum(dim=1)
        loss = (1 - pt ** self.q) / self.q

        if reduction == 'mean':
            return loss.mean()
        elif reduction == 'sum':
            return loss.sum()
        else:
            return loss
        
class EntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = F.softmax(x, dim=1)
        x = x.mean(dim=0)
        return -(x * torch.log(x + 1e-6)).sum()

In [66]:
def compute_embeddings_and_preds(dataloader, embedding_model, classifier, device='cuda'):
    embedding_model.eval();  classifier.eval()

    all_embeddings = []
    all_pred_probs = []
    original_labels = []
    with torch.no_grad():
        for data in tqdm(dataloader):
            data = data.to(device)
            emb = embedding_model(data) 
            logits = classifier(emb)
            probs = torch.softmax(logits, dim=1)

            all_embeddings.append(emb.cpu())
            all_pred_probs.append(probs.cpu())
            original_labels.append(data.cpu().y)

    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_pred_probs = torch.cat(all_pred_probs, dim=0) 
    original_labels = torch.cat(original_labels, dim=0)

    return all_embeddings, all_pred_probs, original_labels

def neighbor_aware_label_correction(embeddings, pred_probs, original_labels, mu=0.8, K=5, delta_prime=0.9, c=1, cmax=10, tau=0.5):
    print(f"Avg max pred {pred_probs.max(dim=1).values.mean().item()}")
    emb_norm = F.normalize(embeddings, p=2, dim=1) 
    sim_matrix = torch.mm(emb_norm, emb_norm.T) 
    sim_matrix = torch.exp(sim_matrix / tau) 

    topk_vals, topk_idx = torch.topk(sim_matrix, K+1, dim=1)
    topk_vals = topk_vals[:, 1:] 
    topk_idx = topk_idx[:, 1:]


    neighbors_preds = pred_probs[topk_idx] 
    weights = topk_vals / topk_vals.sum(dim=1, keepdim=True)
    weighted_neighbors_preds = (weights.unsqueeze(2) * neighbors_preds).sum(dim=1)
    qi = mu * pred_probs + (1 - mu) * weighted_neighbors_preds  
    print(f"Avg qi {qi.max(dim=1).values.mean().item()}")

    delta_c = delta_prime * (c / cmax)
    qi_max_class = qi.argmax(dim=1)

    clean_mask = (qi_max_class == original_labels)  
    max_qi_vals, _ = qi.max(dim=1)
    confident_mask = max_qi_vals > delta_c
    print(f"Updating {((confident_mask & (~clean_mask))*1.0).mean()} of the labels")
    updated_mask = clean_mask | confident_mask
    new_labels = torch.where(clean_mask, original_labels, qi_max_class)

    return updated_mask, new_labels, qi

def filter_dataset_with_label_correction(train_dataset, updated_mask, new_labels):
    filtered_dataset = []
    for i, sample in tqdm(enumerate(train_dataset)):
        if updated_mask[i].item():
            sample.y = new_labels[i].unsqueeze(0) 
            filtered_dataset.append(sample)
    return filtered_dataset

def add_signed_noise(h, gamma=0.01):
    raw_noise = torch.randn_like(h)
    normed_noise = F.normalize(raw_noise, p=2, dim=1)
    scaled_noise = gamma * normed_noise
    signed_noise = torch.abs(scaled_noise) * torch.sign(h)
    return h + signed_noise


def mixup_embeddings(embeddings, labels, alpha=0.1, beta=0.1):
    batch_size = embeddings.size(0)
    device = embeddings.device
    lam = torch.distributions.Beta(alpha, beta).sample([batch_size]).to(device).view(-1, 1)
    index = torch.randperm(batch_size).to(device)
    mixed_embeddings = lam * embeddings + (1 - lam) * embeddings[index]
    mixed_labels = lam * labels + (1 - lam) * labels[index]
    return mixed_embeddings, mixed_labels

def get_positive_mask(mixed_labels, threshold=0.05):
    diff = mixed_labels.unsqueeze(1) - mixed_labels.unsqueeze(0)
    l2_dist = torch.norm(diff, p=2, dim=2)
    pos_mask = (l2_dist < threshold).float()
    pos_mask.fill_diagonal_(0.0) 
    return pos_mask

def get_negatives(z):
    z_norm = F.normalize(z, p=2, dim=1)
    cosine_sim = torch.matmul(z_norm, z_norm.T) 
    mask = 1.0 - torch.eye(z.size(0), device=z.device)
    masked_sim = cosine_sim.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(masked_sim, dim=1)
    soft_negatives = torch.matmul(weights, z)
    return soft_negatives

class OmgContrastiveLoss(nn.Module):
    def __init__(self):
        super(OmgContrastiveLoss, self).__init__()

    def forward(self, pred, positive_mask, neg):
        pred = F.normalize(pred, dim=1)
        neg = F.normalize(neg, dim=1)

        dists = torch.norm(pred.unsqueeze(1) - pred.unsqueeze(0), p=2, dim=2)
        pos_sum = (dists * positive_mask).sum(dim=1) 
        pos_counts = torch.clamp(positive_mask.sum(dim=1), min=1.0)
        pos_avg = pos_sum / pos_counts

        neg_dists = torch.norm(pred - neg, p=2, dim=1)  

        # print(pos_avg.sum().item(), neg_dists.sum().item())
        loss_per_sample = pos_avg - neg_dists
        return loss_per_sample.mean()

class SupervisedSoftLabelLoss(nn.Module):
    def __init__(self):
        super(SupervisedSoftLabelLoss, self).__init__()

    def forward(self, y, y_pred):
        log_probs = F.log_softmax(y_pred, dim=1)
        loss = -(y*log_probs).sum(dim=1).mean() 
        return loss
    

In [67]:
test_dir_name = os.path.basename(os.path.dirname(test_path))


num_layer = 3
emb_dim = 64
drop_ratio = 0.5
num_epochs = 5
warm_up_epochs = 10
device = "cuda:0"
script_dir = "./"
num_checkpoints = 5
num_class = 6

embedding_model = GNN(gnn_type="gin", num_class=num_class, num_layer = num_layer, emb_dim = emb_dim, drop_ratio = drop_ratio, graph_pooling = "mean").to(device)
embedding_projector = MLP([emb_dim, 32, 16]).to(device)
classifier = MLP([emb_dim, num_class]).to(device)
model = CompleteModel(embedding_model, classifier)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
contrastive_optimizer = torch.optim.Adam(list(embedding_model.parameters()) + list(embedding_projector.parameters()) + list(classifier.parameters()), lr=0.001, weight_decay=0.0005)

contrastive_loss = OmgContrastiveLoss()
supervised_label_loss = SupervisedSoftLabelLoss()

scaler = GradScaler()

In [68]:
criterion1 = GCELoss(q=0.7, num_classes=num_class)
criterion2 = EntropyLoss()

for epoch in range(warm_up_epochs):
    model.train(); embedding_projector.train()
    total_loss = 0.0
    for data in tqdm(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        logits = model(data)
        loss1, loss2 = criterion1(logits, data.y), criterion2(logits)
        # print(logits)
        # print(F.one_hot(logits.argmax(dim=1), num_classes=num_class).float().cpu().mean(dim=0), loss1, loss2)
        loss = loss1 - loss2
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    total_loss = total_loss / len(train_loader)
    model.eval();  embedding_projector.eval()
    p_dist = torch.zeros((num_class, ))
    with torch.no_grad():
        correct = 0
        total = 0
        for data in tqdm(train_loader):
            data = data.to(device)
            pred = model(data)
            pred = pred.argmax(dim=1)
            p_dist += F.one_hot(pred, num_classes=num_class).float().cpu().mean(dim=0)
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)
        accuracy = correct / total

    print(f"Epoch {epoch + 1}/{warm_up_epochs}, Loss: {total_loss:.4f}, Train Acc: {accuracy:.4f}")
    print(f"Pred distribution: {(p_dist / len(train_loader)).tolist()}")

100%|██████████| 240/240 [00:13<00:00, 18.31it/s]
100%|██████████| 240/240 [00:11<00:00, 21.12it/s]


Epoch 1/10, Loss: -0.7829, Train Acc: 0.3594
Pred distribution: [0.11380208283662796, 0.14049479365348816, 0.6451823115348816, 0.04401041567325592, 0.005078124813735485, 0.0514322929084301]


100%|██████████| 240/240 [00:13<00:00, 17.67it/s]
100%|██████████| 240/240 [00:11<00:00, 20.66it/s]


Epoch 2/10, Loss: -0.8029, Train Acc: 0.1803
Pred distribution: [0.4455729126930237, 0.23033854365348816, 0.15078124403953552, 0.15026041865348816, 0.0005208333604969084, 0.02252604253590107]


100%|██████████| 240/240 [00:13<00:00, 17.72it/s]
100%|██████████| 240/240 [00:11<00:00, 20.66it/s]


Epoch 3/10, Loss: -0.8192, Train Acc: 0.1743
Pred distribution: [0.0054687499068677425, 0.12708333134651184, 0.19700521230697632, 0.009244791232049465, 0.04843749850988388, 0.6127604246139526]


100%|██████████| 240/240 [00:13<00:00, 17.72it/s]
100%|██████████| 240/240 [00:11<00:00, 20.29it/s]


Epoch 4/10, Loss: -0.8305, Train Acc: 0.1428
Pred distribution: [0.4123697876930237, 0.15781250596046448, 0.1087239608168602, 0.2652343809604645, 0.0403645820915699, 0.01549479179084301]


100%|██████████| 240/240 [00:13<00:00, 17.48it/s]
100%|██████████| 240/240 [00:11<00:00, 20.52it/s]


Epoch 5/10, Loss: -0.8390, Train Acc: 0.1711
Pred distribution: [0.359375, 0.2808593809604645, 0.07135416567325592, 0.23736979067325592, 0.03697916492819786, 0.01406249962747097]


100%|██████████| 240/240 [00:13<00:00, 17.62it/s]
100%|██████████| 240/240 [00:11<00:00, 20.45it/s]


Epoch 6/10, Loss: -0.8447, Train Acc: 0.3257
Pred distribution: [0.04830729216337204, 0.25924479961395264, 0.22643229365348816, 0.10820312798023224, 0.07109375298023224, 0.2867187559604645]


100%|██████████| 240/240 [00:13<00:00, 17.29it/s]
100%|██████████| 240/240 [00:11<00:00, 20.37it/s]


Epoch 7/10, Loss: -0.8571, Train Acc: 0.2706
Pred distribution: [0.24908854067325592, 0.23046875, 0.13854166865348816, 0.22604165971279144, 0.0572916679084301, 0.09856770932674408]


100%|██████████| 240/240 [00:13<00:00, 17.44it/s]
100%|██████████| 240/240 [00:11<00:00, 20.27it/s]


Epoch 8/10, Loss: -0.8634, Train Acc: 0.3418
Pred distribution: [0.041015625, 0.19869790971279144, 0.2604166567325592, 0.0559895820915699, 0.06549479067325592, 0.37838542461395264]


100%|██████████| 240/240 [00:13<00:00, 17.53it/s]
100%|██████████| 240/240 [00:11<00:00, 20.07it/s]


Epoch 9/10, Loss: -0.8767, Train Acc: 0.0960
Pred distribution: [0.7755208611488342, 0.02291666716337204, 0.03281249850988388, 0.12708333134651184, 0.02903645858168602, 0.012630208395421505]


100%|██████████| 240/240 [00:13<00:00, 17.45it/s]
100%|██████████| 240/240 [00:12<00:00, 19.90it/s]

Epoch 10/10, Loss: -0.8806, Train Acc: 0.4286
Pred distribution: [0.06145833432674408, 0.09934895485639572, 0.5854166746139526, 0.11861979216337204, 0.04869791492819786, 0.08645833283662796]





In [69]:
torch.save(model.state_dict(), "./develop_checkpoint/warmup_model.pth")

In [39]:
embedding_model = GNN(gnn_type="gin", num_class=num_class, num_layer = num_layer, emb_dim = emb_dim, drop_ratio = drop_ratio, graph_pooling = "mean").to(device)
embedding_projector = MLP([emb_dim, 32, 16]).to(device)
classifier = MLP([emb_dim, num_class]).to(device)
model = CompleteModel(embedding_model, classifier)
model.load_state_dict(torch.load("./develop_checkpoint/warmup_model.pth"))

<All keys matched successfully>

In [74]:
beta = 0.05

best_accuracy = 0.0
train_losses = []
train_accuracies = []

filtered_train_dataset = train_dataset

In [80]:
c_max = 1
for c in range(c_max):
    filtered_train_loader = DataLoader(filtered_train_dataset, batch_size=batch_size, shuffle=True)
    # embeddings, pred_probs, original_labels = compute_embeddings_and_preds(filtered_train_loader, embedding_model, classifier)
    # updated_mask, new_labels, _ = neighbor_aware_label_correction(embeddings, pred_probs, original_labels, c=c, cmax=c_max, K=10, mu=0.8)
    # print(f"Keeping {(updated_mask*1.).mean()} of previous data")
    # filtered_train_dataset = filter_dataset_with_label_correction(filtered_train_dataset, updated_mask, new_labels)
    # filtered_train_loader = DataLoader(filtered_train_dataset, batch_size=batch_size, shuffle=True)


    for epoch in range(num_epochs): 
        embedding_model.train();  embedding_projector.train(); classifier.train()
        total_loss = 0.0
        for data in tqdm(filtered_train_loader):
            contrastive_optimizer.zero_grad()
            data = data.to(device)
            label = F.one_hot(data.y, num_classes=num_class).float()
            embeddings = embedding_model(data)
            embeddings = add_signed_noise(embeddings)
            embeddings = F.normalize(embeddings, p=2, dim=1)
            z, y = mixup_embeddings(embeddings, label)
            positive_mask = get_positive_mask(y)
            negatives = get_negatives(z)

            emb_proj, neg_proj = embedding_projector(z), embedding_projector(negatives)
            y_pred = classifier(z)
            cl_loss, sup_loss = contrastive_loss(emb_proj, positive_mask, neg_proj), supervised_label_loss(y, y_pred)
            loss = beta * cl_loss + sup_loss
            # print(cl_loss, sup_loss)
            loss.backward()
            contrastive_optimizer.step()
            total_loss += loss.item()
        
        embedding_model.eval();  classifier.eval(); embedding_projector.train()
        p_dist = torch.zeros((num_class, ))
        with torch.no_grad():
            correct = 0
            total = 0
            for data in tqdm(train_loader):
                data = data.to(device)
                output = embedding_model(data)
                output = F.normalize(output, p=2, dim=1)
                pred = classifier(output)
                pred = pred.argmax(dim=1)
                p_dist += F.one_hot(pred, num_classes=num_class).float().cpu().mean(dim=0)
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
            accuracy = correct / total

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}, Train Acc: {accuracy:.4f}")
        print(f"Pred distribution: {(p_dist / len(train_loader)).tolist()}")

100%|██████████| 240/240 [00:13<00:00, 17.75it/s]
100%|██████████| 240/240 [00:11<00:00, 20.70it/s]


Epoch 1/5, Loss: 184.4079, Train Acc: 0.4576
Pred distribution: [0.02356770820915699, 0.39049479365348816, 0.26249998807907104, 0.02174479141831398, 0.02317708358168602, 0.27851563692092896]


100%|██████████| 240/240 [00:13<00:00, 17.46it/s]
100%|██████████| 240/240 [00:11<00:00, 20.62it/s]


Epoch 2/5, Loss: 176.3153, Train Acc: 0.6654
Pred distribution: [0.10208333283662796, 0.2727864682674408, 0.4216145873069763, 0.10078124701976776, 0.009114583022892475, 0.09361979365348816]


100%|██████████| 240/240 [00:13<00:00, 17.63it/s]
100%|██████████| 240/240 [00:11<00:00, 20.21it/s]


Epoch 3/5, Loss: 173.9217, Train Acc: 0.7362
Pred distribution: [0.0377604179084301, 0.19934895634651184, 0.5625, 0.07369791716337204, 0.05820312350988388, 0.06848958134651184]


100%|██████████| 240/240 [00:14<00:00, 16.92it/s]
100%|██████████| 240/240 [00:11<00:00, 20.16it/s]


Epoch 4/5, Loss: 174.8986, Train Acc: 0.7312
Pred distribution: [0.04374999925494194, 0.1744791716337204, 0.5354166626930237, 0.10533854365348816, 0.07434895634651184, 0.06666667014360428]


100%|██████████| 240/240 [00:13<00:00, 17.52it/s]
100%|██████████| 240/240 [00:11<00:00, 20.76it/s]

Epoch 5/5, Loss: 175.0320, Train Acc: 0.7453
Pred distribution: [0.05429687350988388, 0.24700520932674408, 0.46601563692092896, 0.09648437798023224, 0.03789062425494194, 0.0983072891831398]





In [57]:
torch.save(model.state_dict(), "./develop_checkpoint/contrastive_trained_model.pth")
torch.save(embedding_projector.state_dict(), "./develop_checkpoint/contrastive_trained_projector_model.pth")


In [79]:
predictions = evaluate(test_loader, model, device, calculate_accuracy=False)
submission_folder = os.path.join(script_dir, "submission")
test_dir_name = os.path.basename(os.path.dirname(test_path))

os.makedirs(submission_folder, exist_ok=True)

output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")

test_graph_ids = list(range(len(predictions)))
output_df = pd.DataFrame({
    "id": test_graph_ids,
    "pred": predictions
})

output_df.to_csv(output_csv_path, index=False)
print(f"Predictions saved to {output_csv_path}")

Iterating eval graphs: 100%|██████████| 48/48 [00:02<00:00, 21.72batch/s]

Predictions saved to ./submission/testset_C.csv



