In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from tqdm.auto import trange

In [3]:
from gnnboundary import *
from gnnboundary.datasets.msrc_dataset import MSRCDataset
from lib.gcn_classifier import MultiGCNClassifier

# Motif

In [13]:
motif = MotifDataset(seed=12345)
motif_train, motif_val = motif.train_test_split(k=10)
motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
                            num_classes=len(motif.GRAPH_CLS),
                            hidden_channels=6,
                            num_layers=3)

In [None]:
for epoch in trange(128):
    train_loss = motif_train.model_fit(motif_model, lr=0.001)
    train_metrics = motif_train.model_evaluate(motif_model)
    val_metrics = motif_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
# torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

In [None]:
motif_model.load_state_dict(torch.load('ckpts/motif.pt'))

# ENZYMES

In [10]:
enzymes = ENZYMESDataset(seed=12345)
enzymes_train, enzymes_val = enzymes.train_test_split(k=10)
enzymes_model = GCNClassifier(node_features=len(enzymes.NODE_CLS),
                              num_classes=len(enzymes.GRAPH_CLS),
                              hidden_channels=32,
                              num_layers=3)

In [None]:
enzymes_model.load_state_dict(torch.load('ckpts/enzymes.pt'))

In [None]:
for epoch in range(4096):
    train_loss = enzymes_train.model_fit(enzymes_model, lr=0.0001)
    train_metrics = enzymes_train.model_evaluate(enzymes_model)
    val_metrics = enzymes_val.model_evaluate(enzymes_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [8]:
torch.save(enzymes_model.state_dict(), f"ckpts/enzymes.pt")

# COLLAB

In [8]:
collab = CollabDataset(seed=12345)
collab_train, collab_val = collab.train_test_split(k=10)
collab_model = GCNClassifier(node_features=len(collab.NODE_CLS),
                             num_classes=len(collab.GRAPH_CLS),
                             hidden_channels=64,
                             num_layers=5)

In [None]:
for epoch in trange(1024):
    train_loss = collab_train.model_fit(collab_model, lr=0.001)
    train_metrics = collab_train.model_evaluate(collab_model)
    val_metrics = collab_val.model_evaluate(collab_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
# torch.save(collab_model.state_dict(), f"ckpts/collab.pt")

In [None]:
collab_model.load_state_dict(torch.load('ckpts/collab.pt'))

# MSRC_9

In [6]:
import torch
from tqdm.auto import trange
from gnnboundary import *
from gnnboundary.datasets.msrc_dataset import MSRCDataset
from lib.gcn_classifier import MultiGCNClassifier

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

msrc9 = MSRCDataset(seed=12345).to(device)
msrc9_train, msrc9_val = msrc9.train_test_split(k=10)
msrc9_model = MultiGCNClassifier(node_features=len(msrc9.NODE_CLS),
                             num_classes=len(msrc9.GRAPH_CLS),
                             hidden_channels=8,
                             num_layers=3).to(device)

print(device)

cuda:0


Using existing file MSRC_9.zip
Extracting data/MSRC_9/raw/MSRC_9.zip


In [7]:
for epoch in trange(256):
    train_loss = msrc9_train.model_fit(msrc9_model, lr=0.001)
    train_metrics = msrc9_train.model_evaluate(msrc9_model)
    val_metrics = msrc9_val.model_evaluate(msrc9_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
        #   f"Train F1: {train_metrics['f1']}, "
        #   f"Test F1: {val_metrics['f1']}"
        )


  0%|          | 0/256 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 2.2735, Train Acc: 0.1457, Test Acc: 0.1364, 
Epoch: 001, Train Loss: 2.1090, Train Acc: 0.2111, Test Acc: 0.1364, 
Epoch: 002, Train Loss: 2.0197, Train Acc: 0.3266, Test Acc: 0.1818, 
Epoch: 003, Train Loss: 1.9923, Train Acc: 0.3015, Test Acc: 0.2273, 
Epoch: 004, Train Loss: 1.9144, Train Acc: 0.3116, Test Acc: 0.2273, 
Epoch: 005, Train Loss: 1.8734, Train Acc: 0.2915, Test Acc: 0.2273, 
Epoch: 006, Train Loss: 1.8573, Train Acc: 0.2915, Test Acc: 0.2273, 
Epoch: 007, Train Loss: 1.7680, Train Acc: 0.2714, Test Acc: 0.2273, 
Epoch: 008, Train Loss: 1.8182, Train Acc: 0.2714, Test Acc: 0.2273, 
Epoch: 009, Train Loss: 1.7589, Train Acc: 0.2714, Test Acc: 0.2273, 
Epoch: 010, Train Loss: 1.7078, Train Acc: 0.2714, Test Acc: 0.2273, 
Epoch: 011, Train Loss: 1.7163, Train Acc: 0.2613, Test Acc: 0.2273, 
Epoch: 012, Train Loss: 1.5546, Train Acc: 0.2613, Test Acc: 0.2273, 
Epoch: 013, Train Loss: 1.6112, Train Acc: 0.2814, Test Acc: 0.2273, 
Epoch: 014, Train Lo

In [8]:
torch.save(msrc9_model.state_dict(), f"ckpts/msrc_9.pt")

In [None]:
msrc9_model.load_state_dict(torch.load('ckpts/msrc_9.pt'))