In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# cd /content/drive/MyDrive/3D Computer Vision/final_project

In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, JumpingKnowledge, GCNConv
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader
from torch_geometric.utils import dense_to_sparse

In [5]:
# -------------------------
# Student Model
# -------------------------
class StudentModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, use_jk=False):
        super(StudentModel, self).__init__()
        self.use_jk = use_jk

        # GIN Layers
        self.gin_convs = nn.ModuleList()
        self.gin_convs.append(
            GINConv(nn.Sequential(
                nn.Linear(in_channels, hidden_channels),
                nn.ReLU(),
                nn.Linear(hidden_channels, hidden_channels),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_channels)
            ))
        )
        for _ in range(num_layers - 1):
            self.gin_convs.append(
                GINConv(nn.Sequential(
                    nn.Linear(hidden_channels, hidden_channels),
                    nn.ReLU(),
                    nn.Linear(hidden_channels, hidden_channels),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_channels)
                ))
            )

        # Jumping Knowledge (optional)
        if self.use_jk:
            self.jk = JumpingKnowledge(mode='cat')
            self.lin1 = nn.Linear(num_layers * hidden_channels, hidden_channels)
        else:
            self.lin1 = nn.Linear(hidden_channels, hidden_channels)

        self.lin2 = nn.Linear(hidden_channels, out_channels)

        # Label embeddings
        self.label_embeddings = nn.Parameter(torch.randn(out_channels, hidden_channels))  # Learnable label embeddings


    def forward(self, x, edge_index):
        embeddings = []
        for conv in self.gin_convs:
            x = conv(x, edge_index)
            embeddings.append(x)

        if self.use_jk:
            x = self.jk(embeddings)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)  # Output node classification logits
        return x

In [6]:
# -------------------------
# Teacher Model
# -------------------------
class TeacherModel(nn.Module):
    def __init__(self, student_model):
        super(TeacherModel, self).__init__()
        # Clone student model structure
        self.model = StudentModel(
            in_channels=student_model.gin_convs[0].nn[0].in_features,
            hidden_channels=student_model.gin_convs[0].nn[0].out_features,
            out_channels=student_model.lin2.out_features,
            num_layers=len(student_model.gin_convs),
            use_jk=student_model.use_jk
        )
        self.model.load_state_dict(student_model.state_dict())  # Initialize with student weights

    @torch.no_grad()
    def update(self, student_model, alpha=0.99):
        """Update teacher parameters using EMA."""
        for teacher_param, student_param in zip(self.model.parameters(), student_model.parameters()):
            teacher_param.data = alpha * teacher_param.data + (1 - alpha) * student_param.data

    def forward(self, x, edge_index):
        return self.model(x, edge_index)

In [7]:
# -------------------------
# Semi-Supervised Loss Function
# -------------------------
def graph_consistency_loss(student_out, teacher_out, mask):
    """Calculate consistency loss using adjacency similarity."""
    student_probs = F.softmax(student_out[mask], dim=1)
    teacher_probs = F.softmax(teacher_out[mask], dim=1)

    student_adj = torch.mm(student_probs, student_probs.T)
    teacher_adj = torch.mm(teacher_probs, teacher_probs.T)

    return F.mse_loss(student_adj, teacher_adj)

def semi_supervised_loss(student_out, teacher_out, labels, labeled_mask, unlabeled_mask, lambda_con=0.1):
    """Combined supervised and consistency loss."""
    # sup_loss = F.cross_entropy(student_out[labeled_mask], labels[labeled_mask])
    # sup_loss = torch.nn.BCEWithLogitsLoss(student_out[labeled_mask], labels[labeled_mask])
    bce_loss = nn.BCEWithLogitsLoss()
    sup_loss = bce_loss(student_out[labeled_mask], labels[labeled_mask])

    con_loss = graph_consistency_loss(student_out, teacher_out, unlabeled_mask)
    return sup_loss + lambda_con * con_loss

In [8]:
# -------------------------
# Training Framework
# -------------------------
class SemiGNNFrameworkPPI:
    def __init__(self, student_model, teacher_model, device):
        self.student = student_model.to(device)
        self.teacher = teacher_model.to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(self.student.parameters(), lr=0.005)
        self.lambda_con = 0.1

    def train_step(self, loader):
        """Train student and update teacher model."""
        self.student.train()
        self.teacher.eval()
        total_loss = 0

        for data in loader:
            data = data.to(self.device)
            labeled_mask = data.train_mask
            unlabeled_mask = ~data.train_mask

            # Forward pass
            student_out = self.student(data.x, data.edge_index)
            with torch.no_grad():
                teacher_out = self.teacher(data.x, data.edge_index)

            # Compute loss
            loss = semi_supervised_loss(student_out, teacher_out, data.y, labeled_mask, unlabeled_mask, self.lambda_con)

            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()

        # Update teacher model
        self.teacher.update(self.student)
        return total_loss / len(loader)


    def evaluate(self, loader):
      """Evaluate the student model for multi-label node classification."""
      self.student.eval()
      total_correct, total_labels = 0, 0

      for data in loader:
          data = data.to(self.device)
          with torch.no_grad():
              out = self.student(data.x, data.edge_index)  # Raw logits
              prob = torch.sigmoid(out)                   # Probabilities (0 to 1)
              pred = (prob > 0.5).long()                  # Binarize predictions
              # print(data.val_mask.shape)
              # print(pred.shape)
              # print(data.y.shape)

              # Validation mask and multi-label accuracy
              correct = (pred == data.y).sum().item()
              total_correct += correct
              total_labels += data.y.size(0) * data.y.size(1)  # Total labels (nodes * 121)

      accuracy = total_correct / total_labels
      return accuracy



In [12]:
# -------------------------
# Main Script
# -------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load PPI dataset
train_dataset = PPI(root='data/PPI', split='train')
val_dataset = PPI(root='data/PPI', split='val')
test_dataset = PPI(root='data/PPI', split='test')

from torch_geometric.data import Data

def create_train_mask(data, train_ratio=0.9):
    """
    Create a train_mask for nodes in a graph.
    :param data: A graph in the PPI dataset (torch_geometric.data.Data).
    :param train_ratio: Fraction of nodes to use for training.
    :return: train_mask (torch.Tensor)
    """
    num_nodes = data.x.size(0)  # Total number of nodes
    num_train = int(train_ratio * num_nodes)  # Number of training nodes

    # Randomly permute indices and select train nodes
    perm = torch.randperm(num_nodes)
    train_indices = perm[:num_train]

    # Initialize train_mask
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[train_indices] = True
    return train_mask

# Update train_dataset with train_mask
updated_train_dataset = []
for graph in train_dataset:
    train_mask = create_train_mask(graph, train_ratio=0.05)
    updated_graph = Data(
        x=graph.x,
        edge_index=graph.edge_index,
        y=graph.y,
        train_mask=train_mask
    )
    updated_train_dataset.append(updated_graph)


# print(f"Train_mask example: {updated_train_dataset[0].train_mask}")

train_loader = DataLoader(updated_train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Initialize models
student_model = StudentModel(in_channels=train_dataset.num_features, hidden_channels=512,
                             out_channels=train_dataset.num_classes, num_layers=3, use_jk=True)
teacher_model = TeacherModel(student_model)

# Initialize framework
framework = SemiGNNFrameworkPPI(student_model, teacher_model, device)

# Training loop
for epoch in range(50):
    train_loss = framework.train_step(train_loader)
    val_accuracy = framework.evaluate(val_loader)
    print(f"Epoch {epoch+1:02d}: Train Loss: {train_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

# Testing
test_accuracy = framework.evaluate(test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}")


Epoch 01: Train Loss: 0.5895, Validation Accuracy: 0.7262
Epoch 02: Train Loss: 0.5047, Validation Accuracy: 0.7517
Epoch 03: Train Loss: 0.4434, Validation Accuracy: 0.7785
Epoch 04: Train Loss: 0.3836, Validation Accuracy: 0.7764
Epoch 05: Train Loss: 0.3259, Validation Accuracy: 0.7845
Epoch 06: Train Loss: 0.2781, Validation Accuracy: 0.7979
Epoch 07: Train Loss: 0.2381, Validation Accuracy: 0.7968
Epoch 08: Train Loss: 0.2147, Validation Accuracy: 0.7870
Epoch 09: Train Loss: 0.1820, Validation Accuracy: 0.8094
Epoch 10: Train Loss: 0.1541, Validation Accuracy: 0.8099
Epoch 11: Train Loss: 0.1301, Validation Accuracy: 0.8150
Epoch 12: Train Loss: 0.1169, Validation Accuracy: 0.8087
Epoch 13: Train Loss: 0.1108, Validation Accuracy: 0.8122
Epoch 14: Train Loss: 0.0981, Validation Accuracy: 0.8224
Epoch 15: Train Loss: 0.0809, Validation Accuracy: 0.8235
Epoch 16: Train Loss: 0.0666, Validation Accuracy: 0.8253
Epoch 17: Train Loss: 0.0550, Validation Accuracy: 0.8295
Epoch 18: Trai