In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import random
import numpy as np

import os
import pandas as pd
import sys
import torch

sys.path.append('/home/mrahma56/cs519/SSL_LLM_Node_Classification')
from TAGLAS import get_dataset



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.models import *

In [3]:
# Set random seed
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False





In [4]:
# SEED = 1234
# SEED = 4567
SEED = 7890
set_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# DATASET = 'cora'
DATASET = 'wikics'
LLM_ID = "Llama-3"

In [5]:
dataset_key_dict = {
    'cora': 'cora_node',
    'wikics': 'wikics'
}


root_dir = "/home/mrahma56/cs519/SSL_LLM_Node_Classification/"
taglas_dir = root_dir + "TAGLAS/"
llm_gen_dir = root_dir + "llm_gen_data/"
saved_model_dir = root_dir + "saved_models/"
saved_embedding_dir = root_dir + "saved_embeddings/"
embedding_model = "nvidia/NV-Embed-v2"
embedding_path = saved_embedding_dir + f"{dataset_key_dict[DATASET]}_{embedding_model.split('/')[-1]}.pt"
print(embedding_path)

/home/mrahma56/cs519/SSL_LLM_Node_Classification/saved_embeddings/wikics_NV-Embed-v2.pt


In [6]:
def load_taglas_dataset(dataset_key="cora_node", unlabel_ratio=None,embedding_path=None, print_info=True):
    # Load the dataset from TAGLAS
    dataset = get_dataset(dataset_key, root=taglas_dir)
    data = dataset._data

    # Set train, validation, and test masks based on the dataset key
    if dataset_key == "cora_node":
        data.train_lb_mask = dataset.side_data['node_split']['train'][0].clone()
        data.val_mask = dataset.side_data['node_split']['val'][0].clone()
        data.test_mask = dataset.side_data['node_split']['test'][0].clone()
    elif dataset_key == "wikics":
        data.train_lb_mask = dataset.side_data['node_split']['train'][:, 0].clone()
        data.val_mask = dataset.side_data['node_split']['val'][:, 0].clone()
        data.test_mask = dataset.side_data['node_split']['test'].clone()
    
    # Map labels and features
    data.y = data.label_map
    data.x_text = data.x
    data.x = data.x_original
    if embedding_path is not None and os.path.exists(embedding_path):
        print("Loading embedding from: ", embedding_path)
        data.x = torch.load(embedding_path)
    
    
    # Add num_classes to data
    data.num_classes = dataset.num_classes
    data.train_ulb_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    

    if unlabel_ratio is not None and unlabel_ratio > 0:
        # Get indices of training nodes from the labeled training mask
        train_indices = data.train_lb_mask.nonzero(as_tuple=True)[0]
        
        # Get labels of training nodes
        train_labels = data.y[train_indices]
        
        # Initialize the mask for unlabeled training nodes
        data.train_ulb_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        
        class_label_counts = []  # Store labeled/unlabeled counts per class

        for cls in range(data.num_classes):
            # Get indices of training nodes belonging to the current class
            class_indices = train_indices[train_labels == cls]
            num_class_nodes = len(class_indices)
            
            # Calculate the number of nodes to unlabel (70%) and label (30%) for this class
            nodes_to_unlabel = int(unlabel_ratio * num_class_nodes)
            nodes_to_label = num_class_nodes - nodes_to_unlabel
            
            # Randomly select nodes to unlabel for this class
            unlabeled_indices = class_indices[torch.randperm(num_class_nodes)[:nodes_to_unlabel]]
            
            # Update the unlabeled mask
            data.train_ulb_mask[unlabeled_indices] = True
            
            # Count labeled and unlabeled samples for the class
            class_label_counts.append((cls, nodes_to_label, nodes_to_unlabel))
        
        # Update the labeled training mask
        data.train_lb_mask[data.train_ulb_mask] = False

    if print_info and unlabel_ratio:
        # Print the information about the unlabeled and labeled nodes
        print(f"Unlabeled ratio: {unlabel_ratio}")
        print(f"Labeled training nodes: {data.train_lb_mask.sum().item()}")
        print(f"Unlabeled training nodes: {data.train_ulb_mask.sum().item()}")
        
        # Print class-wise statistics
        print("\nClass-wise labeled and unlabeled counts:")
        for cls, num_labeled, num_unlabeled in class_label_counts:
            print(f"Class {cls}: Labeled = {num_labeled}, Unlabeled = {num_unlabeled}")
    
    # Retain only the required keys in the data object
    required_keys = [
        'x', 'y', 'train_lb_mask', 'train_ulb_mask', 
        'val_mask', 'test_mask', 'num_classes', 
        'num_features', 'x_text', 'edge_index', 'edge_attr'
    ]
    for k in list(data.keys()):
        if k not in required_keys:
            data.pop(k)

    return data

In [7]:
# Function to relabel low-confidence samples
def relabel_samples(low_conf_indices, data):
    num_classes = data.num_classes
    gold_label_prob = 0.7  # Probability of assigning the gold label

    # Generate random probabilities for each low-confidence sample
    random_probs = torch.rand(len(low_conf_indices))

    # Initialize new labels with random class labels
    random_labels = torch.randint(0, num_classes, (len(low_conf_indices),))

    # Assign gold labels with probability `gold_label_prob`
    new_labels = torch.where(
        random_probs < gold_label_prob,
        data.y[low_conf_indices],  # Gold labels
        random_labels  # Random class labels
    )

    return new_labels



In [8]:
# Function to relabel low-confidence samples
def llm_label_samples(low_conf_indices, data, dataset="cora", llm_id="Llama-3"):
    # print(f"Dataset: {dataset}")
    num_classes = data.num_classes
    llm_gen_file = os.path.join(llm_gen_dir, f"{dataset}_{llm_id}.tsv")
    df = pd.read_csv(llm_gen_file, sep='\t')
    y_gen = torch.tensor(df['llm_label'].values)
    y_gen = torch.where((y_gen >= 0) & (y_gen < num_classes) , y_gen, torch.zeros_like(y_gen))
    return y_gen[low_conf_indices]


In [9]:
def train_step_semi_supervised(model, data, optimizer, x, y, alpha=0.1, th=0.5, llm_label=False):
    model.train()
    optimizer.zero_grad()

    logits = model(x, data.edge_index)
    out_prob = F.softmax(logits, dim=1)

    labeled_loss = torch.tensor(0.0, device=x.device)
    consistency_loss = torch.tensor(0.0, device=x.device)
    
    num_low_conf_samples = 0

    # Labeled loss
    if data.train_lb_mask.sum() > 0:
        labeled_loss = F.cross_entropy(logits[data.train_lb_mask], y[data.train_lb_mask])

    # Consistency loss and relabeling
    if data.train_ulb_mask.sum() > 0 and llm_label:
        pseudo_labels = out_prob[data.train_ulb_mask].argmax(dim=1)
        confidence_scores = out_prob[data.train_ulb_mask].max(dim=1).values
        confident_mask = confidence_scores > th
        low_conf_mask = ~confident_mask
        
        # print(f"Confidence Scores: {confidence_scores}")
        # print(f"Confident Mask: {confident_mask}")
        # print(f"Low Confidence Mask: {low_conf_mask}")

        confident_indices = data.train_ulb_mask.nonzero(as_tuple=True)[0][confident_mask]
        low_conf_indices = data.train_ulb_mask.nonzero(as_tuple=True)[0][low_conf_mask]
        num_low_conf_samples = len(low_conf_indices)
        # print(f"Num low Confident Indices: {num_low_conf_samples}")

        # Consistency loss for high-confidence samples
        if len(confident_indices) > 0:
            consistency_loss = F.cross_entropy(logits[confident_indices], pseudo_labels[confident_mask])

        # Relabel low-confidence samples
        if len(low_conf_indices) > 0:
            # new_labels = relabel_samples(low_conf_indices, data)
            new_labels = llm_label_samples(low_conf_indices, data, dataset=DATASET, llm_id=LLM_ID)

            # Create new masks instead of modifying in-place
            new_train_lb_mask = data.train_lb_mask.clone()
            new_train_ulb_mask = data.train_ulb_mask.clone()

            # Update the masks with relabeled samples
            new_train_lb_mask[low_conf_indices] = True
            new_train_ulb_mask[low_conf_indices] = False

            # Assign the new labels and masks
            y[low_conf_indices] = new_labels
            data.train_lb_mask = new_train_lb_mask
            data.train_ulb_mask = new_train_ulb_mask

    # Total loss
    loss = labeled_loss + alpha * consistency_loss
    loss.backward()
    optimizer.step()

    return num_low_conf_samples, loss.item(), labeled_loss.item(), consistency_loss.item(), y, data.train_lb_mask, data.train_ulb_mask


In [10]:
# Validation step
def validation_step(model, data, x, y):
    model.eval()
    with torch.no_grad():
        out = model(x, data.edge_index)
        loss = F.cross_entropy(out[data.val_mask], y[data.val_mask])
        pred = out[data.val_mask].argmax(dim=1)
        acc = (pred == y[data.val_mask]).sum() / data.val_mask.sum()
    return loss.item(), acc.item()


In [11]:
# Test step
def test_step(model, data, x, y):
    model.eval()
    with torch.no_grad():
        out = model(x, data.edge_index)
        pred = out[data.test_mask].argmax(dim=1)
        acc = (pred == y[data.test_mask]).sum() / data.test_mask.sum()
    return acc.item()


In [12]:
# Main training loop
def train_model_semi_supervised(data, num_epochs=250, lr=0.01, hidden_channels=16, alpha=0.1, th=0.5, print_logs=True):
    model = GCC(num_features=data.num_features, hidden_channels=hidden_channels, num_classes=data.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)

    x, y = data.x, data.y
    # y = get_noisy_labels(data, noise_ratio=0.5)
    
    best_val_acc = 0
    best_model_path = os.path.join(saved_model_dir, 'best_model_ssl.pth')

    for epoch in range(1, num_epochs + 1):
        llm_label = True if (epoch % 50) == 0 else False
        num_low_conf_samples, total_loss, labeled_loss, consistency_loss, y, data.train_lb_mask, data.train_ulb_mask = train_step_semi_supervised(model, data, optimizer, x, y, alpha, th, llm_label=llm_label)
        val_loss, val_acc = validation_step(model, data, x, y)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_model_path)

        if (epoch % 10 == 0 or epoch == num_epochs) and print_logs:
            print(f'Epoch: {epoch:03d},Low Conf Samples: {num_low_conf_samples}')
            # print(f'Epoch: {epoch:03d},Low Conf Samples: {num_low_conf_samples}, '
            #       f'Total Loss: {total_loss:.4f}, Labeled Loss: {labeled_loss:.4f}, Consistency Loss: {consistency_loss:.4f}, '
            #       f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    model.load_state_dict(torch.load(best_model_path, weights_only=True))
    test_acc = test_step(model, data, x, y)
    return best_val_acc, test_acc



In [13]:
def train_model_supervised(data, num_epochs=200, learning_rate=0.01, hidden_channels=16, print_logs=True):
    # Initialize model and optimizer
    model = GCC(num_features=data.num_features, hidden_channels=hidden_channels, num_classes=data.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Extract features and labels
    x = data.x
    y = data.y

    # Training, validation, and test masks
    train_mask = data.train_lb_mask
    val_mask = data.val_mask
    test_mask = data.test_mask

    # Track the best model and accuracy
    best_val_acc = 0.0
    # best_model_state = None
    best_model_path = os.path.join(saved_model_dir, 'best_model_sup.pth')

    # Training loop
    for epoch in range(num_epochs):
        # Set model to training mode
        model.train()
        optimizer.zero_grad()

        # Forward pass and compute loss
        out = model(x, data.edge_index)
        # print(f"out[train_mask] shape: {out[train_mask].shape}")
        # print(f"y[train_mask] shape: {y[train_mask].shape}")
        # import time
        # time.sleep(5)
        loss = F.cross_entropy(out[train_mask], y[train_mask])
        
        # Backward pass and update
        loss.backward()
        optimizer.step()

        # Validation step
        val_loss, val_acc = validation_step(model, data, x, y)

        # Update best model if validation accuracy improves
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # best_model_state = model.state_dict()  # Save best model parameters
            torch.save(model.state_dict(), best_model_path)

        # Print training and validation results
        if (epoch + 1) % 10 == 0 and print_logs:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Load the best model parameters before testing
    model.load_state_dict(torch.load(best_model_path, weights_only=True))
    # model.load_state_dict(best_model_state)
    test_acc = test_step(model, data, x, y)
    # print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    # print(f"Test Accuracy (Best Model): {test_acc:.4f}")
    
    return best_val_acc, test_acc


In [14]:
from train_eval import run_ssl, run_supervised

In [15]:
dataset_key = dataset_key_dict[DATASET]

# Experiment with varying unlabeled ratios
results = []
for ulb_ratio in [0, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]:

    #set embedding path to None to use the original features
    data_ssl = load_taglas_dataset(dataset_key=dataset_key, unlabel_ratio=ulb_ratio,embedding_path=embedding_path, print_info=False)
    
    # run these two lines for basic GCC
    # _, test_acc_sup = train_model_supervised(data_ssl, print_logs=False)
    # _, test_acc_ssl = train_model_semi_supervised(data_ssl, alpha=0.1, th=0.7, print_logs=False)
    
    #run these two for SOTA GCN
    test_acc_sup = run_supervised(data_ssl, device, dataset=DATASET, print_logs=False) 
    test_acc_ssl = run_ssl(data_ssl, device, dataset=DATASET, print_logs=False)
    
    
    results.append({'ulb_ratio': ulb_ratio, 'supervised_acc': test_acc_sup, 'ssl_acc': test_acc_ssl})


  self.data, self.slices = torch.load(self.processed_paths[0])
  self.side_data = torch.load(self.processed_paths[1])


In [16]:
# Save the results to a CSV file
results_dir = "/home/mrahma56/cs519/SSL_LLM_Node_Classification/results"
results_df = pd.DataFrame(results)
results_df.to_csv(os.path.join(results_dir, f"results_{DATASET}_{LLM_ID}_{SEED}.csv"), index=False)

In [17]:
for r in results:
    print(f"Unlabeled Ratio: {r['ulb_ratio']:.2f}, Supervised Acc: {r['supervised_acc']:.4f}, SSL Acc: {r['ssl_acc']:.4f}")

Unlabeled Ratio: 0.00, Supervised Acc: 0.7866, SSL Acc: 0.7867
Unlabeled Ratio: 0.70, Supervised Acc: 0.7688, SSL Acc: 0.7580
Unlabeled Ratio: 0.75, Supervised Acc: 0.7347, SSL Acc: 0.7305
Unlabeled Ratio: 0.80, Supervised Acc: 0.7149, SSL Acc: 0.7346
Unlabeled Ratio: 0.85, Supervised Acc: 0.7407, SSL Acc: 0.7445
Unlabeled Ratio: 0.90, Supervised Acc: 0.7404, SSL Acc: 0.7077
Unlabeled Ratio: 0.95, Supervised Acc: 0.6658, SSL Acc: 0.6319
