# baseline calculation with cross entropy

In [21]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from sklearn.metrics import balanced_accuracy_score

class PubMedGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

def train_and_evaluate(data):
    model = PubMedGNN(data.x.shape[1], 64, data.y.max().item() + 1)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        
        logits = model(data.x, data.edge_index)
        loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
        
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        preds = logits[data.test_mask].argmax(dim=1)
        balanced_acc = balanced_accuracy_score(data.y[data.test_mask].cpu(), preds.cpu())

    return balanced_acc

# Load PubMed dataset
dataset = Planetoid(root='/tmp/PubMed', name='PubMed', transform=NormalizeFeatures())
data = dataset[0]

balanced_acc = train_and_evaluate(data)
print(f"Balanced Accuracy: {balanced_acc:.2f}")

Balanced Accuracy: 0.80


# Gib training on 2 layer architecture 

In [25]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader

# Load PubMed dataset
dataset = Planetoid(root='/tmp/PubMed', name='PubMed')

class GIBModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, beta=0.1):
        super(GIBModel, self).__init__()
        self.beta = beta
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.prior_dist = torch.distributions.Normal(0, 1)  # Gaussian prior for Z

    def forward(self, x, edge_index):
        z1 = F.relu(self.conv1(x, edge_index))  # First GCN layer
        z2 = self.conv2(z1, edge_index)  # Second GCN layer
        return z1, z2

    def compute_gib_loss(self, z, y, prior_dist, target_dist, edge_index):
        # Cross-entropy for I(Y; Z_X^(L)) - Task relevance
        prediction_loss = F.cross_entropy(z, y)
        
        # KL divergence for I(D; Z_X^(L)) - Compression term
        kl_div = torch.distributions.kl_divergence(target_dist, prior_dist).mean()
        
        # Combine losses
        gib_loss = prediction_loss + self.beta * kl_div
        return gib_loss
from sklearn.metrics import balanced_accuracy_score

def evaluate_model(model, data):
    model.eval()
    with torch.no_grad():
        z1, z2 = model(data.x, data.edge_index)  # Forward pass
        predictions = z2.argmax(dim=1)  # Get predicted class labels
        true_labels = data.y.cpu().numpy()  # Ground truth labels
        pred_labels = predictions.cpu().numpy()  # Predicted labels

        # Calculate Balanced Accuracy
        balanced_acc = balanced_accuracy_score(true_labels, pred_labels)
        return balanced_acc

def train_gib_model(data, input_dim, hidden_dim, output_dim, beta=0.1, epochs=200, lr=0.01):
    model = GIBModel(input_dim, hidden_dim, output_dim, beta=beta).to(data.x.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        # Forward pass
        z1, z2 = model(data.x, data.edge_index)
        prior_dist = model.prior_dist
        target_dist = torch.distributions.Normal(z2.mean(), z2.std())  # Variational posterior

        # Compute GIB loss
        loss = model.compute_gib_loss(z2, data.y, prior_dist, target_dist, data.edge_index)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            balanced_acc = evaluate_model(model, data)
            print(f'Epoch {epoch}, Loss: {loss.item()}, Balanced Accuracy: {balanced_acc:.4f}')
    
    return model

# Load data and run the training function
data = dataset[0]
input_dim = dataset.num_node_features
hidden_dim = 64
output_dim = dataset.num_classes
beta = 0.1  # Ideal beta value can be tuned
trained_model = train_gib_model(data, input_dim, hidden_dim, output_dim, beta=beta)

# Final Evaluation
final_balanced_accuracy = evaluate_model(trained_model, data)
print(f'Final Balanced Accuracy: {final_balanced_accuracy:.4f}')



Epoch 0, Loss: 1.5026302337646484, Balanced Accuracy: 0.3749
Epoch 10, Loss: 1.1825270652770996, Balanced Accuracy: 0.3726
Epoch 20, Loss: 0.9962648749351501, Balanced Accuracy: 0.5701
Epoch 30, Loss: 0.8259830474853516, Balanced Accuracy: 0.6610
Epoch 40, Loss: 0.6544902324676514, Balanced Accuracy: 0.8120
Epoch 50, Loss: 0.5252289175987244, Balanced Accuracy: 0.8356
Epoch 60, Loss: 0.4669717252254486, Balanced Accuracy: 0.8439
Epoch 70, Loss: 0.44023898243904114, Balanced Accuracy: 0.8525
Epoch 80, Loss: 0.42363640666007996, Balanced Accuracy: 0.8598
Epoch 90, Loss: 0.4130460321903229, Balanced Accuracy: 0.8652
Epoch 100, Loss: 0.4052874743938446, Balanced Accuracy: 0.8692
Epoch 110, Loss: 0.39930373430252075, Balanced Accuracy: 0.8716
Epoch 120, Loss: 0.3943099081516266, Balanced Accuracy: 0.8740
Epoch 130, Loss: 0.3900507688522339, Balanced Accuracy: 0.8765
Epoch 140, Loss: 0.3863217234611511, Balanced Accuracy: 0.8783
Epoch 150, Loss: 0.38299494981765747, Balanced Accuracy: 0.8801

# understanding the change with the tradeoff factor

In [None]:
# value evaluatio