In [11]:
import argparse
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
!pip install torch-geometric
from torch_geometric.nn import GCNConv, SAGEConv



TODO:
Implement teacher network architecture and training along with flags for datasets and teacher model architecture (Chanikya and Nithin)
eg: python3 train_teacher.py --dataset=cora --model=SAGE --epochs-100 --lr=0.01 . Add flags for other hyperparameters if necessary (Chanikya and Nithin)
Other teacher model architectures - GCN, GAT, APPNP (Chanikya and Nithin + others based on availability)

In [12]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='./Cora', name='Cora')

data = dataset[0]
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Dataset: Cora():
Number of graphs: 1
Number of features: 1433
Number of classes: 7


In [13]:
# data.edge_index.t()

In [14]:
def load_data(dataset):
    if dataset == "cora":
        dataset = Planetoid(root='./Cora', name='Cora')
        data = dataset[0]
        graph_nx = nx.Graph()
        graph_nx.add_edges_from(data.edge_index.t().tolist())

        # Adding self-loops
        # graph_nx.add_edges_from((n, n) for n in graph_nx.nodes())
        
        # adj_tensor = torch.tensor(nx.to_numpy_array(graph_nx), dtype=torch.float).to('cuda')
        features = data.x
        labels = data.y

        train_idx = data.train_mask.nonzero(as_tuple=True)[0]
        val_idx = data.val_mask.nonzero(as_tuple=True)[0]
        test_idx = data.test_mask.nonzero(as_tuple=True)[0]
        
        return data.edge_index, features, labels, data.train_mask, data.val_mask, data.test_mask

In [15]:
cnt = 0
140+1000+500
2708
for x in data.val_mask:
    cnt += (x==True)
print(cnt)

tensor(500)


GCN => num layers, hidden, input dim, out, dp, activa,  

In [16]:
class GCN(nn.Module):
    def __init__(
        self,
        num_layers,
        input_dim,
        hidden_dim,
        output_dim,
        dropout_ratio,
        activation,
        norm_type="none"
    ):
        super().__init__()
        self.num_layers = num_layers
        self.norm_type = norm_type
        self.dropout = nn.Dropout(dropout_ratio)
        self.activation = activation

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        if num_layers == 1:
            self.layers.append(GCNConv(input_dim, output_dim))
        else:
            self.layers.append(GCNConv(input_dim, hidden_dim))
            if norm_type == "batch":
                self.norms.append(nn.BatchNorm1d(hidden_dim))
            elif norm_type == "layer":
                self.norms.append(nn.LayerNorm(hidden_dim))

            for _ in range(num_layers - 2):
                self.layers.append(GCNConv(hidden_dim, hidden_dim))
                if norm_type == "batch":
                    self.norms.append(nn.BatchNorm1d(hidden_dim))
                elif norm_type == "layer":
                    self.norms.append(nn.LayerNorm(hidden_dim))

            self.layers.append(GCNConv(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        h_list = []
        h = x
        for l, layer in enumerate(self.layers):
            h = layer(h, edge_index)
            if l != self.num_layers - 1:
                if self.norm_type != "none":
                    h = self.norms[l](h)
                h = self.activation(h)
                h = self.dropout(h)
                h_list.append(h)
        return h

In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class SAGE(nn.Module):
    def __init__(
        self,
        num_layers,
        input_dim,
        hidden_dim,
        output_dim,
        dropout_ratio,
        activation,
        norm_type="none",
    ):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.norm_type = norm_type
        self.activation = activation
        self.dropout = nn.Dropout(dropout_ratio)
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        if num_layers == 1:
            self.layers.append(SAGEConv(input_dim, output_dim))
        else:
            self.layers.append(SAGEConv(input_dim, hidden_dim))
            if self.norm_type == "batch":
                self.norms.append(nn.BatchNorm1d(hidden_dim))
            elif self.norm_type == "layer":
                self.norms.append(nn.LayerNorm(hidden_dim))

            for _ in range(num_layers - 2):
                self.layers.append(SAGEConv(hidden_dim, hidden_dim))
                if self.norm_type == "batch":
                    self.norms.append(nn.BatchNorm1d(hidden_dim))
                elif self.norm_type == "layer":
                    self.norms.append(nn.LayerNorm(hidden_dim))

            self.layers.append(SAGEConv(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        h = x
        h_list = []
        for l, layer in enumerate(self.layers):
            h = layer(h, edge_index)
            if l != self.num_layers - 1:
                h_list.append(h)
                if self.norm_type != "none":
                    h = self.norms[l](h)
                h = self.activation(h)
                h = self.dropout(h)
        return h


    def inference(self, x_all, edge_index, batch_size=1024, device="cuda"):
        """
        Full-graph inference using mini-batches (for large graphs).
        """
        from torch_geometric.loader import NeighborLoader

        x = x_all.to(device)
        for l, layer in enumerate(self.layers):
            new_x = torch.zeros(
                x_all.size(0),
                self.hidden_dim if l != self.num_layers - 1 else self.output_dim,
            ).to(device)

            loader = NeighborLoader(
                data=(x_all, edge_index),
                input_nodes=torch.arange(x_all.size(0)),
                num_neighbors=[-1],  # full neighbors
                batch_size=batch_size,
                shuffle=False
            )

            for batch in loader:
                batch = batch.to(device)
                h = x[batch.n_id]
                h = layer(h, batch.edge_index)

                if l != self.num_layers - 1:
                    if self.norm_type != "none":
                        h = self.norms[l](h)
                    h = self.activation(h)
                    h = self.dropout(h)

                new_x[batch.n_id[:batch.batch_size]] = h

            x = new_x
        return x

In [46]:
# For small, medium datasets few thousands, use model() in eval
# For large ones like 100k or millions, use model

In [52]:
def train_sage(model, loader, optimizer, criterion, device, homo=True):
    model.train()
    total_loss = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        x = batch.x
        y = batch.y[:batch.batch_size]  # Only use input nodes

        if homo:
            edge_index = batch.edge_index
        else:
            rel = list(batch.edge_index_dict.keys())[0]
            edge_index = batch.edge_index_dict[rel]

        out = model(x, edge_index)
        out = out[:batch.batch_size]  # Only use predictions for input nodes

        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)


@torch.no_grad()
def evaluate_sage(model, loader, device, homo=True):
    model.eval()
    correct = 0
    total = 0

    for batch in loader:
        batch = batch.to(device)
        x = batch.x
        y = batch.y[:batch.batch_size]  # Only input nodes

        if homo:
            edge_index = batch.edge_index
        else:
            rel = list(batch.edge_index_dict.keys())[0]
            edge_index = batch.edge_index_dict[rel]

        out = model(x, edge_index)
        out = out[:batch.batch_size]  # Only predictions for input nodes

        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return correct / total


In [64]:
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    data,
    input_nodes=data.train_mask,
    num_neighbors=[15, 10],
    batch_size=128,
    shuffle=True
)

val_loader = NeighborLoader(
    data,
    input_nodes=data.val_mask,
    num_neighbors=[-1, -1],
    batch_size=128
)

test_loader = NeighborLoader(
    data,
    input_nodes=data.test_mask,
    num_neighbors=[-1, -1],
    batch_size=128
)


model = SAGE(
    num_layers=2,
    input_dim=dataset.num_node_features,
    hidden_dim=128,
    output_dim=dataset.num_classes,
    dropout_ratio=0,
    activation=nn.functional.relu,
    norm_type="batch"
)

device = 'cuda'
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1, 201):
    loss = train_sage(model, train_loader, optimizer, criterion, device)
    val_acc = evaluate_sage(model, val_loader, device)
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")

test_acc = evaluate_sage(model, test_loader, device)
print(f"Test Acc: {test_acc:.4f}")


Epoch 01 | Loss: 1.9665 | Val Acc: 0.3540
Epoch 02 | Loss: 0.6244 | Val Acc: 0.4660
Epoch 03 | Loss: 0.2881 | Val Acc: 0.5620
Epoch 04 | Loss: 0.1405 | Val Acc: 0.6460
Epoch 05 | Loss: 0.0713 | Val Acc: 0.7040
Epoch 06 | Loss: 0.0565 | Val Acc: 0.7260
Epoch 07 | Loss: 0.0291 | Val Acc: 0.7340
Epoch 08 | Loss: 0.0259 | Val Acc: 0.7460
Epoch 09 | Loss: 0.0148 | Val Acc: 0.7480
Epoch 10 | Loss: 0.0132 | Val Acc: 0.7460
Epoch 11 | Loss: 0.0103 | Val Acc: 0.7480
Epoch 12 | Loss: 0.0062 | Val Acc: 0.7460
Epoch 13 | Loss: 0.0080 | Val Acc: 0.7380
Epoch 14 | Loss: 0.0061 | Val Acc: 0.7420
Epoch 15 | Loss: 0.0049 | Val Acc: 0.7400
Epoch 16 | Loss: 0.0032 | Val Acc: 0.7460
Epoch 17 | Loss: 0.0026 | Val Acc: 0.7440
Epoch 18 | Loss: 0.0036 | Val Acc: 0.7440
Epoch 19 | Loss: 0.0028 | Val Acc: 0.7440
Epoch 20 | Loss: 0.0023 | Val Acc: 0.7440
Epoch 21 | Loss: 0.0023 | Val Acc: 0.7440
Epoch 22 | Loss: 0.0029 | Val Acc: 0.7460
Epoch 23 | Loss: 0.0026 | Val Acc: 0.7460
Epoch 24 | Loss: 0.0025 | Val Acc:

In [60]:
def train(model, data, edge_index, labels, train_idx, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data, edge_index)
    loss = criterion(out[train_idx], labels[train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data, edge_index, labels, idx):
    model.eval()
    out = model(data, edge_index)
    pred = out[idx].argmax(dim=1)
    correct = (pred == labels[idx]).sum().item()
    acc = correct / sum(idx)
    return acc

In [None]:
edge_index, features, labels, train_idx, val_idx, test_idx = load_data("cora")
model = GCN(
    num_layers=3,
    input_dim=dataset.num_node_features,
    hidden_dim=64,
    output_dim=dataset.num_classes,
    dropout_ratio=0.8,
    activation=nn.functional.relu,
    norm_type="batch"
)
# model = GCN1(dataset.num_node_features, 64, dataset.num_classes)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1, 150):
    loss = train(model, features, edge_index, labels, train_idx, optimizer, criterion)
    val_acc = evaluate(model, features, edge_index, labels, val_idx)
    if epoch % 10 == 0 or epoch == 1:
        test_acc = evaluate(model, features, edge_index, labels, test_idx)
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}")


In [None]:
model(data.x, data.edge_index)

In [None]:
class Teacher:
    def __init__(self, args):
        self.args = args
        pass
    def graph_split():
        pass
    def train_transductive():
        pass
    def train_inductive():
        pass
    

In [None]:
def main():
    parser = argparse.ArgumentParser(description="Teacher implementation")
    parser.add_argument('--num_runs', type=int, default=1, help='Number of runs')
    parser.add_argument('--setting', type=str, choices=['trans', 'ind'], required=True, help='Setting type: trans or ind')
    parser.add_argument('--data_path', type=str, required=True, help='Path to the dataset')
    parser.add_argument('--model_name', type=str, default='SAGE', help='Name of the model(SAGE, GCN, GAT, APPNP)')
    parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model')
    parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden dimension size')
    parser.add_argument('--drop_out', type=float, default=0, help='Dropout rate')
    parser.add_argument('--batch_sz', type=int, default=512, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--output_path', type=str, default='./output', help='Path to save output')
    
    args = parser.parse_args()

    