In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepwalk.deepwalk import DeepWalk

In [2]:
class MLP(nn.Module):
    def __init__(self, num_layers, input_feat_dim, position_emb_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.num_layers = num_layers
        self.input_feat_dim = input_feat_dim
        self.position_emb_dim = position_emb_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.dropout = nn.Dropout(p=0.5)
        self.layers = nn.ModuleList()
        self.rsd_encoder = nn.Linear(hidden_dim, hidden_dim)
        self.layers.append(nn.Linear(input_feat_dim + position_emb_dim, hidden_dim))

        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))

        self.layers.append(nn.Linear(hidden_dim, output_dim))
    
    def forward(self, inp):
        for layer in self.layers[:-1]:
            inp = F.relu(layer(inp))
            inp = self.dropout(inp)
        out = self.layers[-1](inp)
        return inp, out
    
    def MLP_RSD(self, mlp_emb):
        return F.relu(self.rsd_encoder(mlp_emb))

In [3]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='./Cora', name='Cora')

data = dataset[0]

In [4]:
data.x.shape

torch.Size([2708, 1433])

torch_geometric.data.data.Data

In [5]:
sum(data.test_mask), sum(data.val_mask), sum(data.train_mask)

(tensor(1000), tensor(500), tensor(140))

In [6]:
l = (list(data.y))
min(l), max(l)

(tensor(0), tensor(6))

In [7]:
import networkx as nx

graph_nx = nx.Graph()
graph_nx.add_edges_from(data.edge_index.t().tolist())

In [8]:
deepwalk_model = DeepWalk(graph_nx, walk_length=80, walks_per_vertex=10)
deepwalk_model.train()
position_embeddings = deepwalk_model.get_embeddings()

Read 2M words
Number of words:  2709
Number of labels: 0
Progress: 100.0% words/sec/thread:  100742 lr:  0.000000 avg.loss:  3.680137 ETA:   0h 0m 0s
  return torch.tensor(embeddings)


In [10]:
data.x.shape

torch.Size([2708, 1433])

In [136]:
mlp = MLP(num_layers=3,
          input_feat_dim=data.x.shape[1],
          position_emb_dim=position_embeddings.shape[1],
          hidden_dim=64,
          output_dim=dataset.num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mlp = mlp.to(device)
data = data.to(device)
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4)

In [137]:
teacher_emb = torch.load('teacher_outputs/embeddings.pt')
teacher_out = torch.load('teacher_outputs/label_scores.pt')

In [138]:
kl_loss_fn = nn.KLDivLoss(reduction='batchmean')

In [139]:
def pgd_delta(model, feats, labels, train_mask):
    iters = 5
    eps = 0.05
    alpha = eps / 4

    # init
    delta = torch.rand(feats.shape) * eps * 2 - eps
    delta = delta.to(feats.device)
    delta = torch.nn.Parameter(delta)

    for i in range(iters):
        p_feats = feats + delta

        _, logits = model(p_feats)
        # out = logits.log_softmax(dim=1)

        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        loss.backward()

        # delta update
        delta.data = delta.data + alpha * delta.grad.sign()
        delta.grad = None
        delta.data = torch.clamp(delta.data, min=-eps, max=eps)

    output = delta.detach()
    return output

In [140]:
def train():
    mlp.train()
    optimizer.zero_grad()
    inp = torch.cat([data.x, position_embeddings.to(device)], dim=-1)  
    mlp_emb, mlp_out = mlp(inp)

    # GROUND TRUTH Cross Entropy Loss
    GT_loss = F.cross_entropy(mlp_out[data.train_mask], data.y[data.train_mask])

    # SOFT LABELS KL Divergence Loss
    temp = 1
    SL_Loss = kl_loss_fn(F.log_softmax(mlp_out/temp, dim=1), 
                         F.softmax(teacher_out/temp, dim=1)) * (temp**2)
    
    # Representational Similarity Distillation Loss
    teacher_mat = teacher_emb @ teacher_emb.t()
    encoded_mlp_meb = mlp.MLP_RSD(mlp_emb)
    mlp_mat = encoded_mlp_meb @ encoded_mlp_meb.t()
    RSD_Loss = torch.mean((mlp_mat - teacher_mat) ** 2)

    # Adversarial Feature Augmentation Loss
    inp = torch.cat([data.x, position_embeddings.to(device)], dim=-1)  
    delta = pgd_delta(mlp, inp, data.y, data.train_mask)
    _, mlp_out = mlp(inp + delta)
    ADV_loss = F.cross_entropy(mlp_out[data.train_mask], data.y[data.train_mask]) + F.cross_entropy(mlp_out, F.softmax(teacher_out, dim=1))
    
    loss = 1.0*GT_loss + 0.5*SL_Loss + 0.1*RSD_Loss + 0.3*ADV_loss
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test():
    mlp.eval()
    inp = torch.cat([data.x, position_embeddings.to(device)], dim=-1)  
    _, out = mlp(inp)
    pred = out.argmax(dim=1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask] == data.y[mask]
        acc = int(correct.sum()) / int(mask.sum())
        accs.append(acc)
    return accs  # train_acc, val_acc, test_acc

# Training for 200 epochs
for epoch in range(1, 201):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 10 == 0 or epoch == 1:
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')


Epoch 001, Loss: 53.3479, Train Acc: 0.3929, Val Acc: 0.3440, Test Acc: 0.3470
Epoch 010, Loss: 48.9328, Train Acc: 0.4429, Val Acc: 0.1360, Test Acc: 0.1530
Epoch 020, Loss: 29.8446, Train Acc: 0.7714, Val Acc: 0.5840, Test Acc: 0.5850
Epoch 030, Loss: 23.8138, Train Acc: 0.9286, Val Acc: 0.6480, Test Acc: 0.6810
Epoch 040, Loss: 19.9465, Train Acc: 0.9786, Val Acc: 0.7080, Test Acc: 0.7190
Epoch 050, Loss: 17.0074, Train Acc: 0.9929, Val Acc: 0.7500, Test Acc: 0.7460
Epoch 060, Loss: 15.0716, Train Acc: 1.0000, Val Acc: 0.7600, Test Acc: 0.7630
Epoch 070, Loss: 13.4363, Train Acc: 1.0000, Val Acc: 0.7640, Test Acc: 0.7840
Epoch 080, Loss: 12.2490, Train Acc: 1.0000, Val Acc: 0.7780, Test Acc: 0.7990
Epoch 090, Loss: 11.4245, Train Acc: 1.0000, Val Acc: 0.7680, Test Acc: 0.7970
Epoch 100, Loss: 10.6364, Train Acc: 1.0000, Val Acc: 0.7760, Test Acc: 0.7890
Epoch 110, Loss: 9.7586, Train Acc: 1.0000, Val Acc: 0.7820, Test Acc: 0.7910
Epoch 120, Loss: 9.7697, Train Acc: 1.0000, Val Acc: 

In [125]:
data.y

tensor([3, 4, 4,  ..., 3, 3, 3])