<a href="https://colab.research.google.com/github/Tineeee2002/Knowledge-Distillation-on-Graph/blob/main/Knowledge%20Distillation%20on%20Graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
!pip install torch_geometric



In [60]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import torch.nn as nn
from torch_geometric.nn import GCNConv
import numpy as np
from torch_geometric.utils import k_hop_subgraph
# Load Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

class TeacherGCN(nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super(TeacherGCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        h1 = F.relu(self.conv1(x, edge_index))
        h1 = self.dropout(h1)
        h2 = F.relu(self.conv2(h1, edge_index))
        h2 = self.dropout(h2)
        out = self.conv3(h2, edge_index)
        return out, [h1, h2]

class StudentGCN(nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super(StudentGCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        h1 = F.relu(self.conv1(x, edge_index))
        h1 = self.dropout(h1)
        out = self.conv2(h1, edge_index)
        return out, [h1]

class LocalStructurePreservingDistillation:
    def __init__(self, teacher_model, student_model, temperature=4.0):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature

    def get_local_structure(self, node_idx, embeddings, edge_index, num_hops=4):

        if isinstance(node_idx, int):  # Nếu node_idx là một giá trị integer (một node duy nhất)
            node_idx = torch.tensor([node_idx], dtype=torch.long)  # Chuyển thành tensor 1 chiều
        elif node_idx.dim() == 0:  # Nếu node_idx là một scalar (0 chiều)
            node_idx = node_idx.unsqueeze(0)  # Chuyển thành tensor 1 chiều

        # Kiểm tra rằng node_idx có chiều đúng và edge_index cũng vậy
        assert node_idx.dim() == 1, f"node_idx must be a 1D tensor, got {node_idx.dim()} dimension."
        assert edge_index.dim() == 2 and edge_index.size(0) == 2, "edge_index must be a tensor of shape (2, E)"




        # Lấy k-hop subgraph
        subset, sub_edge_index, mapping, _ = k_hop_subgraph(
            node_idx, num_hops, edge_index, relabel_nodes=True)
        subset = subset.to(torch.long)
        # Lấy embeddings của các nodes trong subgraph
        local_embeddings = embeddings[subset]

        # Tính similarity matrix
        sim_matrix = torch.mm(local_embeddings, local_embeddings.t())

        return sim_matrix, subset

    def compute_distillation_loss(self, x, edge_index, labels, mask):
        # Forward pass qua teacher
        with torch.no_grad():
            teacher_logits, teacher_features = self.teacher(x, edge_index)
            teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)

        # Forward pass qua student
        student_logits, student_features = self.student(x, edge_index)
        student_probs = F.softmax(student_logits / self.temperature, dim=1)

        # 1. Knowledge Distillation Loss
        kd_loss = F.kl_div(
            F.log_softmax(student_logits[mask] / self.temperature, dim=1),
            teacher_probs[mask],
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # 2. Local Structure Preservation Loss
        structure_loss = 0
        sampled_nodes = torch.where(mask)[0]
        batch_nodes = sampled_nodes[torch.randint(0, len(sampled_nodes), (10,))]

        for node_idx in batch_nodes:
            # Lấy cấu trúc cục bộ từ teacher và student
            teacher_sim, subset = self.get_local_structure(
                node_idx, teacher_features[-1], edge_index)
            student_sim, _ = self.get_local_structure(
                node_idx, student_features[-1], edge_index)

            structure_loss += F.mse_loss(student_sim, teacher_sim)

        structure_loss = structure_loss / len(batch_nodes)

        # 3. Task-specific Loss
        task_loss = F.cross_entropy(student_logits[mask], labels[mask])

        # Tổng hợp losses
        total_loss = task_loss + 0.6 * kd_loss + 0.3 * structure_loss

        return total_loss, task_loss, kd_loss, structure_loss

def train_teacher(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out, _ = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

def evaluate(model, data, mask):
    model.eval()
    with torch.no_grad():
        out, _ = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        correct = pred[mask] == data.y[mask]
        acc = int(correct.sum()) / int(mask.sum())
    return acc

# Khởi tạo models
num_features = dataset.num_features
hidden_dim = 64
num_classes = dataset.num_classes

teacher = TeacherGCN(num_features, hidden_dim, num_classes).to(device)
student = StudentGCN(num_features, hidden_dim, num_classes).to(device)

# Train teacher trước
teacher_optimizer = torch.optim.Adam(teacher.parameters(), lr=0.001)
best_val_acc = 0
best_teacher_state = None

print("Training Teacher Model...")
for epoch in range(200):
    loss = train_teacher(teacher, data, teacher_optimizer)

    if epoch % 20 == 0:
        train_acc = evaluate(teacher, data, data.train_mask)
        val_acc = evaluate(teacher, data, data.val_mask)
        test_acc = evaluate(teacher, data, data.test_mask)
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
              f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_teacher_state = teacher.state_dict()

# Load best teacher model
teacher.load_state_dict(best_teacher_state)

# Khởi tạo distiller và optimizer cho student
distiller = LocalStructurePreservingDistillation(teacher, student)
student_optimizer = torch.optim.Adam(student.parameters(), lr=0.01)

# Train student với distillation
print("\nTraining Student Model with Distillation...")
best_val_acc = 0
best_student_state = None

for epoch in range(200):
    student.train()
    total_loss, task_loss, kd_loss, structure_loss = distiller.compute_distillation_loss(
        data.x, data.edge_index, data.y, data.train_mask)

    student_optimizer.zero_grad()
    total_loss.backward()
    student_optimizer.step()

    if epoch % 20 == 0:
        train_acc = evaluate(student, data, data.train_mask)
        val_acc = evaluate(student, data, data.val_mask)
        test_acc = evaluate(student, data, data.test_mask)
        print(f'Epoch {epoch:03d}, Loss: {total_loss:.4f}, Train: {train_acc:.4f}, '
              f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')
        print(f'Task Loss: {task_loss:.4f}, KD Loss: {kd_loss:.4f}, '
              f'Structure Loss: {structure_loss:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_student_state = student.state_dict()

# Load best student model và evaluate
student.load_state_dict(best_student_state)

# Final evaluation
print("\nFinal Results:")
print("Teacher Model:")
print(f"Train accuracy: {evaluate(teacher, data, data.train_mask):.4f}")
print(f"Val accuracy: {evaluate(teacher, data, data.val_mask):.4f}")
print(f"Test accuracy: {evaluate(teacher, data, data.test_mask):.4f}")

print("\nStudent Model:")
print(f"Train accuracy: {evaluate(student, data, data.train_mask):.4f}")
print(f"Val accuracy: {evaluate(student, data, data.val_mask):.4f}")
print(f"Test accuracy: {evaluate(student, data, data.test_mask):.4f}")

Training Teacher Model...
Epoch 000, Loss: 1.9468, Train: 0.2500, Val: 0.1760, Test: 0.1830
Epoch 020, Loss: 1.9054, Train: 0.8571, Val: 0.6260, Test: 0.6070
Epoch 040, Loss: 1.7951, Train: 0.8857, Val: 0.6620, Test: 0.6560
Epoch 060, Loss: 1.5959, Train: 0.9000, Val: 0.6700, Test: 0.6790
Epoch 080, Loss: 1.2238, Train: 0.9357, Val: 0.7160, Test: 0.7250
Epoch 100, Loss: 0.8771, Train: 0.9571, Val: 0.7580, Test: 0.7720
Epoch 120, Loss: 0.5771, Train: 0.9786, Val: 0.7780, Test: 0.8020
Epoch 140, Loss: 0.3888, Train: 0.9857, Val: 0.7800, Test: 0.8090
Epoch 160, Loss: 0.2968, Train: 0.9857, Val: 0.7860, Test: 0.8090
Epoch 180, Loss: 0.2216, Train: 0.9857, Val: 0.7880, Test: 0.8080

Training Student Model with Distillation...
Epoch 000, Loss: 42.9325, Train: 0.2571, Val: 0.1440, Test: 0.1740
Task Loss: 1.9464, KD Loss: 3.4547, Structure Loss: 129.7110
Epoch 020, Loss: 16.9481, Train: 0.5357, Val: 0.2440, Test: 0.2850
Task Loss: 1.7885, KD Loss: 3.4832, Structure Loss: 43.5654
Epoch 040, Los