In [1]:
# Clean uninstall just in case
# !pip uninstall -y torch-scatter torch-sparse torch-geometric pyg-lib

# Install compatible packages for torch 2.5.1 + CUDA 12.1
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
!pip install -q pyg-lib -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
!pip install -q torch-geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m46.8 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h

In [2]:
import argparse
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, APPNP
from torch_geometric.loader import NeighborLoader
from torch_geometric.datasets import Planetoid

TODO:
Implement teacher network architecture and training along with flags for datasets and teacher model architecture (Chanikya and Nithin)
eg: python3 train_teacher.py --dataset=cora --model=SAGE --epochs-100 --lr=0.01 . Add flags for other hyperparameters if necessary (Chanikya and Nithin)
Other teacher model architectures - GCN, GAT, APPNP (Chanikya and Nithin + others based on availability)

In [53]:
def add_inductive_settings(data, spr=0.2):
    unlabeled_indices = torch.where(data.test_mask)[0]

    num_unlabeled = len(unlabeled_indices)
    num_inductive = int(spr * num_unlabeled)
    perm = torch.randperm(num_unlabeled)
    inductive_indices = unlabeled_indices[perm[:num_inductive]]
    observed_indices = unlabeled_indices[perm[num_inductive:]]

    data.observed_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.inductive_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.observed_mask[observed_indices] = True
    data.inductive_mask[inductive_indices] = True

    edge_index = data.edge_index
    src, dst = edge_index

    mask = ~data.inductive_mask[src] & ~data.inductive_mask[dst]
    data.ind_edge_index = edge_index[:, mask]

    return data

def load_data(dataset, setting="tran"):
    if dataset == "cora":
        dataset = Planetoid(root='./Cora', name='Cora')
        data = dataset[0]
        data.ind_edge_index = []
        data.observed_mask = []
        data.inductive_mask = []
        test_mask = data.test_mask
        if setting=="ind":
            data = add_inductive_settings(data)
            test_mask = data.inductive_mask
        
        return dataset.num_features, dataset.num_classes, data.x, data.y, data.edge_index,  data.ind_edge_index, data.train_mask, data.val_mask, test_mask

In [28]:
class GCN(nn.Module):
    def __init__(
        self,
        num_layers,
        input_dim,
        hidden_dim,
        output_dim,
        dropout_ratio,
        activation,
        norm_type="none"
    ):
        super().__init__()
        self.num_layers = num_layers
        self.norm_type = norm_type
        self.dropout = nn.Dropout(dropout_ratio)
        self.activation = activation

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        if num_layers == 1:
            self.layers.append(GCNConv(input_dim, output_dim))
        else:
            self.layers.append(GCNConv(input_dim, hidden_dim))
            if norm_type == "batch":
                self.norms.append(nn.BatchNorm1d(hidden_dim))
            elif norm_type == "layer":
                self.norms.append(nn.LayerNorm(hidden_dim))

            for _ in range(num_layers - 2):
                self.layers.append(GCNConv(hidden_dim, hidden_dim))
                if norm_type == "batch":
                    self.norms.append(nn.BatchNorm1d(hidden_dim))
                elif norm_type == "layer":
                    self.norms.append(nn.LayerNorm(hidden_dim))

            self.layers.append(GCNConv(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        h_list = []
        h = x
        for l, layer in enumerate(self.layers):
            h = layer(h, edge_index)
            if l != self.num_layers - 1:
                if self.norm_type != "none":
                    h = self.norms[l](h)
                h = self.activation(h)
                h = self.dropout(h)
                h_list.append(h)
        return h_list[-1], h

In [29]:
class SAGE(nn.Module):
    def __init__(
        self,
        num_layers,
        input_dim,
        hidden_dim,
        output_dim,
        dropout_ratio,
        activation,
        norm_type="none",
    ):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.norm_type = norm_type
        self.activation = activation
        self.dropout = nn.Dropout(dropout_ratio)
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        if num_layers == 1:
            self.layers.append(SAGEConv(input_dim, output_dim))
        else:
            self.layers.append(SAGEConv(input_dim, hidden_dim))
            if self.norm_type == "batch":
                self.norms.append(nn.BatchNorm1d(hidden_dim))
            elif self.norm_type == "layer":
                self.norms.append(nn.LayerNorm(hidden_dim))

            for _ in range(num_layers - 2):
                self.layers.append(SAGEConv(hidden_dim, hidden_dim))
                if self.norm_type == "batch":
                    self.norms.append(nn.BatchNorm1d(hidden_dim))
                elif self.norm_type == "layer":
                    self.norms.append(nn.LayerNorm(hidden_dim))

            self.layers.append(SAGEConv(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        h = x
        h_list = []
        for l, layer in enumerate(self.layers):
            h = layer(h, edge_index)
            if l != self.num_layers - 1:
                h_list.append(h)
                if self.norm_type != "none":
                    h = self.norms[l](h)
                h = self.activation(h)
                h = self.dropout(h)
        return h_list[-1], h


    def inference(self, x_all, edge_index, batch_size=1024, device="cuda"):
        """
        Full-graph inference using mini-batches (for large graphs).
        """
        from torch_geometric.loader import NeighborLoader

        x = x_all.to(device)
        for l, layer in enumerate(self.layers):
            new_x = torch.zeros(
                x_all.size(0),
                self.hidden_dim if l != self.num_layers - 1 else self.output_dim,
            ).to(device)

            loader = NeighborLoader(
                data=(x_all, edge_index),
                input_nodes=torch.arange(x_all.size(0)),
                num_neighbors=[-1],  # full neighbors
                batch_size=batch_size,
                shuffle=False
            )

            for batch in loader:
                batch = batch.to(device)
                h = x[batch.n_id]
                h = layer(h, batch.edge_index)

                if l != self.num_layers - 1:
                    if self.norm_type != "none":
                        h = self.norms[l](h)
                    h = self.activation(h)
                    h = self.dropout(h)

                new_x[batch.n_id[:batch.batch_size]] = h

            x = new_x
        return x
# For small, medium datasets few thousands, use model() in eval
# For large ones like 100k or millions, use inference

In [30]:
class GAT(nn.Module):
    def __init__(
            self,
            num_layers,
            input_dim,
            hidden_dim,
            output_dim,
            dropout_ratio,
            activation=F.relu,
            num_heads=8,
            attn_drop=0.3,
            negative_slope=0.2,
            residual=False,
    ):
        super().__init__()
        
        assert num_layers > 1

        hidden_dim //= num_heads 
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        self.activation = activation
        self.dropout = nn.Dropout(dropout_ratio)

        heads = [num_heads] * (num_layers - 1) + [1]  
        # heads = ([num_heads] * num_layers) + [1]

        # Input layer
        self.layers.append(
            GATConv(
                in_channels=input_dim,
                out_channels=hidden_dim,
                heads=heads[0],
                dropout=attn_drop,
                negative_slope=negative_slope,
                concat=True, 
            )
        )

        # Hidden layers
        for l in range(1, num_layers - 1):
            self.layers.append(
                GATConv(
                    in_channels=hidden_dim * heads[l - 1],  
                    out_channels=hidden_dim,
                    heads=heads[l],
                    dropout=attn_drop,
                    negative_slope=negative_slope,
                    concat=True, 
                )
            )

        # Output layer
        self.layers.append(
            GATConv(
                in_channels=hidden_dim * heads[-2],  
                out_channels=output_dim,
                heads=heads[-1],  
                dropout=attn_drop,
                negative_slope=negative_slope,
                concat=False, 
            )
        )

    def forward(self, x, edge_index):
        h_list = []
        h = x
        for l, layer in enumerate(self.layers):
            h = self.dropout(h) 
            h = layer(h, edge_index)
            if l != self.num_layers - 1:
                h = self.activation(h)  
                h_list.append(h)
        return h_list[-1], h

In [31]:
class APPNP_Model(nn.Module):
    def __init__(
            self,
            num_layers,
            input_dim,
            hidden_dim,
            output_dim,
            dropout_ratio,
            activation=F.relu,
            norm_type="none",
            edge_drop=0,
            alpha=0.1,
            k=10,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.norm_type = norm_type
        self.activation = activation
        self.dropout = nn.Dropout(dropout_ratio)
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        # Input layer
        if num_layers == 1:
            self.layers.append(nn.Linear(input_dim, output_dim))
        else:
            self.layers.append(nn.Linear(input_dim, hidden_dim))
            if self.norm_type == "batch":
                self.norms.append(nn.BatchNorm1d(hidden_dim))
            elif self.norm_type == "layer":
                self.norms.append(nn.LayerNorm(hidden_dim))

            # Hidden layers
            for _ in range(num_layers - 2):
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                if self.norm_type == "batch":
                    self.norms.append(nn.BatchNorm1d(hidden_dim))
                elif self.norm_type == "layer":
                    self.norms.append(nn.LayerNorm(hidden_dim))

            # Output layer
            self.layers.append(nn.Linear(hidden_dim, output_dim))

        self.propagate = APPNP(K=k, alpha=alpha, dropout=edge_drop)
        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()

    def forward(self, x, edge_index):
        h_list = []
        h = x

        for l, layer in enumerate(self.layers):
            h = layer(h)

            if l != self.num_layers - 1:  
                h_list.append(h)
                if self.norm_type != "none":
                    h = self.norms[l](h)
                h = self.activation(h)
                h = self.dropout(h)

        h = self.propagate(h, edge_index)
        return h_list[-1], h 

In [35]:
def train_sage(model, loader, optimizer, criterion, device, homo=True):
    model.train()
    total_loss = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        x = batch.x
        y = batch.y[:batch.batch_size]  # Only use input nodes

        if homo:
            edge_index = batch.edge_index
        else:
            rel = list(batch.edge_index_dict.keys())[0]
            edge_index = batch.edge_index_dict[rel]

        _, out = model(x, edge_index)
        out = out[:batch.batch_size]  # Only use predictions for input nodes

        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)


@torch.no_grad()
def evaluate_sage(model, loader, device, homo=True):
    model.eval()
    correct = 0
    total = 0

    for batch in loader:
        batch = batch.to(device)
        x = batch.x
        y = batch.y[:batch.batch_size]  # Only input nodes

        if homo:
            edge_index = batch.edge_index
        else:
            rel = list(batch.edge_index_dict.keys())[0]
            edge_index = batch.edge_index_dict[rel]

        _, out = model(x, edge_index)
        out = out[:batch.batch_size]  # Only predictions for input nodes

        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return correct / total

def train(model, data, edge_index, labels, train_mask, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    _, out = model(data, edge_index)
    loss = criterion(out[train_mask], labels[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data, edge_index, labels, idx):
    model.eval()
    _, out = model(data, edge_index)
    pred = out[idx].argmax(dim=1)
    correct = (pred == labels[idx]).sum().item()
    acc = correct / sum(idx)
    return acc
    
# save embeddings, softmax scores tensors above in a directory
def save_tensors(emb_t, z_soft, output_dir):
    torch.save(emb_t, f"{output_dir}/embeddings.pt")
    torch.save(z_soft, f"{output_dir}/label_scores.pt")
# Example usage
output_dir = "./teacher_outputs"
import os
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [36]:
def run_SAGE(data, num_features, num_classes, setting="tran"):   
    train_mask, val_mask, test_mask = data.train_mask, data.val_mask, data.test_mask 
    if setting == "ind":
        test_mask = data.inductive_mask
    
    train_loader = NeighborLoader(
        data,
        input_nodes=train_mask,
        num_neighbors=[5, 5],
        batch_size=32,
        shuffle=True
    )
    
    #TODO: don't know what's val mask in ind setting
    val_loader = NeighborLoader(
        data,
        input_nodes=val_mask,
        num_neighbors=[-1, -1],
        batch_size=32
    )
    
    test_loader = NeighborLoader(
        data,
        input_nodes=test_mask,
        num_neighbors=[-1, -1],
        batch_size=32
    )
    
    
    model = SAGE(
        num_layers=2,
        input_dim=num_features,
        hidden_dim=128,
        output_dim=num_classes,
        dropout_ratio=0,
        activation=nn.functional.relu,
        norm_type="batch"
    )
    
    device = 'cuda'
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(1, 201):
        loss = train_sage(model, train_loader, optimizer, criterion, device)
        val_acc = evaluate_sage(model, val_loader, device)
        if epoch % 10 == 0 or epoch == 1:
            test_acc =evaluate_sage(model, test_loader, device)
            print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}")

    mb_t, z_soft = model.forward(data.x.to(device), data.edge_index.to(device))
    return mb_t, z_soft

In [75]:
def run_model(model, features, edge_index, labels, train_mask, val_mask, test_mask, setting="tran", orig_edge_index=[]):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    edge_index_eval = edge_index
    if setting=="ind":
        edge_index_eval = orig_edge_index
    
    for epoch in range(1, 200):
        loss = train(model, features, edge_index, labels, train_mask, optimizer, criterion)
        val_acc = evaluate(model, features, edge_index_eval, labels, val_mask)
        if epoch % 10 == 0 or epoch == 1:
            test_acc = evaluate(model, features, edge_index_eval, labels, test_mask)
            print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}")

In [65]:
def run_GCN(num_features, num_classes, features, edge_index, labels, train_mask, val_mask, test_mask, setting="tran",  orig_edge_index=[]):
    gcn_model = GCN(
        num_layers=3,
        input_dim=num_features,
        hidden_dim=64,
        output_dim=num_classes,
        dropout_ratio=0.8,
        activation=nn.functional.relu,
        norm_type="batch"
    )
    # print(ind_edge_index.shape)
    run_model(gcn_model, features, edge_index, labels, train_mask, val_mask, test_mask,  setting, orig_edge_index)

def run_APPNP(num_features, num_classes, features, edge_index, labels, train_mask, val_mask, test_mask,  setting="tran", orig_edge_index=[]):
    appnp_model = APPNP_Model(
        num_layers=2,
        input_dim=num_features,  
        hidden_dim=128,
        output_dim=num_classes,  
        dropout_ratio=0.5,
        activation=F.relu,
    )
    
    run_model(appnp_model, features, edge_index, labels, train_mask, val_mask, test_mask, setting, orig_edge_index)

def run_GAT(num_features, num_classes, features, edge_index, labels, train_mask, val_mask, test_mask, setting="tran", orig_edge_index=[]):
    gat_model = GAT(
        num_layers=2,
        input_dim=num_features,  
        hidden_dim=128,
        output_dim=num_classes,       
        dropout_ratio=0.6,
        activation=F.relu,
        num_heads=8,
        attn_drop=0.3,
        negative_slope=0.2,
        residual=True
    )
    
    run_model(gat_model, features, edge_index, labels, train_mask, val_mask, test_mask, setting, orig_edge_index)

In [80]:
# Code to train models like GCN, GAT, APPNP
setting = "ind"
num_features, num_classes, features, labels, edge_index, ind_edge_index, train_mask, val_mask, test_mask = load_data("cora", setting)
edges = edge_index
if setting == "ind":
    edges = ind_edge_index
run_GAT(num_features, num_classes, features, edges, labels, train_mask, val_mask, test_mask, setting, edge_index)

In [None]:
# Code to train SAGE
setting = "tran"
dataset = Planetoid(root='./Cora', name='Cora')
data = dataset[0]
data.ind_edge_index = []
data.observed_mask = []
data.inductive_mask = []
test_mask = data.test_mask
if setting=="ind":
    data = add_inductive_settings(data)
    test_mask = data.inductive_mask
run_SAGE(data, dataset.num_node_features, dataset.num_classes, setting)

In [None]:
def main():
    parser = argparse.ArgumentParser(description="Teacher implementation")
    parser.add_argument('--num_runs', type=int, default=1, help='Number of runs')
    parser.add_argument('--setting', type=str, choices=['trans', 'ind'], required=True, help='Setting type: trans or ind')
    parser.add_argument('--data_path', type=str, required=True, help='Path to the dataset')
    parser.add_argument('--model_name', type=str, default='SAGE', help='Name of the model(SAGE, GCN, GAT, APPNP)')
    parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model')
    parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden dimension size')
    parser.add_argument('--drop_out', type=float, default=0, help='Dropout rate')
    parser.add_argument('--batch_sz', type=int, default=512, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--output_path', type=str, default='./output', help='Path to save output')
    
    args = parser.parse_args()

    