In [41]:
# from google.colab import drive
# drive.mount('/content/drive')

In [42]:
# cd /content/drive/MyDrive/3D Computer Vision/final_project

In [43]:
!pip install torch_geometric



In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, JumpingKnowledge, GCNConv
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader
from torch_geometric.utils import dense_to_sparse

In [45]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f721ee80bb0>

In [46]:
# -------------------------
# Label Learner Component
# -------------------------
class LabelGraphLearner(nn.Module):
    def __init__(self, num_labels):
        super(LabelGraphLearner, self).__init__()
        self.label_adj = nn.Parameter(torch.randn(num_labels, num_labels))  # Learnable adjacency matrix

    def forward(self):
        """
        Learn the label adjacency matrix as edge probabilities.
        :return: Learned label adjacency matrix (softmax-normalized).
        """
        adj = torch.sigmoid(self.label_adj)  # Map to [0, 1] as edge probabilities
        return adj

class LabelGCN(nn.Module):
    def __init__(self, num_labels, hidden_channels):
        super(LabelGCN, self).__init__()
        self.conv1 = GCNConv(num_labels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, label_adj):
        """
        Propagate through the label graph using GCN.
        :param label_adj: Adjacency matrix for labels.
        :return: Refined label embeddings.
        """
        x = torch.eye(label_adj.size(0)).to(label_adj.device)  # One-hot encoding for labels
        edge_index, _ = dense_to_sparse(label_adj)  # Convert adjacency matrix to edge_index format
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x


In [47]:
# -------------------------
# Graph Augmentation
# -------------------------

def edge_manipulation(edge_index, num_nodes, ratio=0.1):
    """
    Perform edge manipulation by randomly replacing a percentage of edges.
    :param edge_index: Original edge index (2 x num_edges).
    :param num_nodes: Number of nodes in the graph.
    :param ratio: Percentage of edges to replace.
    :return: Modified edge index.
    """
    device = edge_index.device  # Ensure all tensors are created on the same device
    num_edges = edge_index.size(1)
    num_replacements = int(ratio * num_edges)

    # Randomly select edges to replace
    replacement_indices = torch.randint(0, num_edges, (num_replacements,), device=device)
    new_edges = torch.randint(0, num_nodes, (2, num_replacements), device=device)  # Random new edges

    # Replace selected edges
    modified_edge_index = edge_index.clone()  # Clone to avoid modifying the original tensor
    modified_edge_index[:, replacement_indices] = new_edges
    return modified_edge_index


def node_manipulation(x, ratio=0.1):
    """
    Perform node manipulation by randomly masking a percentage of node features.
    :param x: Node features (num_nodes x feature_dim).
    :param ratio: Percentage of node features to mask.
    :return: Modified node features.
    """
    num_nodes, feature_dim = x.size()
    num_masked = int(ratio * num_nodes)

    # Randomly select nodes to mask
    mask_indices = torch.randint(0, num_nodes, (num_masked,))
    x[mask_indices] = 0  # Mask selected nodes
    return x


In [48]:
# -------------------------
# Student Model
# -------------------------
class StudentModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, use_jk=False):
        super(StudentModel, self).__init__()
        self.use_jk = use_jk

        # GIN Layers for node embeddings
        self.gin_convs = nn.ModuleList()
        self.gin_convs.append(
            GINConv(nn.Sequential(
                nn.Linear(in_channels, hidden_channels),
                nn.ReLU(),
                nn.Linear(hidden_channels, hidden_channels),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_channels)
            ))
        )
        for _ in range(num_layers - 1):
            self.gin_convs.append(
                GINConv(nn.Sequential(
                    nn.Linear(hidden_channels, hidden_channels),
                    nn.ReLU(),
                    nn.Linear(hidden_channels, hidden_channels),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_channels)
                ))
            )

        if self.use_jk:
            self.jk = JumpingKnowledge(mode='cat')
            self.lin1 = nn.Linear(num_layers * hidden_channels, hidden_channels)
        else:
            self.lin1 = nn.Linear(hidden_channels, hidden_channels)

        # Label learning components
        self.label_learner = LabelGraphLearner(out_channels)
        self.label_gcn = LabelGCN(out_channels, hidden_channels)

        # Graph-based interaction layer
        self.node_label_interaction = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        """
        Forward pass for the student model.
        :param x: Node features.
        :param edge_index: Graph structure.
        :return: Node-label logits and label adjacency matrix.
        """
        # Node embeddings
        embeddings = []
        for conv in self.gin_convs:
            x = conv(x, edge_index)
            embeddings.append(x)

        if self.use_jk:
            x = self.jk(embeddings)
        x = F.relu(self.lin1(x))  # Node embeddings (shape: [num_nodes, hidden_channels])

        # Label embeddings
        label_adj = self.label_learner()  # Learn the label graph
        label_embeddings = self.label_gcn(label_adj)  # Refine label embeddings (shape: [num_labels, hidden_channels])

        # Combine node and label embeddings into a joint graph
        # Create a node-label interaction graph
        num_nodes = x.size(0)
        num_labels = label_embeddings.size(0)

        # Concatenate node and label embeddings
        combined_embeddings = torch.cat([x, label_embeddings], dim=0)  # Shape: [num_nodes + num_labels, hidden_channels]

        # Create interaction edges between all nodes and labels
        node_indices = torch.arange(num_nodes)
        label_indices = torch.arange(num_labels) + num_nodes
        interaction_edges = torch.stack(torch.meshgrid(node_indices, label_indices)).reshape(2, -1).to(edge_index.device)

        # Pass through GCN for node-label interaction
        interaction_logits = self.node_label_interaction(combined_embeddings, interaction_edges)  # Shape: [num_nodes + num_labels, out_channels]

        # Extract node-label logits
        node_label_logits = interaction_logits[:num_nodes]  # Shape: [num_nodes, num_labels]

        return node_label_logits, x


In [49]:
# -------------------------
# Teacher Model
# -------------------------
class TeacherModel(nn.Module):
    def __init__(self, student_model):
        super(TeacherModel, self).__init__()
        # Clone student model structure
        self.model = StudentModel(
            in_channels=student_model.gin_convs[0].nn[0].in_features,
            hidden_channels=student_model.gin_convs[0].nn[0].out_features,
            out_channels=student_model.label_learner.label_adj.shape[0],
            num_layers=len(student_model.gin_convs),
            use_jk=student_model.use_jk
        )
        self.model.load_state_dict(student_model.state_dict())  # Initialize with student weights

    @torch.no_grad()
    def update(self, student_model, alpha=0.99):
        """Update teacher parameters using EMA."""
        for teacher_param, student_param in zip(self.model.parameters(), student_model.parameters()):
            teacher_param.data = alpha * teacher_param.data + (1 - alpha) * student_param.data

    def forward(self, x, edge_index):
        return self.model(x, edge_index)

In [50]:
# -------------------------
# Semi-Supervised Loss Function
# -------------------------

def supervised_loss(student_out, labels, labeled_mask):
    """
    Supervised loss for labeled data using Binary Cross-Entropy with logits.
    """
    bce_loss = F.binary_cross_entropy_with_logits(
        student_out[labeled_mask], labels[labeled_mask]
    )
    return bce_loss


def consistency_loss(student_out, teacher_out, mask):
    """Calculate consistency loss using adjacency similarity."""
    student_probs = F.softmax(student_out[mask], dim=1)
    teacher_probs = F.softmax(teacher_out[mask], dim=1)

    student_adj = torch.mm(student_probs, student_probs.T)
    teacher_adj = torch.mm(teacher_probs, teacher_probs.T)

    return F.mse_loss(student_adj, teacher_adj)

def calculate_adjacency_matrix(embeddings):
    """
    Compute the adjacency matrix using Pearson's Correlation Coefficient (PCC).
    :param embeddings: Node embeddings (tensor of shape [num_nodes, hidden_dim]).
    :return: Adjacency matrix (tensor of shape [num_nodes, num_nodes]).
    """
    embeddings = embeddings - embeddings.mean(dim=0, keepdim=True)
    norm = embeddings.norm(dim=1, keepdim=True)
    normalized_embeddings = embeddings / (norm + 1e-8)  # Prevent division by zero
    adj_matrix = torch.mm(normalized_embeddings, normalized_embeddings.T)
    return adj_matrix

def edge_matching_loss(student_embeddings, teacher_embeddings):
    """
    Compute the edge loss using adjacency matrices.
    :param student_embeddings: Student node embeddings (tensor of shape [num_nodes, hidden_dim]).
    :param teacher_embeddings: Teacher node embeddings (tensor of shape [num_nodes, hidden_dim]).
    :return: Edge loss (scalar).
    """
    adj_student = calculate_adjacency_matrix(student_embeddings)
    adj_teacher = calculate_adjacency_matrix(teacher_embeddings)
    return F.mse_loss(adj_student, adj_teacher)

def node_matching_loss(student_embeddings, teacher_embeddings):
    """
    Compute the node loss by aligning diagonal elements of the cross-embedding adjacency matrix.
    :param student_embeddings: Student node embeddings (tensor of shape [num_nodes, hidden_dim]).
    :param teacher_embeddings: Teacher node embeddings (tensor of shape [num_nodes, hidden_dim]).
    :return: Node loss (scalar).
    """
    # Cross-embedding adjacency matrix
    student_normalized = student_embeddings / (student_embeddings.norm(dim=1, keepdim=True) + 1e-8)
    teacher_normalized = teacher_embeddings / (teacher_embeddings.norm(dim=1, keepdim=True) + 1e-8)
    cross_adj_matrix = torch.mm(student_normalized, teacher_normalized.T)

    # Diagonal elements should align with identity
    diagonal_elements = torch.diagonal(cross_adj_matrix)
    identity = torch.ones_like(diagonal_elements)
    return F.mse_loss(diagonal_elements, identity)

def semignn_loss(
    student_out, teacher_out, labels, labeled_mask, unlabeled_mask,
    student_embeddings, teacher_embeddings,
    lambda_con=0.1, lambda_edge=0.1, lambda_node=0.1
):
    """
    Combined loss function for SemiGNN-PPI.
    """
    # Compute individual loss components
    sup_loss = supervised_loss(student_out, labels, labeled_mask)
    con_loss = consistency_loss(student_out, teacher_out, unlabeled_mask)
    edge_loss = edge_matching_loss(student_embeddings, teacher_embeddings)
    node_loss = node_matching_loss(student_embeddings, teacher_embeddings)

    # Combine losses with scaling factors
    total_loss = (
        sup_loss
        + lambda_con * con_loss
        + lambda_edge * edge_loss
        + lambda_node * node_loss
    )
    return total_loss

In [51]:
# -------------------------
# Training Framework
# -------------------------
class SemiGNNFrameworkPPI:
    def __init__(self, student_model, teacher_model, device):
        self.student = student_model.to(device)
        self.teacher = teacher_model.to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(self.student.parameters(), lr=0.001)
        self.lambda_con = 0.02

    def train_step(self, loader):
        self.student.train()
        self.teacher.eval()
        total_loss = 0

        for data in loader:
            data = data.to(self.device)
            labeled_mask = data.train_mask
            unlabeled_mask = ~data.train_mask

            # Apply edge manipulation
            edge_index_student = edge_manipulation(data.edge_index, data.x.size(0), ratio=0.1)  # 10% for student
            edge_index_teacher = edge_manipulation(data.edge_index, data.x.size(0), ratio=0.05)  # 5% for teacher
    
            # Apply node manipulation
            x_student = node_manipulation(data.x.clone(), ratio=0.1)  # 10% for student
            x_teacher = node_manipulation(data.x.clone(), ratio=0.05)  # 5% for teacher

            # Student forward pass
            student_out, student_embeddings = self.student(x_student, edge_index_student)

            # Teacher forward pass
            with torch.no_grad():
                teacher_out, teacher_embeddings = self.teacher(x_teacher, edge_index_teacher)

            # Compute loss
            loss = semignn_loss(
                student_out, teacher_out, data.y, labeled_mask, unlabeled_mask,
                student_embeddings, teacher_embeddings,
                lambda_con=0.02, lambda_edge=0.01, lambda_node=0.003  # Loss scaling factors based on SEMIGNN-PPI paper
            )

            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()

        # Update teacher model
        self.teacher.update(self.student)
        return total_loss / len(loader)


    def evaluate(self, loader):
      """Evaluate the student model for multi-label node classification."""
      self.student.eval()
      total_correct, total_labels = 0, 0

      for data in loader:
          data = data.to(self.device)
          with torch.no_grad():
              out, _ = self.student(data.x, data.edge_index)  # Raw logits
              prob = torch.sigmoid(out)                   # Probabilities (0 to 1)
              pred = (prob > 0.5).long()                  # Binarize predictions
              # print(data.val_mask.shape)
              # print(pred.shape)
              # print(data.y.shape)

              # Validation mask and multi-label accuracy
              correct = (pred == data.y).sum().item()
              total_correct += correct
              total_labels += data.y.size(0) * data.y.size(1)  # Total labels (nodes * 121)

      accuracy = total_correct / total_labels
      return accuracy



In [52]:
# -------------------------
# Main Script
# -------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load PPI dataset
train_dataset = PPI(root='data/PPI', split='train')
val_dataset = PPI(root='data/PPI', split='val')
test_dataset = PPI(root='data/PPI', split='test')


def create_train_mask(data, train_ratio=0.9):
    """
    Create a train_mask for nodes in a graph.
    :param data: A graph in the PPI dataset (torch_geometric.data.Data).
    :param train_ratio: Fraction of nodes to use for training.
    :return: train_mask (torch.Tensor)
    """
    num_nodes = data.x.size(0)  # Total number of nodes
    num_train = int(train_ratio * num_nodes)  # Number of training nodes

    # Randomly permute indices and select train nodes
    perm = torch.randperm(num_nodes)
    train_indices = perm[:num_train]

    # Initialize train_mask
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[train_indices] = True
    return train_mask

# Update train_dataset with train_mask
updated_train_dataset = []
for graph in train_dataset:
    train_mask = create_train_mask(graph, train_ratio=0.1)
    updated_graph = Data(
        x=graph.x,
        edge_index=graph.edge_index,
        y=graph.y,
        train_mask=train_mask
    )
    updated_train_dataset.append(updated_graph)


# print(f"Train_mask example: {updated_train_dataset[0].train_mask}")

train_loader = DataLoader(updated_train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Initialize models
student_model = StudentModel(in_channels=train_dataset.num_features, hidden_channels=512,
                             out_channels=train_dataset.num_classes, num_layers=3, use_jk=True)
teacher_model = TeacherModel(student_model)

# Initialize framework
framework = SemiGNNFrameworkPPI(student_model, teacher_model, device)

# Training loop
for epoch in range(50):
    train_loss = framework.train_step(train_loader)
    val_accuracy = framework.evaluate(val_loader)
    print(f"Epoch {epoch+1:02d}: Train Loss: {train_loss: .4f}, Validation Accuracy: {val_accuracy:.4f}")

# Testing
test_accuracy = framework.evaluate(test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}")


Epoch 01: Train Loss:  0.6065, Validation Accuracy: 0.6821
Epoch 02: Train Loss:  0.5298, Validation Accuracy: 0.7188
Epoch 03: Train Loss:  0.4988, Validation Accuracy: 0.7639
Epoch 04: Train Loss:  0.4762, Validation Accuracy: 0.7793
Epoch 05: Train Loss:  0.4575, Validation Accuracy: 0.7924
Epoch 06: Train Loss:  0.4389, Validation Accuracy: 0.7927
Epoch 07: Train Loss:  0.4203, Validation Accuracy: 0.7972
Epoch 08: Train Loss:  0.4042, Validation Accuracy: 0.8099
Epoch 09: Train Loss:  0.3896, Validation Accuracy: 0.8114
Epoch 10: Train Loss:  0.3740, Validation Accuracy: 0.8061
Epoch 11: Train Loss:  0.3631, Validation Accuracy: 0.8200
Epoch 12: Train Loss:  0.3496, Validation Accuracy: 0.8101
Epoch 13: Train Loss:  0.3449, Validation Accuracy: 0.7966
Epoch 14: Train Loss:  0.3288, Validation Accuracy: 0.8258
Epoch 15: Train Loss:  0.3158, Validation Accuracy: 0.8297
Epoch 16: Train Loss:  0.3090, Validation Accuracy: 0.8237
Epoch 17: Train Loss:  0.2993, Validation Accuracy: 0.83