In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

# ==========================================================
# CONFIG
# ==========================================================

DATA_PATH = "../data/processed/modis_firms_train_val_test_dataset.npz"

BATCH_SIZE = 128
TEST_BATCH_SIZE = 512
EPOCHS = 20
LR = 0.01
WEIGHT_DECAY = 0.0
GAMMA = 0.7
USE_DENDRITES = True   

MAX_TRAIN = 20000
MAX_VAL = 5000
MAX_TEST = 5000

device = "cpu"
print(f"[DEBUG] Using device: {device}")

np.random.seed(0)
torch.manual_seed(0)

# ==========================================================
# MODEL
# ==========================================================

class FireNN(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

# ==========================================================
# LOAD and BALANCE DATA
# ==========================================================

print("[DEBUG] Loading dataset...")

with np.load(DATA_PATH) as f:
    X_train = f["X_train"]
    y_train = f["y_train"]
    X_val = f["X_val"]
    y_val = f["y_val"]
    X_test = f["X_test"]
    y_test = f["y_test"]

print(f"[DEBUG] Original Train Size: {len(X_train):,}")

def balanced_sample(X, y, max_samples):
    pos_idx = np.where(y == 1)[0]
    neg_idx = np.where(y == 0)[0]

    n_each = max_samples // 2
    n_each = min(n_each, len(pos_idx), len(neg_idx))

    pos_sample = np.random.choice(pos_idx, n_each, replace=False)
    neg_sample = np.random.choice(neg_idx, n_each, replace=False)

    idx = np.concatenate([pos_sample, neg_sample])
    np.random.shuffle(idx)

    return X[idx], y[idx]

# Since the dataset is extremely imbalanced, applying balancing
X_train, y_train = balanced_sample(X_train, y_train, MAX_TRAIN)
X_val, y_val = balanced_sample(X_val, y_val, MAX_VAL)
X_test, y_test = balanced_sample(X_test, y_test, MAX_TEST)

# Debug class balance
print("[DEBUG] Train positive ratio:", np.mean(y_train))
print("[DEBUG] Val positive ratio:", np.mean(y_val))
print("[DEBUG] Test positive ratio:", np.mean(y_test))

# ==========================================================
# DATALOADERS
# ==========================================================

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train).float(),
                torch.tensor(y_train).float()),
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_loader = DataLoader(
    TensorDataset(torch.tensor(X_val).float(),
                torch.tensor(y_val).float()),
    batch_size=TEST_BATCH_SIZE
)

test_loader = DataLoader(
    TensorDataset(torch.tensor(X_test).float(),
                torch.tensor(y_test).float()),
    batch_size=TEST_BATCH_SIZE
)

# ==========================================================
# TRAIN / EVAL
# ==========================================================

def train_epoch():
    model.train()
    correct = 0

    for data, target in train_loader:
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.binary_cross_entropy(output, target.unsqueeze(1))
        loss.backward()
        optimizer.step()

        preds = (output > 0.5).float()
        correct += preds.eq(target.unsqueeze(1)).sum().item()

    return 100.0 * correct / len(train_loader.dataset)


def evaluate(loader):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            output = model(data)

            preds = (output > 0.5).cpu().numpy().flatten()
            all_preds.extend(preds)
            all_targets.extend(target.numpy())

    acc = np.mean(np.array(all_preds) == np.array(all_targets))
    precision = precision_score(all_targets, all_preds)
    recall = recall_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)

    return acc * 100, precision, recall, f1

# ==========================================================
# RUN EXPERIMENTS (BASELINE + DENDRITES)
# ==========================================================

for is_dendrite in [False, True]:

    print("\n" + "="*60)
    print(f"RUNNING WITH is_dendrite = {is_dendrite}")
    print("="*60)

    # ------------------------------------------------------
    # Reinitialize model each run
    # ------------------------------------------------------

    input_dim = X_train.shape[1]
    base_model = FireNN(input_dim)

    if is_dendrite:
        print("[DEBUG] Initializing WITH dendrites...")
        model = UPA.initialize_pai(base_model, save_name="fire_model")
    else:
        print("[DEBUG] Running BASELINE...")
        model = base_model

    model.to(device)

    print(f"[DEBUG] Parameter count: {sum(p.numel() for p in model.parameters()):,}")

    # ------------------------------------------------------
    # OPTIMIZER
    # ------------------------------------------------------

    if is_dendrite:
        GPA.pc.set_testing_dendrite_capacity(False)
        GPA.pc.set_weight_decay_accepted(True)
        GPA.pc.set_verbose(False)

        GPA.pai_tracker.set_optimizer(optim.Adam)
        GPA.pai_tracker.set_scheduler(StepLR)

        optim_args = {"params": model.parameters(), "lr": LR, "weight_decay": WEIGHT_DECAY}
        sched_args = {"step_size": 1, "gamma": GAMMA}

        optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optim_args, sched_args)
    else:
        optimizer = optim.Adam(model.parameters(), lr=LR)
        scheduler = StepLR(optimizer, step_size=1, gamma=GAMMA)

    # ------------------------------------------------------
    # TRAIN LOOP
    # ------------------------------------------------------

    for epoch in range(1, EPOCHS + 1):

        train_acc = train_epoch()
        val_acc, val_prec, val_rec, val_f1 = evaluate(val_loader)

        print(f"Epoch {epoch}")
        print(f"  Train Acc: {train_acc:.2f}%")
        print(f"  Val Acc: {val_acc:.2f}%")
        print(f"  Precision: {val_prec:.4f}")
        print(f"  Recall:    {val_rec:.4f}")
        print(f"  F1 Score:  {val_f1:.4f}")

        if is_dendrite:
            model, restructured, training_complete = \
                GPA.pai_tracker.add_validation_score(val_acc, model)

            if restructured and not training_complete:
                print("[DEBUG] Restructured Reset optimizer")
                optim_args = {"params": model.parameters(), "lr": LR, "weight_decay": WEIGHT_DECAY}
                sched_args = {"step_size": 1, "gamma": GAMMA}
                optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optim_args, sched_args)

            if training_complete:
                print("[DEBUG] Training ended early by PAI.")
                break

        scheduler.step()

    # ------------------------------------------------------
    # TEST
    # ------------------------------------------------------

    test_acc, test_prec, test_rec, test_f1 = evaluate(test_loader)

    print("\n========== FINAL TEST ==========")
    print(f"Test Acc: {test_acc:.2f}%")
    print(f"Precision: {test_prec:.4f}")
    print(f"Recall:    {test_rec:.4f}")
    print(f"F1 Score:  {test_f1:.4f}")

    if is_dendrite:
        print("Total dendrites added:",
            GPA.pai_tracker.member_vars["num_dendrites_added"])


Building dendrites without Perforated Backpropagation
[DEBUG] Using device: cpu
[DEBUG] Loading dataset...
[DEBUG] Original Train Size: 12,774,667
[DEBUG] Train positive ratio: 0.5
[DEBUG] Val positive ratio: 0.5
[DEBUG] Test positive ratio: 0.5

RUNNING WITH is_dendrite = False
[DEBUG] Running BASELINE...
[DEBUG] Parameter count: 4,545
Epoch 1
  Train Acc: 53.48%
  Val Acc: 55.72%
  Precision: 0.5389
  Recall:    0.7914
  F1 Score:  0.6412
Epoch 2
  Train Acc: 57.44%
  Val Acc: 57.38%
  Precision: 0.5517
  Recall:    0.7876
  F1 Score:  0.6489
Epoch 3
  Train Acc: 58.64%
  Val Acc: 57.81%
  Precision: 0.5825
  Recall:    0.5512
  F1 Score:  0.5664
Epoch 4
  Train Acc: 58.90%
  Val Acc: 56.85%
  Precision: 0.5454
  Recall:    0.8222
  F1 Score:  0.6558
Epoch 5
  Train Acc: 59.01%
  Val Acc: 58.27%
  Precision: 0.5683
  Recall:    0.6877
  F1 Score:  0.6223
Epoch 6
  Train Acc: 59.17%
  Val Acc: 58.17%
  Precision: 0.5697
  Recall:    0.6679
  F1 Score:  0.6149
Epoch 7
  Train Acc: 59.3



Epoch 2
  Train Acc: 57.96%
  Val Acc: 56.80%
  Precision: 0.5833
  Recall:    0.4762
  F1 Score:  0.5243
Adding validation score 56.79961557
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 1, last improved epoch 1, total epochs 1, n: 10, num_cycles: 2
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 3
  Train Acc: 58.71%
  Val Acc: 58.24%
  Precision: 0.5715
  Recall:    0.6588
  F1 Score:  0.6121
Adding validation score 58.24123018
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 2, last improved epoch 2, total epochs 2, n: 10, num_cycles: 4
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 4
  Train Acc: 59.05%
  Val Acc: 58.34%
  Precision: 0.5818
  Recall:    0.5930
  F1 Score:  0.5873
Adding validation score 58.33733782
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 3, last improved epoch 3, total epochs 3, n: 10, num_cycles: 6
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 5
  Train Acc: 60.05%
  Val Acc: 59.39%
  Precision: 0.5709
  Recall:    0.7568
  F1 Score:  0.6508
Adding validation score 59.39452186
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 4, last improved epoch 4, total epochs 4, n: 10, num_cycles: 8
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 6
  Train Acc: 61.81%
  Val Acc: 62.28%
  Precision: 0.6323
  Recall:    0.5867
  F1 Score:  0.6087
Adding validation score 62.27775108
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 5, last improved epoch 5, total epochs 5, n: 10, num_cycles: 10
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 7
  Train Acc: 62.68%
  Val Acc: 62.35%
  Precision: 0.6169
  Recall:    0.6516
  F1 Score:  0.6338
Adding validation score 62.34983181
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 6, last improved epoch 6, total epochs 6, n: 10, num_cycles: 12
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 8
  Train Acc: 63.08%
  Val Acc: 62.16%
  Precision: 0.6164
  Recall:    0.6439
  F1 Score:  0.6298
Adding validation score 62.15761653
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 7, last improved epoch 7, total epochs 7, n: 10, num_cycles: 14
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 9
  Train Acc: 63.55%
  Val Acc: 62.42%
  Precision: 0.6407
  Recall:    0.5656
  F1 Score:  0.6008
Adding validation score 62.42191254
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 8, last improved epoch 8, total epochs 8, n: 10, num_cycles: 16
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 10
  Train Acc: 64.07%
  Val Acc: 62.25%
  Precision: 0.6197
  Recall:    0.6343
  F1 Score:  0.6269
Adding validation score 62.25372417
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 9, last improved epoch 9, total epochs 9, n: 10, num_cycles: 18
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 11
  Train Acc: 63.50%
  Val Acc: 62.21%
  Precision: 0.6113
  Recall:    0.6704
  F1 Score:  0.6395
Adding validation score 62.20567035
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 10, last improved epoch 10, total epochs 10, n: 10, num_cycles: 20
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 12
  Train Acc: 63.71%
  Val Acc: 62.52%
  Precision: 0.6468
  Recall:    0.5517
  F1 Score:  0.5954
Adding validation score 62.51802018
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 11, last improved epoch 11, total epochs 11, n: 10, num_cycles: 22
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 13
  Train Acc: 64.07%
  Val Acc: 62.97%
  Precision: 0.6405
  Recall:    0.5915
  F1 Score:  0.6150
Adding validation score 62.97453148
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 12, last improved epoch 12, total epochs 12, n: 10, num_cycles: 24
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 14
  Train Acc: 64.23%
  Val Acc: 63.48%
  Precision: 0.6477
  Recall:    0.5911
  F1 Score:  0.6181
Adding validation score 63.47909659
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 13, last improved epoch 13, total epochs 13, n: 10, num_cycles: 26
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 15
  Train Acc: 64.17%
  Val Acc: 62.90%
  Precision: 0.6221
  Recall:    0.6574
  F1 Score:  0.6393
Adding validation score 62.90245074
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 14, last improved epoch 14, total epochs 14, n: 10, num_cycles: 28
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 16
  Train Acc: 64.19%
  Val Acc: 62.85%
  Precision: 0.6266
  Recall:    0.6362
  F1 Score:  0.6314
Adding validation score 62.85439692
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 15, last improved epoch 15, total epochs 15, n: 10, num_cycles: 30
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 17
  Train Acc: 64.44%
  Val Acc: 64.13%
  Precision: 0.6167
  Recall:    0.7468
  F1 Score:  0.6755
Adding validation score 64.12782316
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 16, last improved epoch 16, total epochs 16, n: 10, num_cycles: 32
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 18
  Train Acc: 64.70%
  Val Acc: 63.74%
  Precision: 0.6328
  Recall:    0.6550
  F1 Score:  0.6437
Adding validation score 63.74339260
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 17, last improved epoch 17, total epochs 17, n: 10, num_cycles: 34
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 19
  Train Acc: 64.46%
  Val Acc: 63.48%
  Precision: 0.6387
  Recall:    0.6209
  F1 Score:  0.6296
Adding validation score 63.47909659
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 18, last improved epoch 18, total epochs 18, n: 10, num_cycles: 36
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer




Epoch 20
  Train Acc: 64.72%
  Val Acc: 62.69%
  Precision: 0.6155
  Recall:    0.6761
  F1 Score:  0.6444
Adding validation score 62.68620855
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 19, last improved epoch 19, total epochs 19, n: 10, num_cycles: 38
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[DEBUG] Restructured Reset optimizer





Test Acc: 63.32%
Precision: 0.6222
Recall:    0.6779
F1 Score:  0.6489
Total dendrites added: 20
