In [2]:
# Paste this cell into your Jupyter notebook and run it.

import sys, pathlib, importlib, traceback, types
import importlib.util

repo_root = pathlib.Path("/home/user/abin_ref_papers/project_structure_demo/dnn_template").resolve()
print("repo_root =", repo_root)
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
    print("Inserted repo_root into sys.path")

models_dir = repo_root / "models"
losses_dir = repo_root / "losses"

print("models_dir exists:", models_dir.exists(), " losses_dir exists:", losses_dir.exists())

# --- Create an in-memory 'models' package module so relative imports inside models/*.py work ---
if "models" not in sys.modules:
    print("Creating in-memory 'models' package module (so relative imports work)")
    models_pkg = types.ModuleType("models")
    # package must provide __path__ so importlib can locate submodules (models.<name>)
    models_pkg.__path__ = [str(models_dir)]
    # Add a REGISTRY and register decorator (backbone.py expects `from . import register` -> register())
    models_pkg.REGISTRY = {}
    def _register(name):
        def deco(fn):
            models_pkg.REGISTRY[name] = fn
            return fn
        return deco
    models_pkg.register = _register
    # Insert into sys.modules
    sys.modules["models"] = models_pkg
else:
    print("'models' already in sys.modules (skipping in-memory package creation)")

# Now import models.backbone (and other models.* files will be loadable via normal imports)
try:
    m_backbone = importlib.import_module("models.backbone")
    print("Imported models.backbone OK")
except Exception as e:
    print("Failed to import models.backbone via package import. Traceback:")
    traceback.print_exc()
    raise

# After importing backbone and any other modules that used @register decorator,
# the in-memory models.REGISTRY should be populated with factories (if files used @register).
import models as models_pkg  # the module we put in sys.modules
print("Registry keys now:", list(models_pkg.REGISTRY.keys())[:200])

# try to find the factory for daiic_resnet34
factory_name = None
factory_fn = None
for candidate in ("daiic_resnet34", "daiic_resnet", "daiic", "resnet34"):
    if candidate in models_pkg.REGISTRY:
        factory_name = candidate
        factory_fn = models_pkg.REGISTRY[candidate]
        break

# also check attributes on imported modules (sometimes factories are defined as module-level functions)
if factory_fn is None:
    # attempt to find on loaded modules inside 'models' package
    for name, mod in list(sys.modules.items()):
        if name.startswith("models.") and hasattr(mod, "daiic_resnet34"):
            factory_fn = getattr(mod, "daiic_resnet34")
            factory_name = "daiic_resnet34 (module-level)"
            break

if factory_fn is None:
    # try to scan 'models' dir for a plausible file and load it (last resort)
    print("Could not find a registered factory in models.REGISTRY. Scanning models directory for candidate python files...")
    for p in sorted(models_dir.glob("*.py")):
        try:
            spec = importlib.util.spec_from_file_location(f"models_scan_{p.stem}", str(p))
            mod = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(mod)
            if hasattr(mod, "daiic_resnet34"):
                factory_fn = getattr(mod, "daiic_resnet34")
                factory_name = f"daiic_resnet34 (from file {p.name})"
                print("Found factory in", p)
                break
        except Exception:
            continue

if factory_fn is None:
    raise RuntimeError("Failed to locate a daiic_resnet34 factory in models package. Please paste the output of: `ls models` and the first 200 chars of models/backbone.py`.")

print("Using factory:", factory_name, "function:", factory_fn)

# --- Load ALRLoss from losses/alr.py ---
alr_path = losses_dir / "alr.py"
if not alr_path.exists():
    raise FileNotFoundError(f"Couldn't find {alr_path}; adjust path (I saw earlier ALRLoss in losses/alr.py).")
spec = importlib.util.spec_from_file_location("local_losses_alr", str(alr_path))
mod_alr = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod_alr)
if not hasattr(mod_alr, "ALRLoss"):
    raise RuntimeError(f"{alr_path} loaded but ALRLoss symbol not found. Available names: {[n for n in dir(mod_alr) if not n.startswith('_')][:200]}")
ALRLoss = getattr(mod_alr, "ALRLoss")
print("Loaded ALRLoss from", alr_path)

# === instantiate model and run a one-batch diagnostic ===
import torch, torch.nn as nn, torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# instantiate model
model = factory_fn(num_classes=100, pretrained=False, cifar_stem=True, in_channels=3)
model = model.to(device)
print("Instantiated model type:", type(model))

# instantiate ALR and optimizer
criterion_alr = ALRLoss(class_weights=None, reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# tiny dataloader
transform = T.Compose([T.ToTensor()])
train_ds = torchvision.datasets.CIFAR100(root=str(repo_root / "data"), train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)

# fetch batch
x, y = next(iter(train_loader))
x = x.to(device); y = y.to(device)

model.eval()
with torch.no_grad():
    outputs = model(x)

# normalize outputs as Trainer would
def normalize_outputs(outputs):
    logits = None; W = None
    if isinstance(outputs, dict):
        logits = outputs.get("logits", None)
        if logits is None and "probs" in outputs:
            probs = outputs["probs"].clamp(1e-6, 1-1e-6)
            logits = torch.log(probs / (1.0 - probs))
        W = outputs.get("W", None)
    elif torch.is_tensor(outputs):
        logits = outputs
    else:
        raise RuntimeError("Unsupported outputs type")
    logits = logits.to(device).float()
    if W is not None:
        W = W.to(device).float()
        if not torch.isfinite(W).all():
            W = torch.ones_like(logits)
        W = W.clamp(min=0.0)
        rs = W.sum(dim=1, keepdim=True)
        zero_rows = rs == 0
        if zero_rows.any():
            W[zero_rows.expand_as(W)] = 1.0
            rs = W.sum(dim=1, keepdim=True)
        W = W / (rs + 1e-12)
        if not torch.isfinite(W).all() or (W.abs().max() < 1e-12):
            W = torch.ones_like(logits) / float(logits.size(1))
    return logits, W, outputs

logits, W, full = normalize_outputs(outputs)

print("\n=== Shapes ===")
print("logits:", logits.shape)
print("probs:", torch.sigmoid(logits).shape)
print("W:", None if W is None else W.shape)
if isinstance(full, dict):
    for k in ("F","Fprime"):
        if k in full:
            try:
                print(k, "shape:", full[k].shape)
            except Exception:
                print(k, "present but couldn't read shape")

probs = torch.sigmoid(logits)
print("\n=== Basic stats ===")
print("logits min/max/mean:", float(logits.min()), float(logits.max()), float(logits.mean()))
print("probs min/max/mean:", float(probs.min()), float(probs.max()), float(probs.mean()))
if W is not None:
    print("W min/max/mean:", float(W.min()), float(W.max()), float(W.mean()))
    print("W per-sample sums (first 8):", W.sum(dim=1)[:8].detach().cpu().numpy())

preds = logits.argmax(dim=1)
y_idx = y.argmax(dim=1) if y.dim()>1 else y.view(-1)
print("preds (first 16):", preds[:16].cpu().numpy())
print("labels (first 16):", y_idx[:16].cpu().numpy())

# ALR loss
if W is None:
    print("Model did not return W -> using uniform W")
    W = torch.ones_like(logits) / logits.size(1)

try:
    res = criterion_alr(logits, y, W)
    if isinstance(res, (tuple,list)):
        alr_loss = res[0]; alr_info = res[1] if len(res)>1 else {}
    else:
        alr_loss = res; alr_info = {}
    print("\nALR loss:", float(alr_loss.item()))
    if alr_info:
        for k in ("loss_per_class_mean","W_mean_per_class","pos_per_class"):
            if k in alr_info:
                v = alr_info[k]
                try:
                    print(k, "first8:", v[:8].detach().cpu().numpy())
                except Exception:
                    print(k, ":", v)
except Exception as e:
    print("ALR call failed:", e)
    traceback.print_exc()

# CE comparison
try:
    ce_loss = nn.CrossEntropyLoss()(logits, y_idx)
    print("CrossEntropy loss:", float(ce_loss.item()))
except Exception as e:
    print("CE failed:", e)

# CE grad-flow check
print("\n--- CE grad-flow check ---")
model.train()
optimizer.zero_grad()
outputs2 = model(x)
logits2, W2, _ = normalize_outputs(outputs2)
ce_loss2 = nn.CrossEntropyLoss()(logits2, y_idx)
ce_loss2.backward()
print("CE loss (backward):", float(ce_loss2.item()))
for name,p in model.named_parameters():
    if ('fc' in name or 'classifier' in name or 'head' in name) and p.requires_grad:
        print("grad for", name, "norm:", 0.0 if p.grad is None else float(p.grad.detach().norm()))
        break
optimizer.step()

print("\nDiagnostic run complete.")






repo_root = /home/user/abin_ref_papers/project_structure_demo/dnn_template
Inserted repo_root into sys.path
models_dir exists: True  losses_dir exists: True
Creating in-memory 'models' package module (so relative imports work)
Imported models.backbone OK
Registry keys now: ['resnet34', 'daiic_resnet34', 'resnet50', 'daiic_resnet50', 'densenet121']
Using factory: daiic_resnet34 function: <function daiic_resnet34 at 0x79350d2c0ee0>
Loaded ALRLoss from /home/user/abin_ref_papers/project_structure_demo/dnn_template/losses/alr.py
Device: cuda
Instantiated model type: <class 'models.backbone.DAIICModule'>

=== Shapes ===
logits: torch.Size([32, 100])
probs: torch.Size([32, 100])
W: torch.Size([32, 100])
F shape: torch.Size([32, 512, 4, 4])
Fprime shape: torch.Size([32, 512, 4, 4])

=== Basic stats ===
logits min/max/mean: -26.30319595336914 26.12763023376465 -1.1395820379257202
probs min/max/mean: 3.772828837539377e-12 1.0 0.42649194598197937
W min/max/mean: 5.881404874230611e-08 0.183967769

In [2]:
# Run inside your notebook (after you created `model` via daiic factory)
import torch
import torch.nn as nn

def reinit_cost_attention(model, init_std=1e-3, zero_bias=True, verbose=True):
    """
    Re-initialize CostAttention conv weights small and zero bias.
    - model: the DAIICModule instance (or container holding it)
    - init_std: std dev for normal init (try 1e-3, 1e-4)
    """
    # try common attribute paths
    candidates = []
    if hasattr(model, "cost_att"):
        candidates.append(model.cost_att)
    # sometimes model may be wrapper: try deeper search
    for name, module in model.named_modules():
        # detect by attribute/class name
        if module.__class__.__name__.lower().startswith("costattention") or getattr(module, "conv1x1", None) is not None:
            candidates.append(module)

    touched = 0
    for module in candidates:
        conv = getattr(module, "conv1x1", None)
        if conv is None:
            continue
        if isinstance(conv, nn.Conv2d):
            nn.init.normal_(conv.weight, mean=0.0, std=float(init_std))
            if zero_bias and conv.bias is not None:
                nn.init.constant_(conv.bias, 0.0)
            touched += 1
            if verbose:
                print(f"Re-init {module.__class__.__name__}.conv1x1 -> std={init_std}, zero_bias={zero_bias}")
    if touched == 0 and verbose:
        print("No CostAttention.conv1x1 found on model. Inspect model.named_modules() to find proper path.")
    return touched

# Example usage (choose small std)
reinit_cost_attention(model, init_std=1e-3, zero_bias=True)
# Optionally combine with increased softmax temperature if implemented:
if hasattr(model, "cost_att") and hasattr(model.cost_att, "tau"):
    model.cost_att.tau = 8.0
    print("Set cost_att.tau =", model.cost_att.tau)


Re-init CostAttention.conv1x1 -> std=0.001, zero_bias=True
Re-init CostAttention.conv1x1 -> std=0.001, zero_bias=True


In [4]:
# Run after re-init or after changing file + reloading model
import torch, torch.nn.functional as F, numpy as np
# create or re-create model:
model = models_pkg.REGISTRY["daiic_resnet34"](num_classes=100, pretrained=False, cifar_stem=True, in_channels=3).to('cuda')
# if you used monkeypatch reinit on previous model instance, use that; otherwise it's a fresh new model
model.eval()
x = torch.randn(8,3,32,32).cuda()
with torch.no_grad():
    out = model(x)
W = out.get("W", None)
if W is None:
    print("Model did not return W")
else:
    probs = out.get("probs", None)
    logits = out.get("logits", None)
    # compute entropy
    def entropy(p):
        p = p / (p.sum(dim=1, keepdim=True)+1e-12)
        return -(p * (p.clamp(1e-12).log())).sum(dim=1)
    W_top1_val, W_top1_idx = W.max(dim=1)
    print("W mean:", float(W.mean()), "W top1 mean:", float(W_top1_val.mean()), "W entropy mean:", float(entropy(W).mean()))
    print("W rows (first 4, first 10 cols):")
    print(W[:4,:10].cpu().numpy())


W mean: 0.009999999776482582 W top1 mean: 0.9621791839599609 W entropy mean: 0.17686697840690613
W rows (first 4, first 10 cols):
[[2.8333790e-04 1.7396493e-10 1.7888036e-08 8.3736307e-07 1.1745243e-06
  4.2493301e-10 7.0379946e-07 1.9076278e-09 5.4415160e-08 3.6905321e-09]
 [2.3198091e-04 2.8960329e-10 1.9363863e-08 6.2880554e-07 1.5742845e-06
  3.7805695e-10 7.7095638e-07 1.5517209e-09 4.9369703e-08 4.4073589e-09]
 [2.2140451e-04 1.2493841e-10 1.4881882e-08 4.2759348e-07 7.4848361e-07
  2.5877844e-10 5.5893832e-07 9.8466801e-10 5.5105289e-08 2.9031730e-09]
 [2.6897390e-04 2.8882699e-10 2.7090564e-08 1.1810084e-06 2.5095926e-06
  4.9874666e-10 1.0366307e-06 3.7599381e-09 7.1720621e-08 5.0822404e-09]]


In [5]:
# MONKEYPATCH: reinit conv smaller and set tau larger, then test W
import torch, torch.nn as nn, importlib
import numpy as np

# get fresh model (or reuse 'model' if you have one)
# model = models_pkg.REGISTRY["daiic_resnet34"](num_classes=100, pretrained=False, cifar_stem=True).cuda()
# If you already have `model`, skip creating new one.

def apply_costatt_fix(model, init_std=1e-4, tau=10.0, zero_bias=True, verbose=True):
    touched = 0
    for name, module in model.named_modules():
        if module.__class__.__name__.lower().startswith("costattention") or hasattr(module, "conv1x1"):
            conv = getattr(module, "conv1x1", None)
            if isinstance(conv, nn.Conv2d):
                nn.init.normal_(conv.weight, mean=0.0, std=float(init_std))
                if zero_bias and conv.bias is not None:
                    nn.init.constant_(conv.bias, 0.0)
                # set tau if available
                if hasattr(module, "tau"):
                    module.tau = float(tau)
                # store a flag if needed
                setattr(module, "_patched_init_std", float(init_std))
                touched += 1
                if verbose:
                    print(f"Patched {name} (class {module.__class__.__name__}): init_std={init_std}, tau={getattr(module,'tau',None)}")
    if touched == 0:
        print("No CostAttention-like module found. Inspect model.named_modules().")
    return touched

# Apply to your current model (or create/recreate model then apply)
apply_costatt_fix(model, init_std=1e-4, tau=10.0)

# quick W diagnostic on a random minibatch (or reuse your x,y)
model.eval()
x = torch.randn(8,3,32,32).cuda()
with torch.no_grad():
    out = model(x)
W = out.get("W")
def entropy(p):
    p = p / (p.sum(dim=1, keepdim=True)+1e-12)
    return -(p * (p.clamp(1e-12).log())).sum(dim=1)
W_top1_val = W.max(dim=1).values
print("W mean:", float(W.mean()), "W top1 mean:", float(W_top1_val.mean()), "W entropy mean:", float(entropy(W).mean()))
print("Example W rows (first 4, first 10 cols):")
print(W[:4,:10].cpu().numpy())


Patched cost_att (class CostAttention): init_std=0.0001, tau=None
W mean: 0.009999999776482582 W top1 mean: 0.010759033262729645 W entropy mean: 4.60485315322876
Example W rows (first 4, first 10 cols):
[[0.01024134 0.00991709 0.01001475 0.00944132 0.00979228 0.00963728
  0.00999563 0.00955502 0.00984503 0.00963696]
 [0.01025479 0.00991656 0.01002092 0.00943761 0.00977195 0.00961552
  0.00999963 0.00952794 0.00982326 0.00963264]
 [0.01026538 0.00992124 0.0100203  0.00944282 0.00978252 0.00962164
  0.00999355 0.0095411  0.00985025 0.0096305 ]
 [0.01025944 0.00991534 0.01001347 0.0094259  0.00979426 0.00961505
  0.00998644 0.00952348 0.00983254 0.00962814]]


In [6]:
# Paste & run in your notebook
import time, math
import torch, torch.nn as nn, torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader
import models as models_pkg
import importlib, importlib.util, pathlib

repo_root = pathlib.Path("/home/user/abin_ref_papers/project_structure_demo/dnn_template").resolve()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# factory and ALR
daiic_factory = models_pkg.REGISTRY["daiic_resnet34"]

# load ALRLoss from losses
spec = importlib.util.spec_from_file_location("local_alr", str(repo_root / "losses" / "alr.py"))
mod_alr = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod_alr)
ALRLoss = mod_alr.ALRLoss

# hyperparams (feel free to tune)
num_classes = 100
batch_size = 128
warmup_epochs = 2           # CE warmup
alr_epochs = 6              # ALR training after warmup (increase later)
lr_base = 0.01
lr_head = 0.1
weight_decay = 1e-4
lambda_ent = 0.0            # 0.0 unless you want entropy reg
alr_scale = float(num_classes)  # scale ALR to match CE magnitude (tweak)

# dataloaders (small quick transforms)
transform = T.Compose([T.ToTensor(), T.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761))])
train_ds = torchvision.datasets.CIFAR100(root=str(repo_root / "data"), train=True, download=True, transform=transform)
val_ds = torchvision.datasets.CIFAR100(root=str(repo_root / "data"), train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

# utils
def instantiate_model():
    model = daiic_factory(num_classes=num_classes, pretrained=False, cifar_stem=True, in_channels=3).to(device)
    # if model has cost_att and tau attr, you can set initial tau here:
    for name, m in model.named_modules():
        if m.__class__.__name__.lower().startswith("costattention") and hasattr(m, "tau"):
            # start with high tau (very soft), will anneal later
            m.tau = 8.0
    return model

def make_optimizer(model):
    head_params, base_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad: continue
        if ("classifier" in n) or ("fc" in n) or ("head" in n):
            head_params.append(p)
        else:
            base_params.append(p)
    groups = []
    if base_params: groups.append({'params': base_params, 'lr': lr_base})
    if head_params: groups.append({'params': head_params, 'lr': lr_head})
    if not groups: groups = [{'params': model.parameters(), 'lr': lr_base}]
    opt = torch.optim.SGD(groups, momentum=0.9, weight_decay=weight_decay)
    return opt

def W_stats(W):
    if W is None: return {}
    top1 = float(W.max(dim=1).values.mean())
    entropy = float((-(W*(W.clamp(1e-12).log())).sum(dim=1)).mean())
    return {'W_mean': float(W.mean()), 'W_top1_mean': top1, 'W_entropy_mean': entropy}

# simple train loop for a given loss mode
def train_one_epoch(model, opt, loss_mode='ce', epoch=0, total_epochs=1, alr_loss_fn=None, scheduler_tau=None):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    n = 0
    for i, (x,y) in enumerate(train_loader):
        x = x.to(device); y = y.to(device)
        opt.zero_grad()
        out = model(x)
        logits = out.get('logits')
        W = out.get('W')
        # normalize if necessary (but assume W already normalized by module)
        if loss_mode == 'ce':
            y_idx = y.view(-1)
            loss = nn.CrossEntropyLoss()(logits, y_idx)
        elif loss_mode == 'alr':
            # scale ALR to have comparable magnitude to CE
            raw = alr_loss_fn(logits, y, W)
            if isinstance(raw, (tuple, list)):
                raw = raw[0]
            loss = raw * alr_scale
            # optional entropy reg (encourage non-peaky)
            if lambda_ent > 0.0:
                ent = (-(W * (W.clamp(1e-12).log())).sum(dim=1)).mean()
                loss = loss - lambda_ent * ent
        else:
            raise ValueError
        loss.backward()
        opt.step()

        # optionally anneal tau per step if scheduler provided
        if scheduler_tau is not None:
            scheduler_tau(step=(epoch + i/len(train_loader)), total_epochs=total_epochs)

        bs = x.size(0)
        running_loss += float(loss.item()) * bs
        pred = logits.argmax(dim=1)
        y_idx = y.view(-1)
        running_acc += int((pred == y_idx).sum().item())
        n += bs

        if (i+1) % 100 == 0:
            print(f"  step {i+1} | loss {running_loss/n:.4f} acc {running_acc/n:.4f}")

    return running_loss / max(1,n), running_acc / max(1,n)

def eval_one_epoch(model, loss_mode='ce', alr_loss_fn=None, max_batches=50):
    model.eval()
    loss_sum = 0.0; correct = 0; n=0
    with torch.no_grad():
        for i,(x,y) in enumerate(val_loader):
            x = x.to(device); y = y.to(device)
            out = model(x)
            logits = out.get('logits'); W = out.get('W')
            if loss_mode=='ce':
                l = nn.CrossEntropyLoss()(logits, y.view(-1)).item()
            else:
                raw = alr_loss_fn(logits,y,W)
                if isinstance(raw, (tuple,list)): raw = raw[0]
                l = float(raw.item())*alr_scale
            loss_sum += l * x.size(0)
            pred = logits.argmax(dim=1)
            correct += int((pred == y.view(-1)).sum().item())
            n += x.size(0)
            if i >= max_batches: break
    return loss_sum/max(1,n), correct/max(1,n)

# instantiate
model = instantiate_model()
opt = make_optimizer(model)
alr = ALRLoss(class_weights=None, reduction='mean')

# simple tau annealing scheduler function (optional)
def make_tau_scheduler(module, tau_start=8.0, tau_end=1.0, total_epochs=10):
    def scheduler(step, total_epochs=total_epochs):
        # step: fractional epoch (e.g., epoch + iter/len)
        frac = min(1.0, max(0.0, step/float(total_epochs)))
        tau = tau_start + (tau_end - tau_start) * frac
        # apply to all cost_att modules
        for name, m in module.named_modules():
            if m.__class__.__name__.lower().startswith("costattention") and hasattr(m, 'tau'):
                m.tau = tau
    return scheduler

tau_sched = make_tau_scheduler(model, tau_start=8.0, tau_end=1.0, total_epochs=(warmup_epochs+alr_epochs))

# === Warmup (CE) ===
print("Starting CE warmup for", warmup_epochs, "epochs")
for e in range(warmup_epochs):
    t0 = time.time()
    train_loss, train_acc = train_one_epoch(model, opt, loss_mode='ce', epoch=e, total_epochs=warmup_epochs+alr_epochs)
    val_loss, val_acc = eval_one_epoch(model, loss_mode='ce')
    print(f"CE epoch {e+1}/{warmup_epochs} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | time {(time.time()-t0):.1f}s")
    # print W stats
    with torch.no_grad():
        out = model(next(iter(val_loader))[0].to(device))
        W = out.get('W')
        print("  W stats:", W_stats(W))

# === ALR training with tau anneal ===
print("Starting ALR phase for", alr_epochs, "epochs")
for e in range(alr_epochs):
    t0 = time.time()
    train_loss, train_acc = train_one_epoch(model, opt, loss_mode='alr', epoch=warmup_epochs+e, total_epochs=warmup_epochs+alr_epochs, alr_loss_fn=alr, scheduler_tau=tau_sched)
    val_loss, val_acc = eval_one_epoch(model, loss_mode='alr', alr_loss_fn=alr)
    print(f"ALR epoch {e+1}/{alr_epochs} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | time {(time.time()-t0):.1f}s")
    with torch.no_grad():
        out = model(next(iter(val_loader))[0].to(device))
        print("  W stats:", W_stats(out.get('W')))


Device: cuda
Starting CE warmup for 2 epochs
  step 100 | loss 5.2701 acc 0.0148
  step 200 | loss 4.8016 acc 0.0287
  step 300 | loss 4.5323 acc 0.0449
CE epoch 1/2 | train_loss 4.3671 train_acc 0.0586 | val_loss 3.7816 val_acc 0.1060 | time 22.1s
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.010772624984383583, 'W_entropy_mean': 4.604591369628906}
  step 100 | loss 3.6845 acc 0.1230
  step 200 | loss 3.6238 acc 0.1350
  step 300 | loss 3.5465 acc 0.1471
CE epoch 2/2 | train_loss 3.4905 train_acc 0.1575 | val_loss 3.1812 val_acc 0.2135 | time 22.5s
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.010868331417441368, 'W_entropy_mean': 4.60463809967041}
Starting ALR phase for 6 epochs
  step 100 | loss 8.2516 acc 0.0151
  step 200 | loss 6.9176 acc 0.0163
  step 300 | loss 7.1837 acc 0.0156
ALR epoch 1/6 | train_loss 6.8876 train_acc 0.0158 | val_loss 5.2087 val_acc 0.0162 | time 24.0s
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.13355275988578

In [7]:
# === Paste & run this entire cell in your Jupyter notebook ===
import time, math, pathlib, importlib, importlib.util
import torch, torch.nn as nn, torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader

repo_root = pathlib.Path("/home/user/abin_ref_papers/project_structure_demo/dnn_template").resolve()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device, "repo_root:", repo_root)

# === Hyperparameters - tune these ===
num_classes = 100
batch_size = 128
warmup_epochs = 3               # CE warmup epochs (freeze cost-att during these)
alr_epochs = 10                 # ALR training epochs (after warmup)
lr_base = 0.01
lr_head = 0.1
weight_decay = 1e-4

# ALR specifics
alr_scale = float(num_classes)   # scale ALR to make magnitudes comparable to CE
lambda_ent = 0.1                 # entropy regularizer weight (try 0.05 - 0.2)
force_uniform_W_for_alr_epochs = 1  # number of ALR epochs to force uniform W (diagnostic)
initial_tau = 20.0               # starting softmax temperature in CostAttention
final_tau = 1.0                  # final tau after anneal
tau_anneal_epochs = max(1, alr_epochs)  # over how many ALR epochs to anneal tau

# dataset / transforms (CIFAR-100)
transform = T.Compose([T.ToTensor(), T.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761))])
train_ds = torchvision.datasets.CIFAR100(root=str(repo_root / "data"), train=True, download=True, transform=transform)
val_ds = torchvision.datasets.CIFAR100(root=str(repo_root / "data"), train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

# === Load model factory & ALR loss ===
import models as models_pkg  # should already be importable in your notebook
assert "daiic_resnet34" in models_pkg.REGISTRY, "daiic_resnet34 not found in models.REGISTRY"
daiic_factory = models_pkg.REGISTRY["daiic_resnet34"]

spec = importlib.util.spec_from_file_location("local_alr", str(repo_root / "losses" / "alr.py"))
mod_alr = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod_alr)
ALRLoss = mod_alr.ALRLoss

# === Helpers ===
def instantiate_model():
    m = daiic_factory(num_classes=num_classes, pretrained=False, cifar_stem=True, in_channels=3).to(device)
    # set initial tau if module supports it
    for _, mm in m.named_modules():
        if mm.__class__.__name__.lower().startswith("costattention") and hasattr(mm, "tau"):
            mm.tau = float(initial_tau)
    return m

def make_optimizer(model):
    head_params, base_params = [], []
    for n,p in model.named_parameters():
        if not p.requires_grad: 
            continue
        if ("classifier" in n) or ("fc" in n) or ("head" in n):
            head_params.append(p)
        else:
            base_params.append(p)
    groups = []
    if base_params: groups.append({'params': base_params, 'lr': lr_base})
    if head_params: groups.append({'params': head_params, 'lr': lr_head})
    if not groups:
        groups = [{'params': model.parameters(), 'lr': lr_base}]
    opt = torch.optim.SGD(groups, momentum=0.9, weight_decay=weight_decay)
    return opt

def set_costatt_requires_grad(model, req=False):
    touched = 0
    for name, m in model.named_modules():
        if m.__class__.__name__.lower().startswith("costattention") or hasattr(m, "conv1x1"):
            conv = getattr(m, "conv1x1", None)
            if conv is not None:
                for p in conv.parameters():
                    p.requires_grad = bool(req)
                touched += 1
    print("set_costatt_requires_grad -> touched:", touched, "req:", req)

def normalize_outputs(outputs):
    # same normalization used earlier
    logits = None; W = None
    if isinstance(outputs, dict):
        logits = outputs.get("logits", None)
        if logits is None and "probs" in outputs:
            probs = outputs["probs"].clamp(1e-6, 1-1e-6)
            logits = torch.log(probs / (1.0 - probs))
        W = outputs.get("W", None)
    elif torch.is_tensor(outputs):
        logits = outputs
    else:
        raise RuntimeError("Unsupported model outputs type")
    logits = logits.to(device).float()
    if W is not None:
        W = W.to(device).float()
        if not torch.isfinite(W).all():
            W = torch.ones_like(logits)
        W = W.clamp(min=0.0)
        rs = W.sum(dim=1, keepdim=True)
        zero_rows = rs == 0
        if zero_rows.any():
            W[zero_rows.expand_as(W)] = 1.0
            rs = W.sum(dim=1, keepdim=True)
        W = W / (rs + 1e-12)
        if not torch.isfinite(W).all() or (W.abs().max() < 1e-12):
            W = torch.ones_like(logits) / float(logits.size(1))
    return logits, W, outputs

def W_stats(W):
    if W is None:
        return {}
    top1 = float(W.max(dim=1).values.mean())
    ent = float((-(W * (W.clamp(1e-12).log())).sum(dim=1)).mean())
    return {'W_mean': float(W.mean()), 'W_top1_mean': top1, 'W_entropy_mean': ent}

# === Training/Eval loops ===
ce_loss_fn = nn.CrossEntropyLoss()
alr_loss_fn = ALRLoss(class_weights=None, reduction="mean")

def train_one_epoch(model, opt, epoch_idx, total_epochs, mode='ce', force_uniform_W=False, lambda_ent=0.0, alr_scale=1.0):
    model.train()
    running_loss = 0.0
    running_correct = 0
    seen = 0
    for step, (x,y) in enumerate(train_loader, 1):
        x = x.to(device); y = y.to(device)
        opt.zero_grad()
        outputs = model(x)
        logits, W, _ = normalize_outputs(outputs)
        if mode == 'ce':
            loss = ce_loss_fn(logits, y.view(-1))
        elif mode == 'alr':
            if force_uniform_W:
                W_use = torch.ones_like(logits, device=logits.device) / float(logits.size(1))
            else:
                W_use = W
            raw = alr_loss_fn(logits, y, W_use)
            if isinstance(raw, (tuple, list)):
                raw = raw[0]
            loss = raw * alr_scale
            if lambda_ent > 0.0:
                ent = (-(W_use * (W_use.clamp(1e-12).log())).sum(dim=1)).mean()
                # subtract ent (we want to maximize entropy) with small weight
                loss = loss - lambda_ent * ent
        else:
            raise ValueError("mode must be 'ce' or 'alr'")

        loss.backward()
        opt.step()

        preds = logits.argmax(dim=1)
        running_correct += int((preds == y.view(-1)).sum().item())
        running_loss += float(loss.item()) * x.size(0)
        seen += x.size(0)

        # optional per-step logging
        if step % 200 == 0:
            print(f"  step {step} | avg_loss {running_loss/seen:.4f} | avg_acc {running_correct/seen:.4f}")

    return running_loss / max(1, seen), running_correct / max(1, seen)

def eval_model(model, mode='ce', alr_scale=1.0, max_batches=50):
    model.eval()
    loss_sum = 0.0; correct = 0; n = 0
    with torch.no_grad():
        for i, (x,y) in enumerate(val_loader):
            x = x.to(device); y = y.to(device)
            outputs = model(x)
            logits, W, _ = normalize_outputs(outputs)
            if mode == 'ce':
                l = float(ce_loss_fn(logits, y.view(-1)).item())
            else:
                raw = alr_loss_fn(logits, y, W)
                if isinstance(raw, (tuple,list)): raw = raw[0]
                l = float(raw.item()) * alr_scale
            loss_sum += l * x.size(0)
            preds = logits.argmax(dim=1)
            correct += int((preds == y.view(-1)).sum().item())
            n += x.size(0)
            if i >= max_batches: break
    return loss_sum / max(1, n), correct / max(1, n)

# === Instantiate model, freeze cost-att, and run CE warmup ===
model = instantiate_model()
set_costatt_requires_grad(model, req=False)  # freeze cost-att during warmup
opt = make_optimizer(model)

print("\n=== CE Warmup ===")
for e in range(warmup_epochs):
    t0 = time.time()
    train_loss, train_acc = train_one_epoch(model, opt, epoch_idx=e, total_epochs=warmup_epochs, mode='ce')
    val_loss, val_acc = eval_model(model, mode='ce')
    # print W stats on a small val batch
    with torch.no_grad():
        out = model(next(iter(val_loader))[0].to(device))
        _, W_val, _ = normalize_outputs(out)
    print(f"CE epoch {e+1}/{warmup_epochs} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | time {(time.time()-t0):.1f}s")
    print("  W stats:", W_stats(W_val))

# === Prepare for ALR phase: unfreeze cost-att, recreate optimizer ===
set_costatt_requires_grad(model, req=True)
opt = make_optimizer(model)

# tau annealing schedule across ALR epochs
def tau_for_alr_epoch(epoch_idx):
    frac = min(1.0, max(0.0, float(epoch_idx) / float(max(1, tau_anneal_epochs-1))))
    return initial_tau + (final_tau - initial_tau) * frac

print("\n=== ALR Phase ===")
for alr_epoch in range(alr_epochs):
    # set tau for this epoch
    tau_val = tau_for_alr_epoch(alr_epoch)
    for _, mm in model.named_modules():
        if mm.__class__.__name__.lower().startswith("costattention") and hasattr(mm, "tau"):
            mm.tau = float(tau_val)
    # choose whether to force uniform W during first few ALR epochs
    force_uniform = (alr_epoch < force_uniform_W_for_alr_epochs)
    t0 = time.time()
    train_loss, train_acc = train_one_epoch(model, opt, epoch_idx=alr_epoch, total_epochs=alr_epochs, mode='alr',
                                            force_uniform_W=force_uniform, lambda_ent=lambda_ent, alr_scale=alr_scale)
    val_loss, val_acc = eval_model(model, mode='alr', alr_scale=alr_scale)
    # sample W stats from val_loader
    with torch.no_grad():
        out = model(next(iter(val_loader))[0].to(device))
        _, W_val, _ = normalize_outputs(out)
    print(f"ALR epoch {alr_epoch+1}/{alr_epochs} | tau {tau_val:.3f} | force_uniform {force_uniform} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | time {(time.time()-t0):.1f}s")
    print("  W stats:", W_stats(W_val))

print("\nTraining run complete. Monitor W_top1_mean and W_entropy_mean; if W_top1 climbs too fast, increase tau or lambda_ent, or force longer uniform W.")


Device: cuda repo_root: /home/user/abin_ref_papers/project_structure_demo/dnn_template
set_costatt_requires_grad -> touched: 1 req: False

=== CE Warmup ===
  step 200 | avg_loss 4.9911 | avg_acc 0.0212
CE epoch 1/3 | train_loss 4.5498 train_acc 0.0417 | val_loss 3.9484 val_acc 0.0855 | time 23.0s
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.010767602361738682, 'W_entropy_mean': 4.604740619659424}
  step 200 | avg_loss 3.7797 | avg_acc 0.1093
CE epoch 2/3 | train_loss 3.6698 train_acc 0.1259 | val_loss 3.5143 val_acc 0.1570 | time 23.2s
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.01086423359811306, 'W_entropy_mean': 4.604649543762207}
  step 200 | avg_loss 3.2888 | avg_acc 0.1921
CE epoch 3/3 | train_loss 3.1943 train_acc 0.2086 | val_loss 3.0100 val_acc 0.2446 | time 24.5s
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.010853853076696396, 'W_entropy_mean': 4.60462760925293}
set_costatt_requires_grad -> touched: 1 req: True

=== ALR Phase 

In [5]:
# Identity-safe optimizer creator: compare params by id() to avoid tensor-equality checks
import torch, torch.nn as nn

def make_optimizer_with_costatt_lr(model, base_lr=1e-2, head_lr=1e-1, costatt_lr=None, weight_decay=1e-4):
    head_params = []
    base_params = []
    costatt_params = []

    # Collect base/head params heuristically by name
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if ("classifier" in n) or ("fc" in n) or ("head" in n):
            head_params.append(p)
        else:
            base_params.append(p)

    # If costatt_lr is specified, collect conv params from cost-att modules
    if costatt_lr is not None:
        conv_params = []
        for name, m in model.named_modules():
            if m.__class__.__name__.lower().startswith("costattention") or hasattr(m, "conv1x1"):
                conv = getattr(m, "conv1x1", None)
                if conv is not None:
                    for p in conv.parameters():
                        conv_params.append(p)

        # use identity comparison via id()
        conv_ids = {id(p) for p in conv_params}
        base_params = [p for p in base_params if id(p) not in conv_ids]
        head_params = [p for p in head_params if id(p) not in conv_ids]
        costatt_params = conv_params

    # Build param groups
    param_groups = []
    if base_params:
        param_groups.append({'params': base_params, 'lr': base_lr})
    if head_params:
        param_groups.append({'params': head_params, 'lr': head_lr})
    if costatt_params:
        param_groups.append({'params': costatt_params, 'lr': costatt_lr})
    if not param_groups:
        param_groups = [{'params': model.parameters(), 'lr': base_lr}]

    opt = torch.optim.SGD(param_groups, momentum=0.9, weight_decay=weight_decay)
    return opt

# Recreate optimizer using your chosen LRs
costatt_low_lr = 1e-6
lr_base = 0.01
lr_head = 0.1
weight_decay = 1e-4

opt = make_optimizer_with_costatt_lr(model, base_lr=lr_base, head_lr=lr_head, costatt_lr=costatt_low_lr, weight_decay=weight_decay)
print("Recreated optimizer with param groups:", [(len(g['params']), g['lr']) for g in opt.param_groups])


Recreated optimizer with param groups: [(108, 0.01), (6, 0.1), (2, 1e-06)]


In [7]:
# Paste & run this cell in your notebook (replacement / tuned training)
import time, math, pathlib, importlib, importlib.util
import torch, torch.nn as nn, torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader

repo_root = pathlib.Path("/home/user/abin_ref_papers/project_structure_demo/dnn_template").resolve()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# === Conservative hyperparams (more cautious than before) ===
num_classes = 100
batch_size = 128
warmup_epochs = 3           # CE warmup epochs
alr_epochs = 12             # ALR epochs after warmup
lr_base = 0.01
lr_head = 0.1
weight_decay = 1e-4

# ALR specifics (more conservative)
alr_scale = float(num_classes)
lambda_ent = 0.2                    # stronger entropy regularizer
force_uniform_alr_epochs = 3        # force uniform W for first 3 ALR epochs
initial_tau = 50.0                  # very soft initial tau
final_tau = 5.0                     # don't anneal too quickly
tau_anneal_epochs = max(1, alr_epochs)  # slow anneal across ALR epochs
costatt_low_lr = 1e-6               # extremely small LR for cost_att (instead of full freeze)
w_top1_alert_thresh = 0.5           # if W_top1_mean exceeds this, we will re-apply conservative measures

# dataset
transform = T.Compose([T.ToTensor(), T.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761))])
train_ds = torchvision.datasets.CIFAR100(root=str(repo_root / "data"), train=True, download=True, transform=transform)
val_ds = torchvision.datasets.CIFAR100(root=str(repo_root / "data"), train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

# load factory + ALR
import models as models_pkg
assert "daiic_resnet34" in models_pkg.REGISTRY
daiic_factory = models_pkg.REGISTRY["daiic_resnet34"]
spec = importlib.util.spec_from_file_location("local_alr", str(repo_root / "losses" / "alr.py"))
mod_alr = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod_alr)
ALRLoss = mod_alr.ALRLoss

# helpers
def instantiate_model():
    m = daiic_factory(num_classes=num_classes, pretrained=False, cifar_stem=True, in_channels=3).to(device)
    for _, mm in m.named_modules():
        if mm.__class__.__name__.lower().startswith("costattention") and hasattr(mm, "tau"):
            mm.tau = float(initial_tau)
    return m

def get_costatt_modules(model):
    mods = []
    for name,m in model.named_modules():
        if m.__class__.__name__.lower().startswith("costattention") or hasattr(m, "conv1x1"):
            mods.append((name,m))
    return mods

def make_optimizer_with_costatt_lr(model, base_lr=lr_base, head_lr=lr_head, costatt_lr=None):
    head_params, base_params, costatt_params = [], [], []
    for n,p in model.named_parameters():
        if not p.requires_grad:
            continue
        if ("classifier" in n) or ("fc" in n) or ("head" in n):
            head_params.append(p)
        else:
            base_params.append(p)
    # gather costatt parameters separately if costatt_lr provided
    if costatt_lr is not None:
        # find costatt modules
        for name,m in get_costatt_modules(model):
            conv = getattr(m, "conv1x1", None)
            if conv is not None:
                for p in conv.parameters():
                    # remove p from base/head if accidentally included
                    if p in base_params: base_params.remove(p)
                    if p in head_params: head_params.remove(p)
                    costatt_params.append(p)
    param_groups = []
    if base_params: param_groups.append({'params': base_params, 'lr': base_lr})
    if head_params: param_groups.append({'params': head_params, 'lr': head_lr})
    if costatt_params: param_groups.append({'params': costatt_params, 'lr': costatt_lr})
    if not param_groups:
        param_groups = [{'params': model.parameters(), 'lr': base_lr}]
    opt = torch.optim.SGD(param_groups, momentum=0.9, weight_decay=weight_decay)
    return opt

def normalize_outputs(outputs):
    logits=None; W=None
    if isinstance(outputs, dict):
        logits = outputs.get("logits", None)
        if logits is None and "probs" in outputs:
            probs = outputs["probs"].clamp(1e-6,1-1e-6)
            logits = torch.log(probs/(1-probs))
        W = outputs.get("W", None)
    elif torch.is_tensor(outputs):
        logits = outputs
    logits = logits.to(device).float()
    if W is not None:
        W = W.to(device).float()
        if not torch.isfinite(W).all(): W = torch.ones_like(logits)
        W = W.clamp(min=0.0)
        rs = W.sum(dim=1, keepdim=True)
        zero_rows = rs == 0
        if zero_rows.any():
            W[zero_rows.expand_as(W)] = 1.0
            rs = W.sum(dim=1, keepdim=True)
        W = W / (rs + 1e-12)
        if not torch.isfinite(W).all() or (W.abs().max() < 1e-12):
            W = torch.ones_like(logits)/float(logits.size(1))
    return logits, W, outputs

def W_stats(W):
    if W is None: return {}
    t1 = float(W.max(dim=1).values.mean())
    ent = float((-(W*(W.clamp(1e-12).log())).sum(dim=1)).mean())
    return {'W_mean': float(W.mean()), 'W_top1_mean': t1, 'W_entropy_mean': ent}

# instantiate and initial optimizer: during warmup we don't want costatt to learn -> costatt_lr=costatt_low_lr
model = instantiate_model()
opt = make_optimizer_with_costatt_lr(model, base_lr=lr_base, head_lr=lr_head, costatt_lr=costatt_low_lr)

ce_loss_fn = nn.CrossEntropyLoss()
alr_loss_fn = ALRLoss(class_weights=None, reduction='mean')

# training loops
def train_epoch(model, opt, loader, mode='ce', force_uniform_W=False, lambda_ent=0.0, alr_scale=1.0):
    model.train()
    total_loss=0.0; total_correct=0; total_n=0
    for i,(x,y) in enumerate(loader,1):
        x = x.to(device); y = y.to(device)
        opt.zero_grad()
        out = model(x)
        logits, W, _ = normalize_outputs(out)
        if mode=='ce':
            loss = ce_loss_fn(logits, y.view(-1))
        else:
            if force_uniform_W:
                W_use = torch.ones_like(logits, device=logits.device) / float(logits.size(1))
            else:
                W_use = W
            raw = alr_loss_fn(logits, y, W_use)
            if isinstance(raw, (tuple,list)): raw = raw[0]
            loss = raw * alr_scale
            if lambda_ent > 0.0:
                ent = (-(W_use * (W_use.clamp(1e-12).log())).sum(dim=1)).mean()
                loss = loss - lambda_ent * ent
        loss.backward()
        opt.step()
        preds = logits.argmax(dim=1)
        total_correct += int((preds == y.view(-1)).sum().item())
        total_loss += float(loss.item()) * x.size(0)
        total_n += x.size(0)
    return total_loss/max(1,total_n), total_correct/max(1,total_n)

def eval_model(model, loader, mode='ce', alr_scale=1.0, max_batches=50):
    model.eval()
    loss_sum=0.0; correct=0; n=0
    with torch.no_grad():
        for i,(x,y) in enumerate(loader):
            x = x.to(device); y = y.to(device)
            out = model(x)
            logits, W, _ = normalize_outputs(out)
            if mode=='ce':
                l = float(ce_loss_fn(logits, y.view(-1)).item())
            else:
                raw = alr_loss_fn(logits, y, W)
                if isinstance(raw,(tuple,list)): raw = raw[0]
                l = float(raw.item())*alr_scale
            loss_sum += l * x.size(0)
            preds = logits.argmax(dim=1)
            correct += int((preds == y.view(-1)).sum().item())
            n += x.size(0)
            if i>=max_batches: break
    return loss_sum/max(1,n), correct/max(1,n)

print("\n=== CE warmup (cost_att very low LR) ===")
for e in range(warmup_epochs):
    t0=time.time()
    train_loss, train_acc = train_epoch(model, opt, train_loader, mode='ce')
    val_loss, val_acc = eval_model(model, val_loader, mode='ce')
    # sample W stats
    with torch.no_grad():
        _, W_val, _ = normalize_outputs(model(next(iter(val_loader))[0].to(device)))
    print(f"CE epoch {e+1}/{warmup_epochs} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | time {(time.time()-t0):.1f}s")
    print("  W stats:", W_stats(W_val))

# Now prepare ALR: we'll increase costatt lr slightly (still small) and optionally force uniform first few ALR epochs
print("\nReconfiguring optimizer for ALR phase: costatt lr will be small but nonzero.")
# set costatt lr to small positive value to allow slow adaptation
opt = make_optimizer_with_costatt_lr(model, base_lr=lr_base, head_lr=lr_head, costatt_lr=1e-5)

def tau_for_epoch(idx):
    frac = min(1.0, float(idx) / float(max(1,tau_anneal_epochs-1)))
    return initial_tau + (final_tau - initial_tau) * frac

print("\n=== ALR phase (conservative) ===")
for ae in range(alr_epochs):
    # set tau for costatt modules
    tau_val = tau_for_epoch(ae)
    for _, mm in model.named_modules():
        if mm.__class__.__name__.lower().startswith("costattention") and hasattr(mm, "tau"):
            mm.tau = float(tau_val)
    force_uniform = (ae < force_uniform_alr_epochs)
    t0=time.time()
    train_loss, train_acc = train_epoch(model, opt, train_loader, mode='alr',
                                        force_uniform_W=force_uniform, lambda_ent=lambda_ent, alr_scale=alr_scale)
    val_loss, val_acc = eval_model(model, val_loader, mode='alr', alr_scale=alr_scale)
    # sample W stats
    with torch.no_grad():
        _, W_val, _ = normalize_outputs(model(next(iter(val_loader))[0].to(device)))
    stats = W_stats(W_val)
    print(f"ALR epoch {ae+1}/{alr_epochs} | tau {tau_val:.3f} | force_uniform {force_uniform} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | time {(time.time()-t0):.1f}s")
    print("  W stats:", stats)

    # Auto-safety: if W becomes too peaky early, re-apply conservative measures
    if stats.get('W_top1_mean', 0.0) > w_top1_alert_thresh and ae < (force_uniform_alr_epochs + 2):
        print("  WARNING: W collapsed early (W_top1_mean>%.2f). Re-applying conservative measures: increasing tau and forcing uniform W next epoch." % w_top1_alert_thresh)
        # increase tau to soften more
        for _, mm in model.named_modules():
            if mm.__class__.__name__.lower().startswith("costattention") and hasattr(mm, "tau"):
                mm.tau = float(max(mm.tau * 1.5, initial_tau))
        # enforce that the next ALR epoch uses uniform W by bumping counter
        force_uniform_alr_epochs = max(force_uniform_alr_epochs, ae + 2)
        # optionally reduce costatt lr further for stability
        opt = make_optimizer_with_costatt_lr(model, base_lr=lr_base, head_lr=lr_head, costatt_lr=1e-6)

print("\nConservative ALR training complete. Inspect W stats and accuracies; if W still collapses, increase force_uniform_alr_epochs or lambda_ent and re-run.")


Device: cuda


RuntimeError: The size of tensor a (3) must match the size of tensor b (512) at non-singleton dimension 1

In [9]:
from losses.alr import ALRLoss

# ALR loss (module version with diagnostics)
alr_loss_fn = ALRLoss(class_weights=None, reduction="mean")


In [10]:
# Continue conservative ALR training for a few epochs (run after optimizer creation)
import time, torch

# Settings for continuation
continue_epochs = 5
force_uniform_epochs = 2         # force uniform W for first N continuation epochs
lambda_ent = 0.2                 # entropy regularizer
alr_scale = float(100)           # adjust if you changed num_classes
alert_thresh = 0.5               # if W_top1_mean > alert_thresh -> take action
max_val_batches = 50

# helper (reuse normalize_outputs and W_stats from earlier cell; redefine if not present)
def normalize_outputs(outputs):
    logits=None; W=None
    if isinstance(outputs, dict):
        logits = outputs.get("logits", None)
        if logits is None and "probs" in outputs:
            probs = outputs["probs"].clamp(1e-6,1-1e-6)
            logits = torch.log(probs/(1-probs))
        W = outputs.get("W", None)
    elif torch.is_tensor(outputs):
        logits = outputs
    logits = logits.to(device).float()
    if W is not None:
        W = W.to(device).float()
        if not torch.isfinite(W).all(): W = torch.ones_like(logits)
        W = W.clamp(min=0.0)
        rs = W.sum(dim=1, keepdim=True)
        zero_rows = rs == 0
        if zero_rows.any():
            W[zero_rows.expand_as(W)] = 1.0
            rs = W.sum(dim=1, keepdim=True)
        W = W / (rs + 1e-12)
        if not torch.isfinite(W).all() or (W.abs().max() < 1e-12):
            W = torch.ones_like(logits)/float(logits.size(1))
    return logits, W, outputs

def W_stats(W):
    if W is None: return {}
    t1 = float(W.max(dim=1).values.mean())
    ent = float((-(W*(W.clamp(1e-12).log())).sum(dim=1)).mean())
    return {'W_mean': float(W.mean()), 'W_top1_mean': t1, 'W_entropy_mean': ent}

# training and eval loops (small, reuse your CE/ALR functions if present)
ce_loss_fn = torch.nn.CrossEntropyLoss()
# assume alr_loss_fn exists (ALRLoss instance). If not, re-create: alr_loss_fn = ALRLoss(class_weights=None, reduction='mean')

def train_one_epoch_alr(model, opt, loader, force_uniform=False, lambda_ent=0.0, alr_scale=1.0):
    model.train()
    running_loss = 0.0
    running_correct = 0
    seen = 0
    for i,(x,y) in enumerate(loader, 1):
        x = x.to(device); y = y.to(device)
        opt.zero_grad()
        outputs = model(x)
        logits, W, _ = normalize_outputs(outputs)
        if force_uniform:
            W_use = torch.ones_like(logits, device=logits.device) / float(logits.size(1))
        else:
            W_use = W
        raw = alr_loss_fn(logits, y, W_use)
        if isinstance(raw, (tuple, list)):
            raw = raw[0]
        loss = raw * alr_scale
        if lambda_ent > 0.0:
            ent = (-(W_use * (W_use.clamp(1e-12).log())).sum(dim=1)).mean()
            loss = loss - lambda_ent * ent
        loss.backward()
        opt.step()
        preds = logits.argmax(dim=1)
        running_correct += int((preds == y.view(-1)).sum().item())
        running_loss += float(loss.item()) * x.size(0)
        seen += x.size(0)
    return running_loss / max(1, seen), running_correct / max(1, seen)

def eval_alr(model, loader, alr_scale=1.0, max_batches=50):
    model.eval()
    loss_sum=0.0; correct=0; n=0
    with torch.no_grad():
        for i,(x,y) in enumerate(loader):
            x = x.to(device); y = y.to(device)
            outputs = model(x)
            logits, W, _ = normalize_outputs(outputs)
            raw = alr_loss_fn(logits, y, W)
            if isinstance(raw,(tuple,list)): raw = raw[0]
            l = float(raw.item()) * alr_scale
            loss_sum += l * x.size(0)
            preds = logits.argmax(dim=1)
            correct += int((preds == y.view(-1)).sum().item())
            n += x.size(0)
            if i >= max_batches: break
    return loss_sum / max(1,n), correct / max(1,n)

# Run continuation
for epoch in range(1, continue_epochs+1):
    print(f"\n=== Continue ALR epoch {epoch}/{continue_epochs} ===")
    force_uniform = (epoch <= force_uniform_epochs)
    t0=time.time()
    tr_loss, tr_acc = train_one_epoch_alr(model, opt, train_loader, force_uniform=force_uniform, lambda_ent=lambda_ent, alr_scale=alr_scale)
    val_loss, val_acc = eval_alr(model, val_loader, alr_scale=alr_scale)
    # sample W stats on a val batch
    with torch.no_grad():
        _, W_sample, _ = normalize_outputs(model(next(iter(val_loader))[0].to(device)))
    stats = W_stats(W_sample)
    print(f"train_loss {tr_loss:.4f} train_acc {tr_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | time {(time.time()-t0):.1f}s")
    print(" W stats:", stats)
    # safety: if W collapsed, strengthen defense
    if stats.get('W_top1_mean', 0.0) > alert_thresh:
        print("  ALERT: W_top1_mean > %.2f â€” increasing cost-att LR suppression and forcing uniform W next epoch." % alert_thresh)
        # set optimizer cost-att lr to very small by rebuilding opt
        try:
            opt = make_optimizer_with_costatt_lr(model, base_lr=lr_base, head_lr=lr_head, costatt_lr=1e-6, weight_decay=weight_decay)
            force_uniform_epochs = max(force_uniform_epochs, epoch+1)
        except Exception as e:
            print("  Failed to rebuild optimizer for safety:", e)



=== Continue ALR epoch 1/5 ===
train_loss 73.6660 train_acc 0.0102 | val_loss 74.2384 val_acc 0.0101 | time 22.8s
 W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.01275566965341568, 'W_entropy_mean': 4.598027229309082}

=== Continue ALR epoch 2/5 ===
train_loss 73.6689 train_acc 0.0101 | val_loss 74.1307 val_acc 0.0099 | time 23.3s
 W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.012720715254545212, 'W_entropy_mean': 4.5982513427734375}

=== Continue ALR epoch 3/5 ===
train_loss 73.7928 train_acc 0.0099 | val_loss 74.1156 val_acc 0.0099 | time 24.9s
 W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.012717755511403084, 'W_entropy_mean': 4.598291873931885}

=== Continue ALR epoch 4/5 ===
train_loss 73.7745 train_acc 0.0098 | val_loss 74.1295 val_acc 0.0102 | time 24.7s
 W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.012717355042696, 'W_entropy_mean': 4.5982537269592285}

=== Continue ALR epoch 5/5 ===
train_loss 73.7902 train_acc 0.0104 | val_

In [11]:
# Run in notebook to inspect raw ALR scale for one batch
import torch, torch.nn.functional as F
model.eval()
x,y = next(iter(train_loader))
x = x.to(device); y = y.to(device)
with torch.no_grad():
    out = model(x)
logits = out['logits']
W = out.get('W', torch.ones_like(logits)/logits.size(1))
probs = torch.sigmoid(logits).clamp(1e-8, 1-1e-8)

# BCE per entry and raw ALR
B,K = logits.shape
bce_per_entry = -( (y.view(-1,1)==torch.arange(K, device=device)).float() * torch.log(probs) +
                   (1 - (y.view(-1,1)==torch.arange(K, device=device)).float()) * torch.log(1-probs) )
# if y is one-hot multi-hot adapt above; this approximates single-label BCE matrix
raw_alr = (W * bce_per_entry).sum(dim=1)   # [B]
print("raw ALR mean/std:", float(raw_alr.mean()), float(raw_alr.std()))
print("bce_per_entry mean/std:", float(bce_per_entry.mean()), float(bce_per_entry.std()))
print("W top1 mean:", float(W.max(dim=1).values.mean()))


raw ALR mean/std: 0.7393078804016113 0.031193485483527184
bce_per_entry mean/std: 0.7381960153579712 0.2565256357192993
W top1 mean: 0.012673519551753998


In [15]:
opt = make_optimizer_with_costatt_lr(model, base_lr=0.001, head_lr=0.01, costatt_lr=1e-6)


# Combined CE + small ALR continuation (safe)
alpha_alr = 0.1          # weight for ALR term (try 0.05-0.2)
alr_scale = 1.0          # do NOT multiply by 100; keep 1.0 or small
lambda_ent = 0.1         # keep small entropy reg
epochs = 5

# If your optimizer still uses head_lr=0.1, rebuild with smaller head LR:
opt = make_optimizer_with_costatt_lr(model, base_lr=0.001, head_lr=0.01, costatt_lr=1e-6, weight_decay=1e-4)
print("Recreated safer optimizer groups:", [(len(g['params']), g['lr']) for g in opt.param_groups])

ce_loss_fn = torch.nn.CrossEntropyLoss()

for ep in range(epochs):
    model.train()
    tloss = 0.0; tacc = 0; n=0
    for x,y in train_loader:
        x = x.to(device); y = y.to(device)
        opt.zero_grad()
        out = model(x)
        logits, W, _ = normalize_outputs(out)
        # CE
        loss_ce = ce_loss_fn(logits, y.view(-1))
        # ALR (raw)
        raw = alr_loss_fn(logits, y, W)
        if isinstance(raw, (tuple,list)): raw = raw[0]
        loss_alr = raw * alr_scale
        # entropy reg
        ent = (-(W * (W.clamp(1e-12).log())).sum(dim=1)).mean()
        loss = (1.0 - alpha_alr) * loss_ce + alpha_alr * loss_alr - 0.0 * lambda_ent * ent
        loss.backward()
        opt.step()

        preds = logits.argmax(dim=1)
        tacc += int((preds == y.view(-1)).sum().item())
        tloss += float(loss.item()) * x.size(0)
        n += x.size(0)
    print(f"Epoch {ep+1}/{epochs} | train_loss {tloss/n:.4f} train_acc {tacc/n:.4f}")
    # quick eval
    model.eval()
    with torch.no_grad():
        out = model(next(iter(val_loader))[0].to(device))
        _, Wv, _ = normalize_outputs(out)
    print("  sample W stats:", W_stats(Wv))


RuntimeError: The size of tensor a (3) must match the size of tensor b (512) at non-singleton dimension 1

In [16]:
def make_optimizer_with_costatt_lr(model, base_lr=0.01, head_lr=0.1, costatt_lr=1e-6, weight_decay=1e-4):
    base_params, head_params, conv_params = [], [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if "classifier" in name:
            head_params.append(p)
        else:
            base_params.append(p)
        if "cost_att.conv1x1" in name:  # or whatever your module name is
            conv_params.append(p)
    base_params = [p for p in base_params if p not in conv_params]
    head_params = [p for p in head_params if p not in conv_params]
    costatt_params = conv_params
    return torch.optim.SGD(
        [
            {"params": base_params, "lr": base_lr, "weight_decay": weight_decay},
            {"params": head_params, "lr": head_lr, "weight_decay": weight_decay},
            {"params": costatt_params, "lr": costatt_lr, "weight_decay": 0.0},  # cost-att safe
        ],
        momentum=0.9,
        nesterov=True,
    )


In [17]:
# ALR-only continuation (no big scaling)
alr_scale = 1.0
lambda_ent = 0.2
epochs = 5
opt = make_optimizer_with_costatt_lr(model, base_lr=0.001, head_lr=0.01, costatt_lr=1e-6, weight_decay=1e-4)

for ep in range(epochs):
    model.train()
    tloss = 0.0; tacc = 0; n=0
    for x,y in train_loader:
        x=x.to(device); y=y.to(device)
        opt.zero_grad()
        out = model(x)
        logits, W, _ = normalize_outputs(out)
        raw = alr_loss_fn(logits, y, W)
        if isinstance(raw,(list,tuple)): raw = raw[0]
        loss = raw * alr_scale - lambda_ent * (-(W*(W.clamp(1e-12).log())).sum(dim=1)).mean()
        loss.backward()
        opt.step()
        preds = logits.argmax(dim=1)
        tacc += int((preds == y.view(-1)).sum().item())
        tloss += float(loss.item()) * x.size(0)
        n += x.size(0)
    print(f"ALR epoch {ep+1}/{epochs} | train_loss {tloss/n:.4f} train_acc {tacc/n:.4f}")
    # check W stats
    with torch.no_grad():
        _, Wv, _ = normalize_outputs(model(next(iter(val_loader))[0].to(device)))
    print("  sample W stats:", W_stats(Wv))


RuntimeError: The size of tensor a (512) must match the size of tensor b (3) at non-singleton dimension 1

In [18]:
# ===== Fix helper + run CE+small-ALR continuation =====
import time, torch, torch.nn as nn

# Identity-safe optimizer creator that accepts weight_decay
def make_optimizer_with_costatt_lr(model, base_lr=1e-3, head_lr=1e-2, costatt_lr=None, weight_decay=1e-4):
    base_params = []
    head_params = []
    # collect base/head by name heuristics
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if ("classifier" in n) or ("fc" in n) or ("head" in n):
            head_params.append(p)
        else:
            base_params.append(p)

    costatt_params = []
    if costatt_lr is not None:
        conv_params = []
        for name, m in model.named_modules():
            if m.__class__.__name__.lower().startswith("costattention") or hasattr(m, "conv1x1"):
                conv = getattr(m, "conv1x1", None)
                if conv is not None:
                    for p in conv.parameters():
                        conv_params.append(p)
        conv_ids = {id(p) for p in conv_params}
        # filter out conv params by identity (id) to avoid tensor-equality comparisons
        base_params = [p for p in base_params if id(p) not in conv_ids]
        head_params = [p for p in head_params if id(p) not in conv_ids]
        costatt_params = conv_params

    param_groups = []
    if base_params:
        param_groups.append({"params": base_params, "lr": base_lr, "weight_decay": weight_decay})
    if head_params:
        param_groups.append({"params": head_params, "lr": head_lr, "weight_decay": weight_decay})
    if costatt_params:
        # usually no weight decay for small convs; set to 0.0 or keep weight_decay
        param_groups.append({"params": costatt_params, "lr": costatt_lr, "weight_decay": 0.0})

    if not param_groups:
        param_groups = [{"params": model.parameters(), "lr": base_lr, "weight_decay": weight_decay}]

    return torch.optim.SGD(param_groups, momentum=0.9)

# Recreate optimizer with safer LRs for CE+ALR continuation
# (adjust base/head lr if you prefer)
opt = make_optimizer_with_costatt_lr(model, base_lr=1e-3, head_lr=1e-2, costatt_lr=1e-6, weight_decay=1e-4)
print("Optimizer groups:", [(len(g["params"]), g["lr"], g.get("weight_decay")) for g in opt.param_groups])

# === CE + small ALR continuation ===
alpha_alr = 0.1        # weight for ALR term (0 => pure CE; 0.1 => 10% ALR)
alr_scale = 1.0        # do NOT scale by num_classes here
lambda_ent = 0.05      # small entropy regularizer
epochs = 5

ce_loss_fn = nn.CrossEntropyLoss()
# Ensure alr_loss_fn exists; recreate if needed:
try:
    alr_loss_fn
except NameError:
    from losses.alr import ALRLoss
    alr_loss_fn = ALRLoss(class_weights=None, reduction="mean")

# quick helper (reuse normalize_outputs if in scope; else define small one)
def normalize_outputs(outputs):
    logits=None; W=None
    if isinstance(outputs, dict):
        logits = outputs.get("logits", None)
        if logits is None and "probs" in outputs:
            probs = outputs["probs"].clamp(1e-6,1-1e-6)
            logits = torch.log(probs/(1-probs))
        W = outputs.get("W", None)
    elif torch.is_tensor(outputs):
        logits = outputs
    logits = logits.to(device).float()
    if W is not None:
        W = W.to(device).float()
        if not torch.isfinite(W).all(): W = torch.ones_like(logits)
        W = W.clamp(min=0.0)
        rs = W.sum(dim=1, keepdim=True)
        zero_rows = rs == 0
        if zero_rows.any():
            W[zero_rows.expand_as(W)] = 1.0
            rs = W.sum(dim=1, keepdim=True)
        W = W / (rs + 1e-12)
        if not torch.isfinite(W).all() or (W.abs().max() < 1e-12):
            W = torch.ones_like(logits)/float(logits.size(1))
    return logits, W, outputs

def W_stats(W):
    if W is None: return {}
    t1 = float(W.max(dim=1).values.mean())
    ent = float((-(W*(W.clamp(1e-12).log())).sum(dim=1)).mean())
    return {'W_mean': float(W.mean()), 'W_top1_mean': t1, 'W_entropy_mean': ent}

# Training loop
for ep in range(1, epochs+1):
    model.train()
    total_loss=0.0; total_correct=0; total_n=0
    for xb,yb in train_loader:
        xb = xb.to(device); yb = yb.to(device)
        opt.zero_grad()
        out = model(xb)
        logits, W, _ = normalize_outputs(out)
        loss_ce = ce_loss_fn(logits, yb.view(-1))
        raw = alr_loss_fn(logits, yb, W)
        if isinstance(raw, (tuple,list)): raw = raw[0]
        loss_alr = raw * alr_scale
        ent = (-(W * (W.clamp(1e-12).log())).sum(dim=1)).mean()
        loss = (1.0 - alpha_alr) * loss_ce + alpha_alr * loss_alr - lambda_ent * ent
        loss.backward()
        opt.step()
        preds = logits.argmax(dim=1)
        total_correct += int((preds == yb.view(-1)).sum().item())
        total_loss += float(loss.item()) * xb.size(0)
        total_n += xb.size(0)
    # eval quick
    model.eval()
    with torch.no_grad():
        outv = model(next(iter(val_loader))[0].to(device))
        lv, Wv, _ = normalize_outputs(outv)
        val_ce = float(ce_loss_fn(lv, next(iter(val_loader))[1].to(device).view(-1)))
    print(f"Epoch {ep}/{epochs} | train_loss {total_loss/total_n:.4f} train_acc {total_correct/total_n:.4f} | sample_val_ce {val_ce:.4f}")
    print("  W stats:", W_stats(Wv))

print("Done continuation. If acc improves slowly, gradually increase alpha_alr or decrease costatt_lr.")


Optimizer groups: [(108, 0.001, 0.0001), (6, 0.01, 0.0001), (2, 1e-06, 0.0)]
Epoch 1/5 | train_loss 3.6987 train_acc 0.0449 | sample_val_ce 4.0024
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.012885664589703083, 'W_entropy_mean': 4.597731590270996}
Epoch 2/5 | train_loss 3.2088 train_acc 0.1110 | sample_val_ce 3.6298
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.012905037961900234, 'W_entropy_mean': 4.597824573516846}
Epoch 3/5 | train_loss 2.9220 train_acc 0.1626 | sample_val_ce 3.2485
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.012911481782793999, 'W_entropy_mean': 4.598141670227051}
Epoch 4/5 | train_loss 2.6861 train_acc 0.2113 | sample_val_ce 3.2091
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.013038859702646732, 'W_entropy_mean': 4.59785795211792}
Epoch 5/5 | train_loss 2.4785 train_acc 0.2518 | sample_val_ce 2.9636
  W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.01300949975848198, 'W_entropy_mean': 4.59

In [20]:
# Continue training with cautious schedule + checkpointing
import time, copy, torch
from torch.optim.lr_scheduler import CosineAnnealingLR

# Hyperparams for continuation
extra_epochs = 20
alpha_start = 0.10       # initial CE/ALR blend weight (ALR fraction)
alpha_end = 0.20         # target ALR weight after ramp
alpha_ramp_epochs = 10   # ramp ALR weight over this many epochs
alr_scale = 1.0
lambda_ent = 0.05        # small entropy reg
enable_costatt_epoch = 10   # at this epoch (1-indexed in continuation) increase costatt_lr
new_costatt_lr = 1e-5      # new small lr for cost-att once we allow it slightly adapt
# scheduler params
T_max = extra_epochs
# checkpoint
best_val_ce = float('inf')
best_state = None

# Recreate optimizer with the current conservative costatt_lr (should be 1e-6 now)
opt = make_optimizer_with_costatt_lr(model, base_lr=1e-3, head_lr=1e-2, costatt_lr=1e-6, weight_decay=1e-4)
# identify param group indices for scheduler (we'll attach scheduler to groups 0..n-1 but PyTorch handles all groups)
scheduler = CosineAnnealingLR(opt, T_max=T_max, eta_min=1e-6)

def compute_val_ce_sample(model):
    model.eval()
    with torch.no_grad():
        xb,yb = next(iter(val_loader))
        xb = xb.to(device); yb = yb.to(device)
        out = model(xb)
        logits, W, _ = normalize_outputs(out)
        ce = torch.nn.functional.cross_entropy(logits, yb.view(-1)).item()
    return ce, W

for e in range(1, extra_epochs+1):
    # linear ramp for alpha
    if e <= alpha_ramp_epochs:
        alpha_alr = alpha_start + (alpha_end - alpha_start) * ( (e-1) / max(1,(alpha_ramp_epochs-1)) )
    else:
        alpha_alr = alpha_end

    # optionally increase costatt_lr at scheduled epoch by rebuilding optimizer
    if e == enable_costatt_epoch:
        print(">>> Increasing cost-att lr to", new_costatt_lr, " (rebuilding optimizer )")
        opt = make_optimizer_with_costatt_lr(model, base_lr=1e-3, head_lr=1e-2, costatt_lr=new_costatt_lr, weight_decay=1e-4)
        scheduler = CosineAnnealingLR(opt, T_max=max(1, extra_epochs-e+1), eta_min=1e-6)

    # one epoch training (combined CE + ALR)
    model.train()
    total_loss = 0.0; total_correct = 0; total_n = 0
    t0 = time.time()
    for xb, yb in train_loader:
        xb = xb.to(device); yb = yb.to(device)
        opt.zero_grad()
        out = model(xb)
        logits, W, _ = normalize_outputs(out)
        loss_ce = torch.nn.functional.cross_entropy(logits, yb.view(-1))
        raw = alr_loss_fn(logits, yb, W)
        if isinstance(raw, (tuple,list)): raw = raw[0]
        loss_alr = raw * alr_scale
        ent = (-(W * (W.clamp(1e-12).log())).sum(dim=1)).mean()
        loss = (1.0 - alpha_alr) * loss_ce + alpha_alr * loss_alr - lambda_ent * ent
        loss.backward()
        opt.step()
        total_loss += float(loss.item()) * xb.size(0)
        preds = logits.argmax(dim=1)
        total_correct += int((preds == yb.view(-1)).sum().item())
        total_n += xb.size(0)

    # scheduler step (per-epoch)
    try:
        scheduler.step()
    except Exception:
        pass

    # eval CE on small sample as proxy and compute W stats
    val_ce_sample, W_sample = compute_val_ce_sample(model)
    stats = W_stats(W_sample)
    epoch_time = time.time() - t0
    train_loss = total_loss / max(1,total_n)
    train_acc  = total_correct / max(1,total_n)
    print(f"[cont ep {e}/{extra_epochs}] alpha_alr={alpha_alr:.3f} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | sample_val_ce {val_ce_sample:.4f} | time {epoch_time:.1f}s")
    print("   W stats:", stats)

    # checkpoint best sample val CE
    if val_ce_sample < best_val_ce:
        best_val_ce = val_ce_sample
        best_state = copy.deepcopy(model.state_dict())
        print("   -> New best sample_val_ce:", best_val_ce)

    # safety: if W collapses too quickly, roll back cost-att aggressiveness
    if stats.get('W_top1_mean', 0.0) > 0.5 and e < enable_costatt_epoch + 3:
        print("   !!! W collapsed early (top1>0.5). Reverting cost-att lr to tiny and forcing uniform for next epoch")
        opt = make_optimizer_with_costatt_lr(model, base_lr=1e-3, head_lr=1e-2, costatt_lr=1e-6, weight_decay=1e-4)
        # optionally reduce alpha_alr for next epoch (not implemented here; you can restart loop)

# after loop: restore best if you want
if best_state is not None:
    model.load_state_dict(best_state)
    print("Restored best sample_val_ce checkpoint state.")

print("Continuation finished. Final W stats:", W_stats(W_sample))


[cont ep 1/20] alpha_alr=0.100 | train_loss 1.7586 train_acc 0.4076 | sample_val_ce 2.5989 | time 21.6s
   W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.01319112628698349, 'W_entropy_mean': 4.597421169281006}
   -> New best sample_val_ce: 2.598945379257202
[cont ep 2/20] alpha_alr=0.111 | train_loss 1.5431 train_acc 0.4548 | sample_val_ce 2.5885 | time 22.0s
   W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.013182642869651318, 'W_entropy_mean': 4.597250938415527}
   -> New best sample_val_ce: 2.588477849960327
[cont ep 3/20] alpha_alr=0.122 | train_loss 1.3150 train_acc 0.5093 | sample_val_ce 2.6740 | time 21.9s
   W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.013111312873661518, 'W_entropy_mean': 4.597574710845947}
[cont ep 4/20] alpha_alr=0.133 | train_loss 1.0714 train_acc 0.5700 | sample_val_ce 2.8547 | time 22.4s
   W stats: {'W_mean': 0.009999999776482582, 'W_top1_mean': 0.01308935135602951, 'W_entropy_mean': 4.597599983215332}
[cont ep 5/20]

In [21]:
import torch, time
import torch.nn.functional as F

def evaluate_ce(model, val_loader, device="cuda"):
    """Evaluate with plain CE on full val set."""
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            logits = outputs["logits"] if isinstance(outputs, dict) else outputs
            loss = F.cross_entropy(logits, y)
            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(1)
            total_correct += (preds == y).sum().item()
            total_samples += x.size(0)
    return total_loss / total_samples, total_correct / total_samples

# === Continuation training with full val eval ===
extra_epochs = 5
alpha_start, alpha_end, alpha_ramp_epochs = 0.1, 0.2, 5
lambda_ent = 0.1
alr_scale = 1.0

for ep in range(1, extra_epochs + 1):
    # Ramp ALR weight
    alpha_alr = min(alpha_end, alpha_start + (alpha_end - alpha_start) * (ep / alpha_ramp_epochs))

    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    t0 = time.time()

    for step, (x, y) in enumerate(train_loader, 1):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()

        outputs = model(x)
        logits = outputs["logits"]
        W = outputs.get("W", None)

        ce = F.cross_entropy(logits, y)

        if W is not None:
            alr, _ = alr_loss_fn(logits, y, W)
            loss = (1 - alpha_alr) * ce + alpha_alr * alr
        else:
            loss = ce

        loss.backward()
        opt.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total_samples += x.size(0)

        if step % 200 == 0:
            print(f"  step {step} | avg_loss {total_loss/total_samples:.4f} | avg_acc {total_correct/total_samples:.4f}")

    tr_loss, tr_acc = total_loss / total_samples, total_correct / total_samples
    val_loss, val_acc = evaluate_ce(model, val_loader, device=device)

    W_stats_dict = W_stats(outputs["W"]) if "W" in outputs else {}

    print(f"Epoch {ep}/{extra_epochs} | "
          f"train_loss {tr_loss:.4f} train_acc {tr_acc:.4f} | "
          f"val_loss {val_loss:.4f} val_acc {val_acc:.4f}")
    print("  W stats:", {k: float(v) for k, v in W_stats_dict.items()})


  step 200 | avg_loss 1.5239 | avg_acc 0.5164
Epoch 1/5 | train_loss 1.5230 train_acc 0.5170 | val_loss 2.6366 val_acc 0.3390
  W stats: {'W_mean': 0.010000000707805157, 'W_top1_mean': 0.013088570907711983, 'W_entropy_mean': 4.597934246063232}
  step 200 | avg_loss 1.4668 | avg_acc 0.5266
Epoch 2/5 | train_loss 1.4592 train_acc 0.5257 | val_loss 2.6064 val_acc 0.3417
  W stats: {'W_mean': 0.010000000707805157, 'W_top1_mean': 0.013099893927574158, 'W_entropy_mean': 4.5977067947387695}
  step 200 | avg_loss 1.3991 | avg_acc 0.5343
Epoch 3/5 | train_loss 1.4024 train_acc 0.5342 | val_loss 2.5936 val_acc 0.3462
  W stats: {'W_mean': 0.010000000707805157, 'W_top1_mean': 0.013130143284797668, 'W_entropy_mean': 4.597745895385742}
  step 200 | avg_loss 1.3486 | avg_acc 0.5425
Epoch 4/5 | train_loss 1.3469 train_acc 0.5431 | val_loss 2.5815 val_acc 0.3487
  W stats: {'W_mean': 0.010000000707805157, 'W_top1_mean': 0.013044561259448528, 'W_entropy_mean': 4.597859859466553}
  step 200 | avg_loss 1

KeyboardInterrupt: 