In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

# ==========================================================
# CONFIG
# ==========================================================

DATA_PATH = "../data/processed/modis_firms_train_val_test_dataset.npz"

BATCH_SIZE = 128
TEST_BATCH_SIZE = 512
EPOCHS = 20
LR = 0.001
WEIGHT_DECAY = 0.0
GAMMA = 0.95

MAX_TRAIN = 20000
MAX_VAL = 5000
MAX_TEST = 5000

device = "cpu"
print(f"[DEBUG] Using device: {device}")

np.random.seed(0)
torch.manual_seed(0)

# ==========================================================
# MODEL
# ==========================================================

class FireNN(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  
        return x

# ==========================================================
# LOAD DATA
# ==========================================================

print("[DEBUG] Loading dataset...")

with np.load(DATA_PATH) as f:
    X_train = f["X_train"]
    y_train = f["y_train"]
    X_val = f["X_val"]
    y_val = f["y_val"]
    X_test = f["X_test"]
    y_test = f["y_test"]

print(f"[DEBUG] Original Train Size: {len(X_train):,}")

# ==========================================================
# BALANCE TRAINING DATA ONLY
# ==========================================================

def balanced_sample(X, y, max_samples):
    pos_idx = np.where(y == 1)[0]
    neg_idx = np.where(y == 0)[0]

    n_each = max_samples // 2
    n_each = min(n_each, len(pos_idx), len(neg_idx))

    pos_sample = np.random.choice(pos_idx, n_each, replace=False)
    neg_sample = np.random.choice(neg_idx, n_each, replace=False)

    idx = np.concatenate([pos_sample, neg_sample])
    np.random.shuffle(idx)

    return X[idx], y[idx]

def subsample_training(X, y, max_negatives=400_000):
    """Keep all positives, subsample negatives to max_negatives"""
    pos_idx = np.where(y == 1)[0]
    neg_idx = np.where(y == 0)[0]
    np.random.shuffle(neg_idx)
    neg_idx_sub = neg_idx[:max_negatives]

    idx = np.concatenate([pos_idx, neg_idx_sub])
    np.random.shuffle(idx)

    return X[idx], y[idx]

X_train, y_train = subsample_training(X_train, y_train, max_negatives=MAX_TRAIN//2)

# Keep validation and test balanced
X_val, y_val = balanced_sample(X_val, y_val, MAX_VAL)
X_test, y_test = balanced_sample(X_test, y_test, MAX_TEST)

print("[DEBUG] Train positive ratio:", np.mean(y_train))
print("[DEBUG] Val positive ratio:", np.mean(y_val))
print("[DEBUG] Test positive ratio:", np.mean(y_test))


# ==========================================================
# DATALOADERS
# ==========================================================

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train).float(),
                torch.tensor(y_train).float()),
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_loader = DataLoader(
    TensorDataset(torch.tensor(X_val).float(),
                torch.tensor(y_val).float()),
    batch_size=TEST_BATCH_SIZE
)

test_loader = DataLoader(
    TensorDataset(torch.tensor(X_test).float(),
                torch.tensor(y_test).float()),
    batch_size=TEST_BATCH_SIZE
)

# ==========================================================
# LOSS FUNCTION WITH DYNAMIC pos_weight
# ==========================================================

pos_count = np.sum(y_train)
neg_count = len(y_train) - pos_count
pos_weight_tensor = torch.tensor([neg_count / pos_count], dtype=torch.float32).to(device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
print(f"[DEBUG] pos_weight for BCE: {pos_weight_tensor.item():.2f}")

# ==========================================================
# TRAIN / EVAL
# ==========================================================

def train_epoch():
    model.train()
    correct = 0
    total_loss = 0

    for data, target in train_loader:
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        output = model(data)

        loss = loss_fn(output, target.unsqueeze(1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # logits threshold at 0
        preds = (output > 0).float()
        correct += preds.eq(target.unsqueeze(1)).sum().item()

    acc = 100.0 * correct / len(train_loader.dataset)
    return acc, total_loss / len(train_loader)


def evaluate(loader):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            output = model(data)

            # threshold logits at 0
            preds = (output > 0).cpu().numpy().flatten()

            all_preds.extend(preds)
            all_targets.extend(target.numpy())

    acc = np.mean(np.array(all_preds) == np.array(all_targets))
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)

    return acc * 100, precision, recall, f1

# ==========================================================
# RUN EXPERIMENTS
# ==========================================================

for is_dendrite in [False, True]:

    print("\n" + "="*60)
    print(f"RUNNING WITH is_dendrite = {is_dendrite}")
    print("="*60)

    input_dim = X_train.shape[1]
    base_model = FireNN(input_dim)

    if is_dendrite:
        print("[DEBUG] Initializing WITH dendrites...")
        # model = UPA.initialize_pai(base_model, save_name="fire_model")

        GPA.pc.set_testing_dendrite_capacity(False)
        GPA.pc.set_weight_decay_accepted(True)
        GPA.pc.set_verbose(False)
        
        model = UPA.initialize_pai(base_model, save_name="fire_model")

        GPA.pai_tracker.set_optimizer(optim.Adam)
        GPA.pai_tracker.set_scheduler(StepLR)
    else:
        print("[DEBUG] Running BASELINE...")
        model = base_model

    model.to(device)

    print(f"[DEBUG] Parameter count: {sum(p.numel() for p in model.parameters()):,}")

    if is_dendrite:
        optim_args = {"params": model.parameters(), "lr": LR, "weight_decay": WEIGHT_DECAY}
        sched_args = {"step_size": 1, "gamma": GAMMA}
        optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optim_args, sched_args)
    else:
        optimizer = optim.Adam(model.parameters(), lr=LR)
        scheduler = StepLR(optimizer, step_size=1, gamma=GAMMA)

    # ---------------- TRAIN ----------------

    for epoch in range(1, EPOCHS + 1):

        train_acc, train_loss = train_epoch()
        scheduler.step()
        val_acc, val_prec, val_rec, val_f1 = evaluate(val_loader)

        print(f"Epoch {epoch}")
        print(f"  Train Acc: {train_acc:.2f}%")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Acc: {val_acc:.2f}%")
        print(f"  Precision: {val_prec:.4f}")
        print(f"  Recall:    {val_rec:.4f}")
        print(f"  F1 Score:  {val_f1:.4f}")

        if is_dendrite:
            model, restructured, training_complete = \
                GPA.pai_tracker.add_validation_score(val_rec, model)

            if restructured and not training_complete:
                print("[DEBUG] Restructured Reset optimizer")
                optim_args = {"params": model.parameters(), "lr": LR, "weight_decay": WEIGHT_DECAY}
                sched_args = {"step_size": 1, "gamma": GAMMA}
                optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optim_args, sched_args)

            if training_complete:
                print("[DEBUG] Training ended early by PAI.")
                break


    # ---------------- TEST ----------------

    test_acc, test_prec, test_rec, test_f1 = evaluate(test_loader)

    print("\n========== FINAL TEST ==========")
    print(f"Test Acc: {test_acc:.2f}%")
    print(f"Precision: {test_prec:.4f}")
    print(f"Recall:    {test_rec:.4f}")
    print(f"F1 Score:  {test_f1:.4f}")

    if is_dendrite:
        print("Total dendrites added:",
            GPA.pai_tracker.member_vars["num_dendrites_added"])


[DEBUG] Using device: cpu
[DEBUG] Loading dataset...
[DEBUG] Original Train Size: 12,774,667
[DEBUG] Train positive ratio: 0.4926433282597666
[DEBUG] Val positive ratio: 0.5
[DEBUG] Test positive ratio: 0.5
[DEBUG] pos_weight for BCE: 1.03

RUNNING WITH is_dendrite = False
[DEBUG] Running BASELINE...
[DEBUG] Parameter count: 4,545
Epoch 1
  Train Acc: 53.29%
  Train Loss: 0.7071
  Val Acc: 56.94%
  Precision: 0.5533
  Recall:    0.7208
  F1 Score:  0.6260
Epoch 2
  Train Acc: 56.86%
  Train Loss: 0.6861
  Val Acc: 55.62%
  Precision: 0.5353
  Recall:    0.8530
  F1 Score:  0.6578
Epoch 3
  Train Acc: 57.69%
  Train Loss: 0.6845
  Val Acc: 56.70%
  Precision: 0.5469
  Recall:    0.7823
  F1 Score:  0.6437
Epoch 4
  Train Acc: 58.50%
  Train Loss: 0.6806
  Val Acc: 58.60%
  Precision: 0.5872
  Recall:    0.5790
  F1 Score:  0.5831
Epoch 5
  Train Acc: 57.88%
  Train Loss: 0.6819
  Val Acc: 59.03%
  Precision: 0.5761
  Recall:    0.6843
  F1 Score:  0.6255
Epoch 6
  Train Acc: 58.75%
  Tr