In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepwalk.deepwalk import DeepWalk

In [95]:
class MLP(nn.Module):
    def __init__(self, num_layers, input_feat_dim, position_emb_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.num_layers = num_layers
        self.input_feat_dim = input_feat_dim
        self.position_emb_dim = position_emb_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.dropout = nn.Dropout(p=0.5)
        self.layers = nn.ModuleList()
        self.rsd_encoder = nn.Linear(hidden_dim, hidden_dim)
        self.layers.append(nn.Linear(input_feat_dim + position_emb_dim, hidden_dim))

        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))

        self.layers.append(nn.Linear(hidden_dim, output_dim))
    
    def forward(self, inp):
        for layer in self.layers[:-1]:
            inp = F.relu(layer(inp))
            inp = self.dropout(inp)
        out = self.layers[-1](inp)
        return inp, out
    
    def MLP_RSD(self, mlp_emb):
        return F.relu(self.rsd_encoder(mlp_emb))

In [17]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='./Cora', name='Cora')

data = dataset[0]

In [18]:
import networkx as nx

graph_nx = nx.Graph()
graph_nx.add_edges_from(data.edge_index.t().tolist())

In [5]:
deepwalk_model = DeepWalk(graph_nx, walk_length=80, walks_per_vertex=10)
deepwalk_model.train()
position_embeddings = deepwalk_model.get_embeddings()

Read 2M words
Number of words:  2709
Number of labels: 0
Progress: 100.0% words/sec/thread:  126582 lr:  0.000000 avg.loss:  3.680919 ETA:   0h 0m 0s
  return torch.tensor(embeddings)


In [10]:
data.x.shape

torch.Size([2708, 1433])

In [116]:
mlp = MLP(num_layers=3,
          input_feat_dim=data.x.shape[1],
          position_emb_dim=position_embeddings.shape[1],
          hidden_dim=64,
          output_dim=dataset.num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mlp = mlp.to(device)
data = data.to(device)
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4)

In [117]:
teacher_emb = torch.load('teacher_outputs/embeddings.pt')
teacher_out = torch.load('teacher_outputs/label_scores.pt')

In [118]:
kl_loss_fn = nn.KLDivLoss(reduction='batchmean')

In [119]:
def pgd_delta(model, X, Y, epsilon=0.01, step_size=0.001, num_iterations=10):
    delta = torch.zeros_like(X, requires_grad=True)
    
    for t in range(num_iterations):
        # Forward pass with current perturbation
        X_adv = X + delta
        _, logits = model(X_adv)
        
        # Calculate loss (negative, since we're maximizing the loss for adversarial examples)
        loss = F.cross_entropy(logits, Y)
        
        # Backward pass to get gradients
        loss.backward()
        
        # Update delta using the sign of gradient (Eq. 6)
        with torch.no_grad():
            delta_grad = delta.grad.data
            delta_update = step_size * torch.sign(delta_grad)
            delta.data = delta.data + delta_update
            
            # Project delta back to epsilon-ball (L_p norm constraint)
            if torch.norm(delta.data, p=2) > epsilon:
                delta.data = epsilon * delta.data / torch.norm(delta.data, p=2)
            
        # Reset gradients for next iteration
        if delta.grad is not None:
            delta.grad.zero_()
    
    return delta.detach()

In [120]:
def train():
    mlp.train()
    optimizer.zero_grad()
    inp = torch.cat([data.x, position_embeddings.to(device)], dim=-1)  
    mlp_emb, mlp_out = mlp(inp)

    # GROUND TRUTH Cross Entropy Loss
    GT_loss = F.cross_entropy(mlp_out[data.train_mask], data.y[data.train_mask])

    # SOFT LABELS KL Divergence Loss
    temp = 1
    SL_Loss = kl_loss_fn(F.log_softmax(mlp_out/temp, dim=1), 
                         F.softmax(teacher_out/temp, dim=1)) * (temp**2)
    
    # Representational Similarity Distillation Loss
    teacher_mat = teacher_emb @ teacher_emb.t()
    encoded_mlp_meb = mlp.MLP_RSD(mlp_emb)
    mlp_mat = encoded_mlp_meb @ encoded_mlp_meb.t()
    RSD_Loss = torch.mean((mlp_mat - teacher_mat) ** 2)

    # Adversarial Feature Augmentation Loss
    inp = torch.cat([data.x, position_embeddings.to(device)], dim=-1)  
    delta = pgd_delta(mlp, inp, data.y, epsilon=0.01, step_size=0.001, num_iterations=10)
    _, mlp_out = mlp(inp + delta)
    ADV_loss = F.cross_entropy(mlp_out[data.train_mask], data.y[data.train_mask]) + F.cross_entropy(mlp_out, F.softmax(teacher_out, dim=1))
    
    loss = 1.0*GT_loss + 0.5*SL_Loss + 0.0*RSD_Loss + 0.0*ADV_loss
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test():
    mlp.eval()
    inp = torch.cat([data.x, position_embeddings.to(device)], dim=-1)  
    _, out = mlp(inp)
    pred = out.argmax(dim=1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask] == data.y[mask]
        acc = int(correct.sum()) / int(mask.sum())
        accs.append(acc)
    return accs  # train_acc, val_acc, test_acc

# Training for 200 epochs
for epoch in range(1, 201):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 10 == 0 or epoch == 1:
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')


Epoch 001, Loss: 2.8204, Train Acc: 0.1429, Val Acc: 0.3160, Test Acc: 0.3210
Epoch 010, Loss: 1.0083, Train Acc: 0.9429, Val Acc: 0.9060, Test Acc: 0.9080
Epoch 020, Loss: 0.7623, Train Acc: 1.0000, Val Acc: 0.9860, Test Acc: 0.9920
Epoch 030, Loss: 0.6638, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 0.9990
Epoch 040, Loss: 0.6458, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 050, Loss: 0.6003, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 060, Loss: 0.5426, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 070, Loss: 0.5193, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 080, Loss: 0.5034, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 090, Loss: 0.4855, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 100, Loss: 0.4852, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 110, Loss: 0.4746, Train Acc: 1.0000, Val Acc: 1.0000, Test Acc: 1.0000
Epoch 120, Loss: 0.4791, Train Acc: 1.0000, Val Acc: 1.0000, Tes