## GCN - Attribute Masking

In [22]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score
import numpy as np

def attribute_masking(features, mask_rate=0.15):
    mask_indices = np.random.choice(features.shape[0], int(features.shape[0]*mask_rate), replace=False)
    masked_features = features.clone()
    masked_features[mask_indices] = 0
    
    return masked_features, mask_indices

class GCNSelfSupervised(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNSelfSupervised, self).__init__()
        self.conv1 = GCNConv(in_channels, 2*in_channels)
        self.conv2 = GCNConv(2*in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x) 
        x = self.conv2(x, edge_index)
        return x

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)


cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
citeseer_dataset = Planetoid(root='/tmp/CiteSeer', name='CiteSeer')

cora_data = cora_dataset[0]
citeseer_data = citeseer_dataset[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cora_data = cora_data.to(device)  
citeseer_data = citeseer_data.to(device)

# Attribute Masking Pretraining on Cora
model_pretrain = GCNSelfSupervised(cora_dataset.num_node_features, cora_dataset.num_node_features).to(device)
criterion = torch.nn.MSELoss()  
optimizer = torch.optim.Adam(model_pretrain.parameters(), lr=0.01)

for epoch in range(50):
    model_pretrain.train()
    masked_features, mask_indices = attribute_masking(cora_data.x)
    masked_features = masked_features.to(device)
    mask_indices = torch.tensor(mask_indices).to(device)

    optimizer.zero_grad() 
    out = model_pretrain(masked_features, cora_data.edge_index)
    loss = criterion(out[mask_indices], cora_data.x[mask_indices])  
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Pretraining Epoch {epoch}, Loss: {loss.item()}")

# Finetuning on CiteSeer with pretrained model
model_pretrained = GCN(citeseer_dataset.num_node_features, cora_dataset.num_node_features, citeseer_dataset.num_classes).to(device)

pretrained_dict = model_pretrain.state_dict()
model_dict = model_pretrained.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model_pretrained.load_state_dict(model_dict)

optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.0001, weight_decay=5e-4)

# Training on CiteSeer without pretraining (control model)
model_control = GCN(citeseer_dataset.num_node_features, cora_dataset.num_node_features, citeseer_dataset.num_classes).to(device)
optimizer_control = torch.optim.Adam(model_control.parameters(), lr=0.0001, weight_decay=5e-4)

def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(citeseer_data.x, citeseer_data.edge_index)  
    loss = F.nll_loss(out[citeseer_data.train_mask], citeseer_data.y[citeseer_data.train_mask]) 
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model):
    model.eval()
    _, pred = model(citeseer_data.x, citeseer_data.edge_index).max(dim=1)
    correct = int(pred[citeseer_data.test_mask].eq(citeseer_data.y[citeseer_data.test_mask]).sum().item())
    acc = correct / int(citeseer_data.test_mask.sum()) 
    return acc

for epoch in range(30):
    loss_pretrained = train(model_pretrained, optimizer_pretrained)
    loss_control = train(model_control, optimizer_control)
    
    if epoch % 10 == 0:
        print(f'Finetuning Epoch: {epoch:03d}, Pretrained Loss: {loss_pretrained:.4f}, Control Loss: {loss_control:.4f}')
        test_acc_pretrained = test(model_pretrained)
        test_acc_control = test(model_control)
        print(f'Pretrained Test Accuracy: {test_acc_pretrained:.4f}, Control Test Accuracy: {test_acc_control:.4f}')

Pretraining Epoch 0, Loss: 0.013630217872560024
Pretraining Epoch 10, Loss: 0.01294267550110817
Pretraining Epoch 20, Loss: 0.01275724545121193
Pretraining Epoch 30, Loss: 0.01226820144802332
Pretraining Epoch 40, Loss: 0.01170811802148819
Finetuning Epoch: 000, Pretrained Loss: 1.7914, Control Loss: 1.7887
Pretrained Test Accuracy: 0.2530, Control Test Accuracy: 0.2140
Finetuning Epoch: 010, Pretrained Loss: 1.6684, Control Loss: 1.6822
Pretrained Test Accuracy: 0.5970, Control Test Accuracy: 0.5150
Finetuning Epoch: 020, Pretrained Loss: 1.5028, Control Loss: 1.5205
Pretrained Test Accuracy: 0.6320, Control Test Accuracy: 0.6050


## GIN - Attribute Masking

In [21]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GINConv
from sklearn.metrics import accuracy_score
import numpy as np

def attribute_masking(features, mask_rate=0.15):
    mask_indices = np.random.choice(features.shape[0], int(features.shape[0]*mask_rate), replace=False)
    masked_features = features.clone()
    masked_features[mask_indices] = 0
    
    return masked_features, mask_indices

class GINSelfSupervised(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GINSelfSupervised, self).__init__()
        self.mlp1 = torch.nn.Sequential(
            torch.nn.Linear(in_channels, 2*in_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(2*in_channels, 2*in_channels)
        )
        self.conv1 = GINConv(self.mlp1)
        
        self.mlp2 = torch.nn.Sequential(
            torch.nn.Linear(2*in_channels, out_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(out_channels, out_channels)
        )
        self.conv2 = GINConv(self.mlp2)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        return x

class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        self.mlp1 = torch.nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        )
        self.conv1 = GINConv(self.mlp1)
        
        self.mlp2 = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        )
        self.conv2 = GINConv(self.mlp2)
        
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)


cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
citeseer_dataset = Planetoid(root='/tmp/CiteSeer', name='CiteSeer')

cora_data = cora_dataset[0]
citeseer_data = citeseer_dataset[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cora_data = cora_data.to(device)  
citeseer_data = citeseer_data.to(device)

# Attribute Masking Pretraining on Cora
model_pretrain = GINSelfSupervised(cora_dataset.num_node_features, cora_dataset.num_node_features).to(device)
criterion = torch.nn.MSELoss()  
optimizer = torch.optim.Adam(model_pretrain.parameters(), lr=0.0001)

for epoch in range(50):
    model_pretrain.train()
    masked_features, mask_indices = attribute_masking(cora_data.x)
    masked_features = masked_features.to(device)
    mask_indices = torch.tensor(mask_indices).to(device)

    optimizer.zero_grad() 
    out = model_pretrain(masked_features, cora_data.edge_index)
    loss = criterion(out[mask_indices], cora_data.x[mask_indices])  
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Pretraining Epoch {epoch}, Loss: {loss.item()}")

# Finetuning on CiteSeer with pretrained model
model_pretrained = GIN(citeseer_dataset.num_node_features, cora_dataset.num_node_features, citeseer_dataset.num_classes).to(device)

pretrained_dict = model_pretrain.state_dict()
model_dict = model_pretrained.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model_pretrained.load_state_dict(model_dict)

optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.001, weight_decay=5e-4)

# Training on CiteSeer without pretraining (control model)
model_control = GIN(citeseer_dataset.num_node_features, cora_dataset.num_node_features, citeseer_dataset.num_classes).to(device)
optimizer_control = torch.optim.Adam(model_control.parameters(), lr=0.001, weight_decay=5e-4)

def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(citeseer_data.x, citeseer_data.edge_index)  
    loss = F.nll_loss(out[citeseer_data.train_mask], citeseer_data.y[citeseer_data.train_mask]) 
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model):
    model.eval()
    _, pred = model(citeseer_data.x, citeseer_data.edge_index).max(dim=1)
    correct = int(pred[citeseer_data.test_mask].eq(citeseer_data.y[citeseer_data.test_mask]).sum().item())
    acc = correct / int(citeseer_data.test_mask.sum()) 
    return acc

for epoch in range(30):
    loss_pretrained = train(model_pretrained, optimizer_pretrained)
    loss_control = train(model_control, optimizer_control)
    
    if epoch % 10 == 0:
        print(f'Finetuning Epoch: {epoch:03d}, Pretrained Loss: {loss_pretrained:.4f}, Control Loss: {loss_control:.4f}')
        test_acc_pretrained = test(model_pretrained)
        test_acc_control = test(model_control)
        print(f'Pretrained Test Accuracy: {test_acc_pretrained:.4f}, Control Test Accuracy: {test_acc_control:.4f}')

Pretraining Epoch 0, Loss: 0.03316158056259155
Pretraining Epoch 10, Loss: 0.01279956754297018
Pretraining Epoch 20, Loss: 0.012602251023054123
Pretraining Epoch 30, Loss: 0.012775646522641182
Pretraining Epoch 40, Loss: 0.012420615181326866
Finetuning Epoch: 000, Pretrained Loss: 1.7845, Control Loss: 1.7865
Pretrained Test Accuracy: 0.1810, Control Test Accuracy: 0.1810
Finetuning Epoch: 010, Pretrained Loss: 0.4317, Control Loss: 0.2197
Pretrained Test Accuracy: 0.5550, Control Test Accuracy: 0.6510
Finetuning Epoch: 020, Pretrained Loss: 0.0093, Control Loss: 0.1456
Pretrained Test Accuracy: 0.5880, Control Test Accuracy: 0.6540


## GCN - Context Prediction

In [11]:
#pretty slow on my machine, so i run it on colab
#Pretrained Test Accuracy: 0.5860
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import k_hop_subgraph, to_undirected
from sklearn.metrics import accuracy_score
import numpy as np
from torch_geometric.utils import subgraph

def context_prediction(data, K=2, r1=1, r2=4):
    edge_index = to_undirected(data.edge_index) 
    
    contexts = []
    neighborhoods = []
    for i in range(data.x.size(0)):
        neighborhood, _, _, _ = k_hop_subgraph(i, K, edge_index)
        neighborhood = torch.unique(neighborhood)
        neighborhoods.append(neighborhood)
        
        context_mask = torch.zeros(data.x.size(0), dtype=torch.bool)
        for j in range(r1, r2+1):
            context, _, _, _ = k_hop_subgraph(i, j, edge_index)
            context = torch.unique(context)
            context_mask[context] = 1
        context_mask[neighborhood] = 0
        context = torch.where(context_mask)[0]
        contexts.append(context)
    
    return neighborhoods, contexts

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class GCNContextPrediction(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNContextPrediction, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.context_gnn = GCN(in_channels, hidden_channels, out_channels)

    def forward(self, x, edge_index, neighborhoods, contexts):
        neighborhood_embeddings = []
        context_embeddings = []
        for neighborhood, context in zip(neighborhoods, contexts):
            neighborhood_edge_index, _ = subgraph(neighborhood, edge_index, relabel_nodes=True)
            neighborhood_embedding = self.forward_once(x[neighborhood], neighborhood_edge_index)
            if neighborhood_embedding.size(0) > 0:
                neighborhood_embedding = neighborhood_embedding.mean(dim=0)
            else:
                neighborhood_embedding = torch.zeros(neighborhood_embedding.size(1)).to(device)
            
            context_edge_index, _ = subgraph(context, edge_index, relabel_nodes=True)
            context_embedding = self.context_gnn(x[context], context_edge_index)
            if context_embedding.size(0) > 0:  
                context_embedding = context_embedding.mean(dim=0)
            else:
                context_embedding = torch.zeros(context_embedding.size(1)).to(device)
            
            neighborhood_embeddings.append(neighborhood_embedding)
            context_embeddings.append(context_embedding)
        
        neighborhood_embeddings = torch.stack(neighborhood_embeddings)
        context_embeddings = torch.stack(context_embeddings)
        return neighborhood_embeddings, context_embeddings

    def forward_once(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
    
    
class GCNClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNClassifier, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
citeseer_dataset = Planetoid(root='/tmp/CiteSeer', name='CiteSeer')

cora_data = cora_dataset[0]
citeseer_data = citeseer_dataset[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cora_data = cora_data.to(device)  
citeseer_data = citeseer_data.to(device)

# Context Prediction Pretraining on Cora

model_pretrain = GCNContextPrediction(cora_dataset.num_node_features, 128, 128).to(device)  
optimizer = torch.optim.Adam(model_pretrain.parameters(), lr=0.0001)

cosine_similarity = torch.nn.CosineSimilarity(dim=1)

for epoch in range(50):
    model_pretrain.train()
    neighborhoods, contexts = context_prediction(cora_data)

    optimizer.zero_grad() 
    neighborhood_embeddings, context_embeddings = model_pretrain(cora_data.x, cora_data.edge_index, neighborhoods, contexts)
    loss = 1 - cosine_similarity(neighborhood_embeddings, context_embeddings).mean()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Pretraining Epoch {epoch}, Loss: {loss.item()}")

# Finetuning on CiteSeer with pretrained model
model_pretrained = GCN(citeseer_dataset.num_node_features, 256, citeseer_dataset.num_classes).to(device)

pretrained_dict = model_pretrain.state_dict()
model_dict = model_pretrained.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model_pretrained.load_state_dict(model_dict)

optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.001, weight_decay=5e-4)

# Training on CiteSeer without pretraining (control model)
model_control = GCN(citeseer_dataset.num_node_features, 256, citeseer_dataset.num_classes).to(device)
optimizer_control = torch.optim.Adam(model_control.parameters(), lr=0.001, weight_decay=5e-4)

def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(citeseer_data.x, citeseer_data.edge_index)  
    loss = F.nll_loss(out[citeseer_data.train_mask], citeseer_data.y[citeseer_data.train_mask]) 
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model):
    model.eval()
    _, pred = model(citeseer_data.x, citeseer_data.edge_index).max(dim=1)
    correct = int(pred[citeseer_data.test_mask].eq(citeseer_data.y[citeseer_data.test_mask]).sum().item())
    acc = correct / int(citeseer_data.test_mask.sum()) 
    return acc

for epoch in range(30):
    loss_pretrained = train(model_pretrained, optimizer_pretrained)
    loss_control = train(model_control, optimizer_control)
    
    if epoch % 10 == 0:
        print(f'Finetuning Epoch: {epoch:03d}, Pretrained Loss: {loss_pretrained:.4f}, Control Loss: {loss_control:.4f}')
        test_acc_pretrained = test(model_pretrained)
        test_acc_control = test(model_control)
        print(f'Pretrained Test Accuracy: {test_acc_pretrained:.4f}, Control Test Accuracy: {test_acc_control:.4f}')

KeyboardInterrupt: 

## GIN - Context Prediction

In [24]:

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import k_hop_subgraph, to_undirected
from sklearn.metrics import accuracy_score
import numpy as np
from torch_geometric.utils import subgraph

def context_prediction(data, K=2, r1=1, r2=4):
    edge_index = to_undirected(data.edge_index) 
    
    contexts = []
    neighborhoods = []
    for i in range(data.x.size(0)):
        neighborhood, _, _, _ = k_hop_subgraph(i, K, edge_index)
        neighborhood = torch.unique(neighborhood)
        neighborhoods.append(neighborhood)
        
        context_mask = torch.zeros(data.x.size(0), dtype=torch.bool)
        for j in range(r1, r2+1):
            context, _, _, _ = k_hop_subgraph(i, j, edge_index)
            context = torch.unique(context)
            context_mask[context] = 1
        context_mask[neighborhood] = 0
        context = torch.where(context_mask)[0]
        contexts.append(context)
    
    return neighborhoods, contexts

class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        self.conv1 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU()
        ))
        self.conv2 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, out_channels),
            torch.nn.ReLU()
        ))

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class GINContextPrediction(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GINContextPrediction, self).__init__()
        self.conv1 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU()
        ))
        self.conv2 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, out_channels),
            torch.nn.ReLU()
        ))
        self.context_gnn = GIN(in_channels, hidden_channels, out_channels)

    def forward(self, x, edge_index, neighborhoods, contexts):
        neighborhood_embeddings = []
        context_embeddings = []
        for neighborhood, context in zip(neighborhoods, contexts):
            neighborhood_edge_index, _ = subgraph(neighborhood, edge_index, relabel_nodes=True)
            neighborhood_embedding = self.forward_once(x[neighborhood], neighborhood_edge_index)
            if neighborhood_embedding.size(0) > 0:
                neighborhood_embedding = neighborhood_embedding.mean(dim=0)
            else:
                neighborhood_embedding = torch.zeros(neighborhood_embedding.size(1)).to(device)
            
            context_edge_index, _ = subgraph(context, edge_index, relabel_nodes=True)
            context_embedding = self.context_gnn(x[context], context_edge_index)
            if context_embedding.size(0) > 0:  
                context_embedding = context_embedding.mean(dim=0)
            else:
                context_embedding = torch.zeros(context_embedding.size(1)).to(device)
            
            neighborhood_embeddings.append(neighborhood_embedding)
            context_embeddings.append(context_embedding)
        
        neighborhood_embeddings = torch.stack(neighborhood_embeddings)
        context_embeddings = torch.stack(context_embeddings)
        return neighborhood_embeddings, context_embeddings
    
    def forward_once(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
    
class GINClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GINClassifier, self).__init__()
        self.conv1 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU()
        ))
        self.conv2 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU()
        ))
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
citeseer_dataset = Planetoid(root='/tmp/CiteSeer', name='CiteSeer')

cora_data = cora_dataset[0]
citeseer_data = citeseer_dataset[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cora_data = cora_data.to(device)  
citeseer_data = citeseer_data.to(device)

# Context Prediction Pretraining on Cora

model_pretrain = GINContextPrediction(cora_dataset.num_node_features, 128, 128).to(device)
optimizer = torch.optim.Adam(model_pretrain.parameters(), lr=0.0001)

cosine_similarity = torch.nn.CosineSimilarity(dim=1)

for epoch in range(50):
    model_pretrain.train()
    neighborhoods, contexts = context_prediction(cora_data)

    optimizer.zero_grad() 
    neighborhood_embeddings, context_embeddings = model_pretrain(cora_data.x, cora_data.edge_index, neighborhoods, contexts)
    loss = 1 - cosine_similarity(neighborhood_embeddings, context_embeddings).mean()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Pretraining Epoch {epoch}, Loss: {loss.item()}")

# Finetuning on CiteSeer with pretrained model
model_pretrained = GINClassifier(citeseer_dataset.num_node_features, 256, citeseer_dataset.num_classes).to(device)

pretrained_dict = model_pretrain.state_dict()
model_dict = model_pretrained.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model_pretrained.load_state_dict(model_dict)

optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.001, weight_decay=5e-4)

# Training on CiteSeer without pretraining (control model)
model_control = GINClassifier(citeseer_dataset.num_node_features, 256, citeseer_dataset.num_classes).to(device)
optimizer_control = torch.optim.Adam(model_control.parameters(), lr=0.001, weight_decay=5e-4)

def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(citeseer_data.x, citeseer_data.edge_index)  
    loss = F.nll_loss(out[citeseer_data.train_mask], citeseer_data.y[citeseer_data.train_mask]) 
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model):
    model.eval()
    _, pred = model(citeseer_data.x, citeseer_data.edge_index).max(dim=1)
    correct = int(pred[citeseer_data.test_mask].eq(citeseer_data.y[citeseer_data.test_mask]).sum().item())
    acc = correct / int(citeseer_data.test_mask.sum()) 
    return acc

for epoch in range(30):
    loss_pretrained = train(model_pretrained, optimizer_pretrained)
    loss_control = train(model_control, optimizer_control)
    
    if epoch % 10 == 0:
        print(f'Finetuning Epoch: {epoch:03d}, Pretrained Loss: {loss_pretrained:.4f}, Control Loss: {loss_control:.4f}')
        test_acc_pretrained = test(model_pretrained)
        test_acc_control = test(model_control)
        print(f'Pretrained Test Accuracy: {test_acc_pretrained:.4f}, Control Test Accuracy: {test_acc_control:.4f}')

Pretraining Epoch 0, Loss: 0.6880784630775452


KeyboardInterrupt: 

## 2 dataset for pretraining

In [20]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score
import numpy as np

def attribute_masking(features, mask_rate=0.15):
    mask_indices = np.random.choice(features.shape[0], int(features.shape[0]*mask_rate), replace=False)
    masked_features = features.clone()
    masked_features[mask_indices] = 0
    
    return masked_features, mask_indices

class GCNSelfSupervised(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNSelfSupervised, self).__init__()
        self.conv1 = GCNConv(in_channels, 2*in_channels)
        self.conv2 = GCNConv(2*in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x) 
        x = self.conv2(x, edge_index)
        return x

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)


cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
citeseer_dataset = Planetoid(root='/tmp/CiteSeer', name='CiteSeer')

pubmed_dataset = Planetoid(root='/tmp/PubMed', name='PubMed')
pubmed_data = pubmed_dataset[0]
pubmed_data = pubmed_data.to(device)

cora_data = cora_dataset[0]
citeseer_data = citeseer_dataset[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cora_data = cora_data.to(device)  
citeseer_data = citeseer_data.to(device)

# Attribute Masking Pretraining on Cora
model_pretrain = GCNSelfSupervised(cora_dataset.num_node_features, cora_dataset.num_node_features).to(device)
criterion = torch.nn.MSELoss()  
optimizer = torch.optim.Adam(model_pretrain.parameters(), lr=0.01)

model_pretrain_pubmed = GCNSelfSupervised(pubmed_dataset.num_node_features, pubmed_dataset.num_node_features).to(device)
optimizer_pubmed = torch.optim.Adam(model_pretrain_pubmed.parameters(), lr=0.01)

for epoch in range(50):
    model_pretrain.train()
    masked_features, mask_indices = attribute_masking(cora_data.x)
    masked_features = masked_features.to(device)
    mask_indices = torch.tensor(mask_indices).to(device)

    optimizer.zero_grad() 
    out = model_pretrain(masked_features, cora_data.edge_index)
    loss = criterion(out[mask_indices], cora_data.x[mask_indices])  
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Pretraining Epoch {epoch}, Loss: {loss.item()}")

for epoch in range(50):
    model_pretrain_pubmed.train()
    masked_features, mask_indices = attribute_masking(pubmed_data.x)
    masked_features = masked_features.to(device)
    mask_indices = torch.tensor(mask_indices).to(device)

    optimizer_pubmed.zero_grad()
    out = model_pretrain_pubmed(masked_features, pubmed_data.edge_index)
    loss = criterion(out[mask_indices], pubmed_data.x[mask_indices])
    loss.backward()
    optimizer_pubmed.step()

    if epoch % 10 == 0:
        print(f"PubMed Pretraining Epoch {epoch}, Loss: {loss.item()}")
        
model_pretrained = GCN(citeseer_dataset.num_node_features, max(cora_dataset.num_node_features, pubmed_dataset.num_node_features), citeseer_dataset.num_classes).to(device)

pretrained_dict_cora = model_pretrain.state_dict()
pretrained_dict_pubmed = model_pretrain_pubmed.state_dict()
model_dict = model_pretrained.state_dict()

pretrained_dict_cora = {k: v for k, v in pretrained_dict_cora.items() if k in model_dict and v.size() == model_dict[k].size()}
pretrained_dict_pubmed = {k: v for k, v in pretrained_dict_pubmed.items() if k in model_dict and v.size() == model_dict[k].size()}

model_dict.update(pretrained_dict_cora)
model_dict.update(pretrained_dict_pubmed)
model_pretrained.load_state_dict(model_dict)

optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.0001, weight_decay=5e-4)

# Training on CiteSeer without pretraining (control model)
model_control = GCN(citeseer_dataset.num_node_features, cora_dataset.num_node_features, citeseer_dataset.num_classes).to(device)
optimizer_control = torch.optim.Adam(model_control.parameters(), lr=0.0001, weight_decay=5e-4)

def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(citeseer_data.x, citeseer_data.edge_index)  
    loss = F.nll_loss(out[citeseer_data.train_mask], citeseer_data.y[citeseer_data.train_mask]) 
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model):
    model.eval()
    _, pred = model(citeseer_data.x, citeseer_data.edge_index).max(dim=1)
    correct = int(pred[citeseer_data.test_mask].eq(citeseer_data.y[citeseer_data.test_mask]).sum().item())
    acc = correct / int(citeseer_data.test_mask.sum()) 
    return acc

for epoch in range(30):
    loss_pretrained = train(model_pretrained, optimizer_pretrained)
    loss_control = train(model_control, optimizer_control)
    
    if epoch % 10 == 0:
        print(f'Finetuning Epoch: {epoch:03d}, Pretrained Loss: {loss_pretrained:.4f}, Control Loss: {loss_control:.4f}')
        test_acc_pretrained = test(model_pretrained)
        test_acc_control = test(model_control)
        print(f'Pretrained Test Accuracy: {test_acc_pretrained:.4f}, Control Test Accuracy: {test_acc_control:.4f}')

Pretraining Epoch 0, Loss: 0.013583128340542316
Pretraining Epoch 10, Loss: 0.0132874371483922
Pretraining Epoch 20, Loss: 0.012566559948027134
Pretraining Epoch 30, Loss: 0.012245085090398788
Pretraining Epoch 40, Loss: 0.01204062718898058
PubMed Pretraining Epoch 0, Loss: 0.00034497419255785644
PubMed Pretraining Epoch 10, Loss: 0.0003179232880938798
PubMed Pretraining Epoch 20, Loss: 0.0003094275598414242
PubMed Pretraining Epoch 30, Loss: 0.0003114673891104758
PubMed Pretraining Epoch 40, Loss: 0.00031147923436947167
Finetuning Epoch: 000, Pretrained Loss: 1.7920, Control Loss: 1.7942
Pretrained Test Accuracy: 0.0910, Control Test Accuracy: 0.1780
Finetuning Epoch: 010, Pretrained Loss: 1.6731, Control Loss: 1.6911
Pretrained Test Accuracy: 0.4180, Control Test Accuracy: 0.5200
Finetuning Epoch: 020, Pretrained Loss: 1.5014, Control Loss: 1.5404
Pretrained Test Accuracy: 0.6220, Control Test Accuracy: 0.5870


## Amazon dataset - Attribute Masking

In [19]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Amazon
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score
import numpy as np

def attribute_masking(features, mask_rate=0.15):
    mask_indices = np.random.choice(features.shape[0], int(features.shape[0]*mask_rate), replace=False)
    masked_features = features.clone()
    masked_features[mask_indices] = 0
    
    return masked_features, mask_indices

class GCNSelfSupervised(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNSelfSupervised, self).__init__()
        self.conv1 = GCNConv(in_channels, 2*in_channels)
        self.conv2 = GCNConv(2*in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x) 
        x = self.conv2(x, edge_index)
        return x

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)


computers_dataset = Amazon(root='/tmp/AmazonComputers', name='Computers')
photo_dataset = Amazon(root='/tmp/AmazonPhoto', name='Photo')

computers_data = computers_dataset[0]
photo_data = photo_dataset[0]

train_ratio = 0.6
val_ratio = 0.2
test_ratio = 0.2

num_nodes = photo_data.num_nodes
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)

perm = torch.randperm(num_nodes)
train_index = int(num_nodes * train_ratio)
val_index = int(num_nodes * (train_ratio + val_ratio))

train_mask[perm[:train_index]] = True
val_mask[perm[train_index:val_index]] = True
test_mask[perm[val_index:]] = True

photo_data.train_mask = train_mask
photo_data.val_mask = val_mask
photo_data.test_mask = test_mask

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
computers_data = computers_data.to(device)  
photo_data = photo_data.to(device)

# Attribute Masking Pretraining on Amazon Computers
model_pretrain = GCNSelfSupervised(computers_dataset.num_node_features, computers_dataset.num_node_features).to(device)
criterion = torch.nn.MSELoss()  
optimizer = torch.optim.Adam(model_pretrain.parameters(), lr=0.01)

for epoch in range(50):
    model_pretrain.train()
    masked_features, mask_indices = attribute_masking(computers_data.x)
    masked_features = masked_features.to(device)
    mask_indices = torch.tensor(mask_indices).to(device)

    optimizer.zero_grad() 
    out = model_pretrain(masked_features, computers_data.edge_index)
    loss = criterion(out[mask_indices], computers_data.x[mask_indices])  
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Pretraining Epoch {epoch}, Loss: {loss.item()}")

# Finetuning on Amazon Photo with pretrained model
model_pretrained = GCN(photo_dataset.num_node_features, computers_dataset.num_node_features, photo_dataset.num_classes).to(device)

pretrained_dict = model_pretrain.state_dict()
model_dict = model_pretrained.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model_pretrained.load_state_dict(model_dict)

optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.001, weight_decay=5e-4)

# Training on Amazon Photo without pretraining (control model)
model_control = GCN(photo_dataset.num_node_features, computers_dataset.num_node_features, photo_dataset.num_classes).to(device)
optimizer_control = torch.optim.Adam(model_control.parameters(), lr=0.001, weight_decay=5e-4)

def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(photo_data.x, photo_data.edge_index)  
    loss = F.nll_loss(out[photo_data.train_mask], photo_data.y[photo_data.train_mask]) 
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model):
    model.eval()
    _, pred = model(photo_data.x, photo_data.edge_index).max(dim=1)
    correct = int(pred[photo_data.test_mask].eq(photo_data.y[photo_data.test_mask]).sum().item())
    acc = correct / int(photo_data.test_mask.sum()) 
    return acc

for epoch in range(30):
    loss_pretrained = train(model_pretrained, optimizer_pretrained)
    loss_control = train(model_control, optimizer_control)
    
    if epoch % 10 == 0:
        print(f'Finetuning Epoch: {epoch:03d}, Pretrained Loss: {loss_pretrained:.4f}, Control Loss: {loss_control:.4f}')
        test_acc_pretrained = test(model_pretrained)
        test_acc_control = test(model_control)
        print(f'Pretrained Test Accuracy: {test_acc_pretrained:.4f}, Control Test Accuracy: {test_acc_control:.4f}')


Pretraining Epoch 0, Loss: 0.408372163772583
Pretraining Epoch 10, Loss: 0.3597937524318695
Pretraining Epoch 20, Loss: 0.20146320760250092
Pretraining Epoch 30, Loss: 0.20677796006202698
Pretraining Epoch 40, Loss: 0.19507689774036407
Finetuning Epoch: 000, Pretrained Loss: 2.0751, Control Loss: 2.0930
Pretrained Test Accuracy: 0.2242, Control Test Accuracy: 0.2235
Finetuning Epoch: 010, Pretrained Loss: 0.9464, Control Loss: 1.0195
Pretrained Test Accuracy: 0.8229, Control Test Accuracy: 0.8392
Finetuning Epoch: 020, Pretrained Loss: 0.4010, Control Loss: 0.4199
Pretrained Test Accuracy: 0.9098, Control Test Accuracy: 0.9013


## Amazon - Context prediction

In [25]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Amazon
from torch_geometric.nn import GCNConv
from torch_geometric.utils import k_hop_subgraph, to_undirected
from sklearn.metrics import accuracy_score
import numpy as np
from torch_geometric.utils import subgraph

def context_prediction(data, K=2, r1=1, r2=4):
    edge_index = to_undirected(data.edge_index) 
    
    contexts = []
    neighborhoods = []
    for i in range(data.x.size(0)):
        neighborhood, _, _, _ = k_hop_subgraph(i, K, edge_index)
        neighborhood = torch.unique(neighborhood)
        neighborhoods.append(neighborhood)
        
        context_mask = torch.zeros(data.x.size(0), dtype=torch.bool)
        for j in range(r1, r2+1):
            context, _, _, _ = k_hop_subgraph(i, j, edge_index)
            context = torch.unique(context)
            context_mask[context] = 1
        context_mask[neighborhood] = 0
        context = torch.where(context_mask)[0]
        contexts.append(context)
    
    return neighborhoods, contexts

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class GCNContextPrediction(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNContextPrediction, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.context_gnn = GCN(in_channels, hidden_channels, out_channels)

    def forward(self, x, edge_index, neighborhoods, contexts):
        neighborhood_embeddings = []
        context_embeddings = []
        for neighborhood, context in zip(neighborhoods, contexts):
            neighborhood_edge_index, _ = subgraph(neighborhood, edge_index, relabel_nodes=True)
            neighborhood_embedding = self.forward_once(x[neighborhood], neighborhood_edge_index)
            if neighborhood_embedding.size(0) > 0:
                neighborhood_embedding = neighborhood_embedding.mean(dim=0)
            else:
                neighborhood_embedding = torch.zeros(neighborhood_embedding.size(1)).to(device)
            
            context_edge_index, _ = subgraph(context, edge_index, relabel_nodes=True)
            context_embedding = self.context_gnn(x[context], context_edge_index)
            if context_embedding.size(0) > 0:  
                context_embedding = context_embedding.mean(dim=0)
            else:
                context_embedding = torch.zeros(context_embedding.size(1)).to(device)
            
            neighborhood_embeddings.append(neighborhood_embedding)
            context_embeddings.append(context_embedding)
        
        neighborhood_embeddings = torch.stack(neighborhood_embeddings)
        context_embeddings = torch.stack(context_embeddings)
        return neighborhood_embeddings, context_embeddings
    
    def forward_once(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class GCNClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNClassifier, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

computers_dataset = Amazon(root='/tmp/AmazonComputers', name='Computers')
photo_dataset = Amazon(root='/tmp/AmazonPhoto', name='Photo')

computers_data = computers_dataset[0]
photo_data = photo_dataset[0]

# Create train, validation, and test masks for the Amazon Photo dataset
train_ratio = 0.6
val_ratio = 0.2
test_ratio = 0.2

num_nodes = photo_data.num_nodes
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)

perm = torch.randperm(num_nodes)
train_index = int(num_nodes * train_ratio)
val_index = int(num_nodes * (train_ratio + val_ratio))

train_mask[perm[:train_index]] = True
val_mask[perm[train_index:val_index]] = True
test_mask[perm[val_index:]] = True

photo_data.train_mask = train_mask
photo_data.val_mask = val_mask
photo_data.test_mask = test_mask

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
computers_data = computers_data.to(device)
photo_data = photo_data.to(device)

# Context Prediction Pretraining on Amazon Computers
model_pretrain = GCNContextPrediction(computers_dataset.num_node_features, 128, 128).to(device)  
optimizer = torch.optim.Adam(model_pretrain.parameters(), lr=0.0001)

cosine_similarity = torch.nn.CosineSimilarity(dim=1)

for epoch in range(50):
    model_pretrain.train()
    neighborhoods, contexts = context_prediction(computers_data)

    optimizer.zero_grad() 
    neighborhood_embeddings, context_embeddings = model_pretrain(computers_data.x, computers_data.edge_index, neighborhoods, contexts)
    loss = 1 - cosine_similarity(neighborhood_embeddings, context_embeddings).mean()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Pretraining Epoch {epoch}, Loss: {loss.item()}")

# Finetuning on Amazon Photo with pretrained model
model_pretrained = GCNClassifier(photo_dataset.num_node_features, 256, photo_dataset.num_classes).to(device)

pretrained_dict = model_pretrain.state_dict()
model_dict = model_pretrained.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model_pretrained.load_state_dict(model_dict)

optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.001, weight_decay=5e-4)

# Training on Amazon Photo without pretraining (control model)
model_control = GCNClassifier(photo_dataset.num_node_features, 256, photo_dataset.num_classes).to(device)
optimizer_control = torch.optim.Adam(model_control.parameters(), lr=0.001, weight_decay=5e-4)

def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(photo_data.x, photo_data.edge_index)  
    loss = F.nll_loss(out[photo_data.train_mask], photo_data.y[photo_data.train_mask]) 
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model):
    model.eval()
    _, pred = model(photo_data.x, photo_data.edge_index).max(dim=1)
    correct = int(pred[photo_data.test_mask].eq(photo_data.y[photo_data.test_mask]).sum().item())
    acc = correct / int(photo_data.test_mask.sum()) 
    return acc

for epoch in range(30):
    loss_pretrained = train(model_pretrained, optimizer_pretrained)
    loss_control = train(model_control, optimizer_control)
    
    if epoch % 10 == 0:
        print(f'Finetuning Epoch: {epoch:03d}, Pretrained Loss: {loss_pretrained:.4f}, Control Loss: {loss_control:.4f}')
        test_acc_pretrained = test(model_pretrained)
        test_acc_control = test(model_control)
        print(f'Pretrained Test Accuracy: {test_acc_pretrained:.4f}, Control Test Accuracy: {test_acc_control:.4f}')

KeyboardInterrupt: 