<a href="https://colab.research.google.com/github/Raheelkhan117/Using-Node-Classification-in-GNNs-using-Transfer-Learning/blob/main/MAGsource_MAGtarget.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision torchaudio
!pip install torch-geometric

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-

In [2]:
pip install "git+https://github.com/tqdm/tqdm.git@devel#egg=tqdm"

Collecting tqdm
  Cloning https://github.com/tqdm/tqdm.git (to revision devel) to /tmp/pip-install-5rxv1gd4/tqdm_d461279f517d45f984b37016fa3c84e0
  Running command git clone --filter=blob:none --quiet https://github.com/tqdm/tqdm.git /tmp/pip-install-5rxv1gd4/tqdm_d461279f517d45f984b37016fa3c84e0
  Running command git checkout -b devel --track origin/devel
  Switched to a new branch 'devel'
  Branch 'devel' set up to track remote branch 'devel' from 'origin'.
  Resolved https://github.com/tqdm/tqdm.git to commit 729db6c1b52f44c01b06b2338d0688ab83b00f01
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: tqdm
  Building wheel for tqdm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for tqdm: filename=tqdm-4.66.6.dev1+g729db6c-py3-none-any.whl size=78565 sha256=c7b3ede9b6d80d58ae8cbe71556df693a2ffaf901a1763b688ca3d88fca4dba2


In [3]:
pip install ogb

Collecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)
Collecting outdated>=0.2.0 (from ogb)
  Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)
Downloading ogb-1.3.6-py3-none-any.whl (78 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Downloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-0.2.4 ogb-1.3.6 outdated-0.2.2


step 3: Models

In [None]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.5):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels)) # input layer

        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels)) # hidden layers

        self.convs.append(GCNConv(hidden_channels, out_channels)) # output layer

        self.dropout = dropout


    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()


    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)



class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.5):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels)) # input layer

        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels)) # hidden layers

        self.convs.append(SAGEConv(hidden_channels, out_channels)) # output layer

        self.dropout = dropout


    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()


    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)



class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.5):
        super(GAT, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels)) # input layer

        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_channels, hidden_channels)) # hidden layers

        self.convs.append(GATConv(hidden_channels, out_channels)) # output layer

        self.dropout = dropout


    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()


    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)



class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.5):
        super(GIN, self).__init__()

        self.convs = torch.nn.ModuleList()

        # input layer
        self.convs.append(
            GINConv(Linear(in_channels, hidden_channels), train_eps=True)
        )

        # hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(
                GINConv(Linear(hidden_channels, hidden_channels), train_eps=True)
            )

        # output layer
        self.convs.append(
            GINConv(Linear(hidden_channels, out_channels), train_eps=True)
        )

        self.dropout = dropout


    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()


    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)


**Testing below**

Loading data and split into source and target(node_year_dic issue solved)

In [5]:
import torch
import torch.nn.functional as F
import torch_geometric
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.utils import subgraph

# ---------------------------------------------------
# DEVICE
# ---------------------------------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ---------------------------------------------------
# Data
# ---------------------------------------------------

# LOAD Arxiv
arxiv_dataset = PygNodePropPredDataset(name='ogbn-arxiv')
arxiv_data = arxiv_dataset[0]

arxiv_data = T.ToSparseTensor()(arxiv_data)
##arxiv_data.adj_t = arxiv_data.adj_t.to_symmetric()
arxiv_data = arxiv_data.to(device)

arxiv_split_idx = arxiv_dataset.get_idx_split()
arxiv_evaluator = Evaluator(name='ogbn-arxiv')
# ---------------------------------------------------

# LOAD MAG
dataset = PygNodePropPredDataset(name="ogbn-mag")
rel_data = dataset[0]
#print(dir(rel_data))
#print(rel_data)



data = Data(
    x=rel_data.x_dict['paper'],
    edge_index=rel_data.edge_index_dict[('paper', 'cites', 'paper')],
    y=rel_data.y_dict['paper']
).to(device)

# SPLIT INTO SOURCE & TARGET SET
years = rel_data.node_year['paper'].unique()
source_years = years[:5]
target_years = years[5:]

source_nodes = torch.cat([
                    torch.where(rel_data.node_year['paper'] == year)[0]
                    for year in source_years
                ])

target_nodes = torch.cat([
                    torch.where(rel_data.node_year['paper'] == year)[0]
                    for year in target_years
                ])

source_nodes, _ = source_nodes.sort()
target_nodes, _ = target_nodes.sort()

source_edge_index, _ = subgraph(source_nodes, data.edge_index, relabel_nodes=True)
target_edge_index, _ = subgraph(target_nodes, data.edge_index, relabel_nodes=True)

source_data = Data(
                x=rel_data.x_dict['paper'][source_nodes],
                edge_index=source_edge_index,
                y=rel_data.y_dict['paper'][source_nodes]
            )

target_data = Data(
                x=rel_data.x_dict['paper'][target_nodes],
                edge_index=target_edge_index,
                y=rel_data.y_dict['paper'][target_nodes]
            )

data = target_data.to(device) # Train on Target split

# MAG EVALUATOR
evaluator = Evaluator(name="ogbn-mag")
# ---------------------------------------------------

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:01<00:00, 46.91it/s]


Extracting dataset/arxiv.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 2325.00it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 2303.30it/s]

Saving...



Done!
  adj = torch.sparse_csr_tensor(


Downloading http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip


Downloaded 0.40 GB: 100%|██████████| 413/413 [00:07<00:00, 53.14it/s]


Extracting dataset/mag.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 1987.82it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 2585.88it/s]

Saving...



Done!


define training and **pretraining model for MAG and Arxiv**

In [None]:
import torch
from torch import optim, nn
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv
from models import *
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from tqdm import tqdm
from copy import deepcopy
import numpy as np


def train(model, optimiser, data):
    # TRAIN
    model.train()
    optimiser.zero_grad()

    try:
        out = model(data.x, data.adj_t)
    except:
        out = model(data.x, data.edge_index)

    loss = F.nll_loss(out, data.y.squeeze(1))

    loss.backward()
    optimiser.step()

    # EVAL
    y_pred = out.argmax(dim=-1, keepdim=True)
    acc = evaluator.eval({
            'y_true': data.y,
            'y_pred': y_pred,
        })['acc']

    return loss.item(), acc


def pretrain_mag_source(model, optimiser, data, model_name, epochs=1000):
    best_acc = 0.0

    for epoch in tqdm(range(epochs)):
        # TRAIN
        model.train()
        optimiser.zero_grad()

        out = model(data.x, data.edge_index)
        loss = F.nll_loss(out, data.y.squeeze(1))

        loss.backward()
        optimiser.step()

        # EVAL
        y_pred = out.argmax(dim=-1, keepdim=True)
        acc = evaluator.eval({
                'y_true': data.y,
                'y_pred': y_pred,
            })['acc']

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'source/{}_mag_source.pth'.format(model_name))

    return best_acc



pretraining on source split MAG **GCN**

In [None]:
# PRETRAIN ON SOURCE SPLIT
from models import *

model = GCN(
        in_channels=data.num_features,
        hidden_channels=256,
        out_channels=dataset.num_classes,
        num_layers=3
    ).to(device)


model.reset_parameters()
source_optimiser = torch.optim.Adam(model.parameters(), lr=0.001)
print('Pretraining model on MAG Source split')
best_acc = pretrain_mag_source(model, source_optimiser, source_data.to(device), "GCN")
print('Best accuracy: {:.3}'.format(best_acc))
model.load_state_dict(
        torch.load( '/content/source/{}_mag_source.pth'.format("GCN") )
    )



Pretraining model on MAG Source split


100%|██████████| 2/2 [01:50<00:00, 55.12s/it]

Best accuracy: 0.00556





<All keys matched successfully>

**Training on Target MAG using Pretrained source MAG(Fine Tuning)**

In [None]:
# ---------------------------------------------------
    # MODEL
    # ---------------------------------------------------
    # USE MAG MODEL
model = GCN(
        in_channels=data.num_features,
        hidden_channels=253,
        out_channels=dataset.num_classes,
        num_layers=3
    ).to(device)

#MAG TARGET TRAINING
print('Training on MAG')
optimiser = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in tqdm(range(2)):
  train_loss, acc = train(model, optimiser, data)
  print('Epoch: {:03d}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format(epoch, train_loss, acc))

Training on MAG


 50%|█████     | 1/2 [00:51<00:51, 51.26s/it]

Epoch: 000, Train Loss: 5.885, Train Acc: 0.008


100%|██████████| 2/2 [01:37<00:00, 48.76s/it]

Epoch: 001, Train Loss: 5.810, Train Acc: 0.017





**Feature Extraction**

step 1: split target into train test and validation sets

In [None]:
from sklearn.model_selection import train_test_split

# Ensure that indices are in the correct format
target_indices = torch.arange(target_data.num_nodes)

# Split indices into train (80%) and temp (20%)
train_indices, temp_indices = train_test_split(target_indices, test_size=0.2, random_state=42)

# Split temp into validation (10%) and test (10%)
val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

train_indices = torch.tensor(train_indices, dtype=torch.long).to(device)
val_indices = torch.tensor(val_indices, dtype=torch.long).to(device)
test_indices = torch.tensor(test_indices, dtype=torch.long).to(device)

print(f'Train indices: {train_indices.size(0)}')
print(f'Validation indices: {val_indices.size(0)}')
print(f'Test indices: {test_indices.size(0)}')


Train indices: 267032
Validation indices: 33379
Test indices: 33380


  train_indices = torch.tensor(train_indices, dtype=torch.long).to(device)
  val_indices = torch.tensor(val_indices, dtype=torch.long).to(device)
  test_indices = torch.tensor(test_indices, dtype=torch.long).to(device)


step 2: feature extraction using the pretrained model

In [None]:
def extract_features(model, data):
    model.eval()
    with torch.no_grad():
        features = model.convs[0](data.x, data.edge_index)
        for conv in model.convs[1:]:
            features = conv(features, data.edge_index)
            features = F.relu(features)
        return features

def train_classifier(features, labels, train_idx, val_idx, test_idx, num_classes):
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score

##ajust iterations
    clf = LogisticRegression(max_iter=10)
    clf.fit(features[train_idx.cpu()], labels[train_idx.cpu()].squeeze(1))

    val_preds = clf.predict(features[val_idx.cpu()])
    test_preds = clf.predict(features[test_idx.cpu()])

    val_acc = accuracy_score(labels[val_idx.cpu()].cpu(), val_preds)
    test_acc = accuracy_score(labels[test_idx.cpu()].cpu(), test_preds)

    return val_acc, test_acc

# Load the pretrained model
model.load_state_dict(torch.load('/content/source/GCN_mag_source.pth'))

# Extract features from the target data using the pretrained model
target_features = extract_features(model, target_data)

# Define the number of classes
num_classes = target_data.y.max().item() + 1

# Train a classifier on the extracted features
feat_val_acc, feat_test_acc = train_classifier(target_features.cpu(), target_data.y.cpu(), train_indices, val_indices, test_indices, num_classes)

print(f'Validation Accuracy: {feat_val_acc:.4f}, Test Accuracy: {feat_test_acc:.4f}')


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Validation Accuracy: 0.1381, Test Accuracy: 0.1385


**Baseline GCN training without transfer learning**

In [None]:
def train_baseline_gcn(model, optimizer, data, train_idx, val_idx, epochs=10):
    best_val_acc = 0.0

    for epoch in tqdm(range(epochs)):
        # TRAIN
        model.train()
        optimizer.zero_grad()

        out = model(data.x, data.edge_index)
        loss = F.nll_loss(out[train_idx], data.y[train_idx].squeeze(1))

        loss.backward()
        optimizer.step()

        # EVAL
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            val_loss = F.nll_loss(out[val_idx], data.y[val_idx].squeeze(1)).item()
            val_acc = (out[val_idx].argmax(dim=1) == data.y[val_idx].squeeze(1)).sum().item() / val_idx.size(0)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'baseline/gcn_baseline.pth')

        print(f'Epoch {epoch:03d}, Loss: {loss.item():.4f}, Val Acc: {val_acc:.4f}')

    return best_val_acc

def evaluate_model(model, data, test_idx):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        #will change test_acc name to base_test_acc
        base_test_acc = (out[test_idx].argmax(dim=1) == data.y[test_idx].squeeze(1)).sum().item() / test_idx.size(0)
    return test_acc


In [None]:
# Use predefined target MAG dataset split
target_train_indices = train_indices.to(device)
target_val_indices = val_indices.to(device)
target_test_indices = test_indices.to(device)

# Initialize the model, optimizer, and other parameters
baseline_model = GCN(in_channels=128, hidden_channels=253, out_channels=num_classes, num_layers=3, dropout=0.5).to(device)
baseline_optimizer = torch.optim.Adam(baseline_model.parameters(), lr=0.01, weight_decay=5e-4)

# Train the baseline GCN model
best_val_acc = train_baseline_gcn(baseline_model, baseline_optimizer, target_data, target_train_indices, target_val_indices, epochs=2)

print(f'Best Validation Accuracy (Baseline GCN): {best_val_acc:.4f}')

# Load the best model
baseline_model.load_state_dict(torch.load('baseline/gcn_baseline.pth'))

# Evaluate the model on the test set
base_test_acc = evaluate_model(baseline_model, target_data, target_test_indices)

print(f'Test Accuracy (Baseline GCN): {base_test_acc:.4f}')


  0%|          | 0/2 [00:51<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# Compare the models
print("Baseline Model:")
print(f"Validation Accuracy: {best_val_acc:.4f}, Test Accuracy: {test_acc:.4f}")

print("Fine-Tuned Model:")
# Assume fine_tune_val_acc and fine_tune_test_acc are obtained during fine-tuning
#print(f"Validation Accuracy: {fine_tune_val_acc:.4f}, Test Accuracy: {fine_tune_test_acc:.4f}")
print(f"Validation Accuracy: {acc:.4f}")

print("Feature Extraction Model:")
print(f"Validation Accuracy: {val_acc:.4f}, Test Accuracy: {test_acc:.4f}")


Baseline Model:
Validation Accuracy: 0.0474, Test Accuracy: 0.0473
Fine-Tuned Model:
Validation Accuracy: 0.0174
Feature Extraction Model:
Validation Accuracy: 0.1381, Test Accuracy: 0.0473
