In [42]:
import pandas as pd
import matplotlib.pyplot as plt
import wandb

import numpy as np
import seaborn as sns
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

import sys
BASE_PATH = globals()['_dh'][0].parent.absolute()
sys.path.insert(1, str(BASE_PATH))

from src.utils.utils import count_parameters
from src.models.models import GCN, GAT
from src.models.iterativeModels import iterativeGCN, iterativeGAT
from src.models.variantModels import iterativeGCN_variant
from src.utils.wandb_analysis import get_metrics, get_clean_sweep_runs, get_sweep_info
wandb.login()

True

In [53]:
dataset = Planetoid(root='data/Planetoid/',
                     name='CiteSeer',
                     transform=NormalizeFeatures())
num_features, num_classes = dataset.num_features, dataset.num_classes

In [34]:
gcn = GCN(input_dim=num_features,
          output_dim=num_classes,
          hidden_dim=32,
          num_layers=4,
          dropout=0.5)
count_parameters(gcn)

18243

In [35]:
igcn = iterativeGCN(input_dim=num_features,
                    output_dim=num_classes,
                    hidden_dim=32,
                    dropout=0.5,
                    train_schedule=None
)
count_parameters(igcn)

17187

In [54]:
gat = GAT(num_node_features=num_features, 
          hidden_dim=8,
          output_dim=num_classes,
          num_layers=3,
          attn_dropout_rate=0.6,
          dropout=0.6,
          heads=8)
count_parameters(gat)

271442

In [55]:
igat = iterativeGAT(input_dim=num_features,
                    output_dim=num_classes,
                    hidden_dim=64,
                    heads=8,
                    attn_dropout_rate=0.6,
                    dropout=0.6,
                    train_schedule=None)
count_parameters(igat)

271302

In [48]:
igcnv = iterativeGCN_variant(input_dim=num_features,
                             output_dim=num_classes,
                             hidden_dim=32,
                             train_schedule=None,
                             dropout=0.5)
count_parameters(igcnv)

17187