In [1]:
import sys
import os
import random
import torch
import torch.nn.functional as F
import git
import numpy as np
from tqdm import tqdm

sys.path.append("../")

from src.models.disc_models import DiscreteDiagSheafDiffusion, DiscreteBundleSheafDiffusion, DiscreteGeneralSheafDiffusion
from src.utils.parser import get_parser
from src.utils.heterophilic import get_dataset, get_fixed_splits

In [2]:
def reset_wandb_env():
    exclude = {
        "WANDB_PROJECT",
        "WANDB_ENTITY",
        "WANDB_API_KEY",
    }
    for k, v in os.environ.items():
        if k.startswith("WANDB_") and k not in exclude:
            del os.environ[k]


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x)[data.train_mask]
    nll = F.nll_loss(out, data.y[data.train_mask])
    loss = nll
    loss.backward()

    optimizer.step()
    del out

def test(model, data):
    model.eval()
    with torch.no_grad():
        logits, accs, losses, preds = model(data.x), [], [], []
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()

            loss = F.nll_loss(logits[mask], data.y[mask])

            preds.append(pred.detach().cpu())
            accs.append(acc)
            losses.append(loss.detach().cpu())
        return accs, preds, losses

def run_exp(args, dataset, model_cls, fold):
    data = dataset[0]
    data = get_fixed_splits(data, args['dataset'], fold)
    data = data.to(args['device'])

    model = model_cls(data.edge_index, args)
    model = model.to(args['device'])

    sheaf_learner_params, other_params = model.grouped_parameters()
    optimizer = torch.optim.Adam([
        {'params': sheaf_learner_params, 'weight_decay': args['sheaf_decay']},
        {'params': other_params, 'weight_decay': args['weight_decay']}
    ], lr=args['lr'])

    epoch = 0
    best_val_acc = test_acc = 0
    best_val_loss = float('inf')
    val_loss_history = []
    val_acc_history = []
    best_epoch = 0
    bad_counter = 0

    for epoch in range(args['epochs']):
        train(model, optimizer, data)

        [train_acc, val_acc, tmp_test_acc], preds, [
            train_loss, val_loss, tmp_test_loss] = test(model, data)
        if fold == 0:
            res_dict = {
                f'fold{fold}_train_acc': train_acc,
                f'fold{fold}_train_loss': train_loss,
                f'fold{fold}_val_acc': val_acc,
                f'fold{fold}_val_loss': val_loss,
                f'fold{fold}_tmp_test_acc': tmp_test_acc,
                f'fold{fold}_tmp_test_loss': tmp_test_loss,
            }
            wandb.log(res_dict, step=epoch)

        new_best_trigger = val_acc > best_val_acc if args['stop_strategy'] == 'acc' else val_loss < best_val_loss
        if new_best_trigger:
            best_val_acc = val_acc
            best_val_loss = val_loss
            test_acc = tmp_test_acc
            best_epoch = epoch
            bad_counter = 0
        else:
            bad_counter += 1

        if bad_counter == args['early_stopping']:
            break

    print(f"Fold {fold} | Epochs: {epoch} | Best epoch: {best_epoch}")
    print(f"Test acc: {test_acc:.4f}")
    print(f"Best val acc: {best_val_acc:.4f}")

    if "ODE" not in args['model']:
        # Debugging for discrete models
        for i in range(len(model.sheaf_learners)):
            L_max = model.sheaf_learners[i].L.detach().max().item()
            L_min = model.sheaf_learners[i].L.detach().min().item()
            L_avg = model.sheaf_learners[i].L.detach().mean().item()
            L_abs_avg = model.sheaf_learners[i].L.detach().abs().mean().item()
            print(f"Laplacian {i}: Max: {L_max:.4f}, Min: {L_min:.4f}, Avg: {L_avg:.4f}, Abs avg: {L_abs_avg:.4f}")

        with np.printoptions(precision=3, suppress=True):
            for i in range(0, args['layers']):
                print(f"Epsilons {i}: {model.epsilons[i].detach().cpu().numpy().flatten()}")

    wandb.log({'best_test_acc': test_acc, 'best_val_acc': best_val_acc, 'best_epoch': best_epoch})
    keep_running = False if test_acc < args['min_acc'] else True

    return test_acc, best_val_acc, keep_running


In [6]:
#first loss

# setup the parameters
parser = get_parser()
args = parser.parse_args("")

repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha

#rewrite and add some parameters
args.d = 3
args.layers = 4
args.dropout = 0.7
args.model = "BundleSheaf"
args.entity = "sheafnn"

model_cls = DiscreteBundleSheafDiffusion
dataset = get_dataset(args.dataset)

args.graph_size = dataset[0].x.size(0)
args.input_dim = dataset.num_features
args.output_dim = dataset.num_classes
args.device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
assert args.normalised or args.deg_normalised
if args.sheaf_decay is None:
    args.sheaf_decay = args.weight_decay


if args.model == 'DiagSheafODE':
    model_cls = DiagSheafDiffusion
elif args.model == 'BundleSheafODE':
    model_cls = BundleSheafDiffusion
elif args.model == 'GeneralSheafODE':
    model_cls = GeneralSheafDiffusion
elif args.model == 'DiagSheaf':
    model_cls = DiscreteDiagSheafDiffusion
elif args.model == 'BundleSheaf':
    model_cls = DiscreteBundleSheafDiffusion
elif args.model == 'GeneralSheaf':
    model_cls = DiscreteGeneralSheafDiffusion
else:
    raise ValueError(f'Unknown model')


# Set the seed for everything
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

results = []

!pip install wandb -qqq
import wandb
wandb.init(project="sheafnn", config=vars(args), entity=args.entity)
    
#train the diffusion process


for fold in tqdm(range(args.folds)):
        test_acc, best_val_acc, keep_running = run_exp(wandb.config, dataset, model_cls, fold)
        results.append([test_acc, best_val_acc])
        if not keep_running:
            break




[31mERROR: sentry-sdk 1.17.0 has requirement urllib3>=1.26.11; python_version >= "3.6", but you'll have urllib3 1.24.2 which is incompatible.[0m


 10%|█         | 1/10 [00:06<00:56,  6.28s/it]

Fold 0 | Epochs: 274 | Best epoch: 74
Test acc: 0.7838
Best val acc: 0.7797
Laplacian 0: Max: 0.0208, Min: -0.0714, Avg: -0.0219, Abs avg: 0.0237
Laplacian 1: Max: 0.0332, Min: -0.0744, Avg: -0.0211, Abs avg: 0.0254
Laplacian 2: Max: 0.0691, Min: -0.1484, Avg: -0.0268, Abs avg: 0.0383
Laplacian 3: Max: 0.2540, Min: -0.2322, Avg: -0.0009, Abs avg: 0.0807
Epsilons 0: [0.971 0.948 0.81 ]
Epsilons 1: [0.984 0.959 0.842]
Epsilons 2: [0.985 0.977 0.848]
Epsilons 3: [1.    1.022 0.861]


 20%|██        | 2/10 [00:16<01:10,  8.78s/it]

Fold 1 | Epochs: 480 | Best epoch: 280
Test acc: 0.9189
Best val acc: 0.8305
Laplacian 0: Max: 0.0120, Min: -0.0884, Avg: -0.0262, Abs avg: 0.0284
Laplacian 1: Max: 0.0290, Min: -0.0806, Avg: -0.0243, Abs avg: 0.0294
Laplacian 2: Max: 0.0324, Min: -0.0799, Avg: -0.0210, Abs avg: 0.0270
Laplacian 3: Max: 0.4184, Min: -0.4446, Avg: -0.0127, Abs avg: 0.0996
Epsilons 0: [1.166 1.154 1.063]
Epsilons 1: [1.136 1.177 1.044]
Epsilons 2: [1.162 1.162 1.069]
Epsilons 3: [1.19  1.177 1.098]


 30%|███       | 3/10 [00:24<00:56,  8.09s/it]

Fold 2 | Epochs: 331 | Best epoch: 131
Test acc: 0.8108
Best val acc: 0.8983
Laplacian 0: Max: 0.0432, Min: -0.1063, Avg: -0.0287, Abs avg: 0.0339
Laplacian 1: Max: 0.0685, Min: -0.1029, Avg: -0.0165, Abs avg: 0.0374
Laplacian 2: Max: 0.3174, Min: -0.3817, Avg: -0.0253, Abs avg: 0.0763
Laplacian 3: Max: 0.1129, Min: -0.1132, Avg: -0.0060, Abs avg: 0.0410
Epsilons 0: [1.022 0.982 0.97 ]
Epsilons 1: [1.038 0.983 0.99 ]
Epsilons 2: [1.077 0.992 1.   ]
Epsilons 3: [1.055 0.991 0.977]


 40%|████      | 4/10 [00:29<00:42,  7.16s/it]

Fold 3 | Epochs: 257 | Best epoch: 57
Test acc: 0.8919
Best val acc: 0.8305
Laplacian 0: Max: 0.0271, Min: -0.0842, Avg: -0.0229, Abs avg: 0.0272
Laplacian 1: Max: 0.0337, Min: -0.0752, Avg: -0.0181, Abs avg: 0.0249
Laplacian 2: Max: 0.2009, Min: -0.3842, Avg: -0.0444, Abs avg: 0.0801
Laplacian 3: Max: 0.0655, Min: -0.0695, Avg: -0.0047, Abs avg: 0.0230
Epsilons 0: [0.937 0.895 0.95 ]
Epsilons 1: [0.935 0.942 0.946]
Epsilons 2: [0.98  1.032 0.966]
Epsilons 3: [0.939 1.004 0.963]


 50%|█████     | 5/10 [00:35<00:33,  6.60s/it]

Fold 4 | Epochs: 254 | Best epoch: 54
Test acc: 0.8378
Best val acc: 0.9492
Laplacian 0: Max: 0.0378, Min: -0.0796, Avg: -0.0226, Abs avg: 0.0278
Laplacian 1: Max: 0.0675, Min: -0.1061, Avg: -0.0274, Abs avg: 0.0414
Laplacian 2: Max: 0.1733, Min: -0.1978, Avg: -0.0101, Abs avg: 0.0544
Laplacian 3: Max: 0.2629, Min: -0.2793, Avg: -0.0199, Abs avg: 0.0633
Epsilons 0: [0.958 0.985 0.898]
Epsilons 1: [0.956 0.984 0.912]
Epsilons 2: [0.973 1.02  0.935]
Epsilons 3: [0.974 1.01  0.936]


 60%|██████    | 6/10 [00:43<00:28,  7.03s/it]

Fold 5 | Epochs: 348 | Best epoch: 148
Test acc: 0.8378
Best val acc: 0.9322
Laplacian 0: Max: 0.0244, Min: -0.0847, Avg: -0.0230, Abs avg: 0.0277
Laplacian 1: Max: 0.0273, Min: -0.0893, Avg: -0.0204, Abs avg: 0.0251
Laplacian 2: Max: 0.0790, Min: -0.0993, Avg: -0.0063, Abs avg: 0.0308
Laplacian 3: Max: 0.1982, Min: -0.6358, Avg: -0.0780, Abs avg: 0.0969
Epsilons 0: [1.088 1.061 1.101]
Epsilons 1: [1.108 1.061 1.099]
Epsilons 2: [1.101 1.044 1.077]
Epsilons 3: [1.101 1.083 1.128]


 70%|███████   | 7/10 [00:51<00:21,  7.26s/it]

Fold 6 | Epochs: 350 | Best epoch: 150
Test acc: 0.9459
Best val acc: 0.8814
Laplacian 0: Max: 0.0192, Min: -0.0819, Avg: -0.0237, Abs avg: 0.0256
Laplacian 1: Max: 0.0271, Min: -0.0794, Avg: -0.0227, Abs avg: 0.0284
Laplacian 2: Max: 0.1084, Min: -0.2020, Avg: -0.0328, Abs avg: 0.0497
Laplacian 3: Max: 0.5875, Min: -0.8511, Avg: -0.0885, Abs avg: 0.1592
Epsilons 0: [1.036 1.071 0.994]
Epsilons 1: [1.045 1.059 1.008]
Epsilons 2: [1.075 1.097 1.072]
Epsilons 3: [1.059 1.081 1.084]


 80%|████████  | 8/10 [00:56<00:13,  6.72s/it]

Fold 7 | Epochs: 247 | Best epoch: 47
Test acc: 0.7027
Best val acc: 0.8305
Laplacian 0: Max: 0.0418, Min: -0.0636, Avg: -0.0193, Abs avg: 0.0247
Laplacian 1: Max: 0.0490, Min: -0.0903, Avg: -0.0212, Abs avg: 0.0330
Laplacian 2: Max: 0.0610, Min: -0.1384, Avg: -0.0155, Abs avg: 0.0226
Laplacian 3: Max: 0.0930, Min: -0.1898, Avg: -0.0223, Abs avg: 0.0379
Epsilons 0: [0.995 0.89  0.923]
Epsilons 1: [1.008 0.923 0.925]
Epsilons 2: [1.007 0.916 0.974]
Epsilons 3: [1.007 0.933 0.984]


 90%|█████████ | 9/10 [01:08<00:08,  8.24s/it]

Fold 8 | Epochs: 526 | Best epoch: 326
Test acc: 0.7838
Best val acc: 0.8983
Laplacian 0: Max: 0.0249, Min: -0.0675, Avg: -0.0214, Abs avg: 0.0260
Laplacian 1: Max: 0.0600, Min: -0.0805, Avg: -0.0200, Abs avg: 0.0358
Laplacian 2: Max: 0.0930, Min: -0.0960, Avg: -0.0125, Abs avg: 0.0317
Laplacian 3: Max: 0.5287, Min: -0.4440, Avg: -0.0018, Abs avg: 0.1121
Epsilons 0: [1.107 1.227 1.091]
Epsilons 1: [1.116 1.216 1.091]
Epsilons 2: [1.11  1.228 1.106]
Epsilons 3: [1.156 1.234 1.154]


100%|██████████| 10/10 [01:17<00:00,  7.73s/it]

Fold 9 | Epochs: 419 | Best epoch: 219
Test acc: 0.8919
Best val acc: 0.7966
Laplacian 0: Max: 0.0272, Min: -0.0770, Avg: -0.0230, Abs avg: 0.0260
Laplacian 1: Max: 0.0411, Min: -0.0919, Avg: -0.0231, Abs avg: 0.0286
Laplacian 2: Max: 0.0472, Min: -0.0799, Avg: -0.0176, Abs avg: 0.0283
Laplacian 3: Max: 0.1199, Min: -0.1435, Avg: -0.0174, Abs avg: 0.0399
Epsilons 0: [1.056 1.077 1.069]
Epsilons 1: [1.077 1.096 1.117]
Epsilons 2: [1.064 1.118 1.104]
Epsilons 3: [1.097 1.115 1.124]





In [7]:
test_acc_mean, val_acc_mean = np.mean(results, axis=0) * 100
test_acc_std = np.sqrt(np.var(results, axis=0)[0]) * 100

wandb_results = {'test_acc': test_acc_mean, 'val_acc': val_acc_mean, 'test_acc_std': test_acc_std}
wandb.log(wandb_results)
wandb.finish()


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
best_epoch,▂▇▃▁▁▄▄▁█▅
best_test_acc,▃▇▄▆▅▅█▁▃▆
best_val_acc,▁▃▆▃█▇▅▃▆▂
fold0_tmp_test_acc,▁▂▂▂▄▇█▇▇█▇▆▇▆██▇▇▇███▇▇▆▇▇▇▇█▆▇▇▇▇▇██▇█
fold0_tmp_test_loss,█▆▄▄▃▂▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▃▃▃▅▄▃▃▃▄▄▃▄
fold0_train_acc,▁▁▁▄▄▇██████████████████████████████████
fold0_train_loss,█▇▆▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
fold0_val_acc,▁▁▁▂▂▅▆▆▆▆▆▆▇▇▇▇▆▆▆▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇█▇▇▇▆▇
fold0_val_loss,█▇▆▅▄▃▂▁▁▁▁▁▁▁▁▁▂▂▂▁▂▁▁▂▂▂▁▂▂▁▂▂▂▂▁▁▂▁▁▂
test_acc,▁

0,1
best_epoch,219.0
best_test_acc,0.89189
best_val_acc,0.79661
fold0_tmp_test_acc,0.81081
fold0_tmp_test_loss,0.87521
fold0_train_acc,1.0
fold0_train_loss,0.00082
fold0_val_acc,0.79661
fold0_val_loss,0.63614
test_acc,84.05405


In [8]:
model_name = args.model if args.evectors == 0 else f"{args.model}+LP{args.evectors}"
print(f'{model_name} on {args.dataset} | SHA: {sha}')
print(f'Test acc: {test_acc_mean:.4f} +/- {test_acc_std:.4f} | Val acc: {val_acc_mean:.4f}')


BundleSheaf on texas | SHA: 055f7b1c80e5028eb8f58d3f62de531e68dc908d
Test acc: 84.0541 +/- 6.9905 | Val acc: 86.2712
