In [None]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingWarmRestarts
sys.path.append('../src')
from modules import (
                    paths,
                    dataset,
                    model,
                    utils,
                    acdc,
                    train
                    )
from torchvision.transforms import v2
from torch.optim import AdamW

In [None]:
transform_train = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),

])

transform_valid = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
toy=True

In [None]:
import importlib
importlib.reload(dataset)
if toy == True:
    print("laoding toy datasets")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train, tiny=True, stop=6)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid, tiny=True, stop=2)

else:
    print("loading full dataet")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid)

train_dataset = dataset.TorchDatasetWrapper(train_dataset, transform=transform_train)
val_dataset = dataset.TorchDatasetWrapper(val_dataset, transform=transform_valid)
print("train:\n"+str(train_dataset))
print("validation:\n"+str(val_dataset))


In [None]:
batch_size = 5 if toy else 4096 

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

Train the model

In [None]:
importlib.reload(model)
config = {
    "patch_size": 8,           # Kept small for fine-grained patches
    "hidden_size": 64,          # Increased from 48 (better representation)
    "num_hidden_layers": 6,     # Deeper for pruning flexibility
    "num_attention_heads": 8,   # More heads (head_dim = 64/8 = 8)
    "intermediate_size": 4 * 64,# Standard FFN scaling
    "hidden_dropout_prob": 0.2, # Mild dropout for regularization
    "attention_probs_dropout_prob": 0.2,
    "initializer_range": 0.02,
    "image_size": 64,
    "num_classes": 58,
    "num_channels": 3,
    "qkv_bias": True,           # Keep bias for now (can prune later)
}

importlib.reload(train)

class SoftTargetCrossEntropy(nn.Module):
    """Cross-entropy loss compatible with Mixup/Cutmix soft labels"""
    def __init__(self):
        super().__init__()
    
    def forward(self, x, target):
        # x = model outputs (logits)
        # target = mixed labels (probability distributions)
        loss = torch.sum(-target * F.log_softmax(x, dim=1), dim=1)
        return loss.mean()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vit = model.ViT(config).to(device)

num_epochs = 500
warmup_epochs = 20
base_lr = 3e-4
min_lr = 1e-6
weight_decay = 0.05  # For AdamW optimizer
label_smoothing = 0.1  # For cross-entropy
patience = 50



optimizer = AdamW(vit.parameters(),
                  lr=base_lr,
                  weight_decay = weight_decay,
                  betas=(0.9, 0.98),
                  eps = 1e-6      
                  )

# Linear warmup for 30 epochs (0 → base_lr)
warmup = LinearLR(
    optimizer,
    start_factor=1e-6,  # Near-zero initial LR
    end_factor=1.0,     # Full LR after warmup
    total_iters=warmup_epochs,
)

cosine = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=num_epochs - warmup_epochs,  
    eta_min=min_lr,
)

# Combine them
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_epochs], 
)

mixup_fn = v2.MixUp(
    alpha=1.0,          
    num_classes=58
)

trainer = train.Trainer(model=vit,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        optimizer=optimizer,
                        criterion=SoftTargetCrossEntropy(),
                        val_criterion=nn.CrossEntropyLoss(),
                        scheduler=scheduler,
                        device = device,
                        writer=torch.utils.tensorboard.SummaryWriter(log_dir=paths.logs),
                        scaler=torch.amp.GradScaler(),
                        num_epochs=num_epochs,
                        log_interval=50,
                        model_dir=paths.chekpoints,
                        mixup_fn=mixup_fn,
                        early_stop_patience=20,
                        model_name="vit1",
                        resume=True
                        )


In [None]:
# acc = trainer.train()

In [None]:
importlib.reload(acdc)
importlib.reload(dataset)
acdc_dataset = dataset.ContrastiveWrapper(val_dataset, coarse_labels)

acdc_loader = DataLoader(
    acdc_dataset,
    batch_size=50,
    shuffle=False,
    # num_workers=1,  
    # pin_memory=False,
    # prefetch_factor=1,
    collate_fn=dataset.contrastive_collate_fn,
    # persistent_workers=False
)
clean_batch, corrupted_batch = next(iter(acdc_loader))


In [None]:
importlib.reload(acdc)
importlib.reload(utils)
run_acdc = False
if run_acdc:
    circuits = {}
    for tau in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
        circuits[str(tau)] = acdc.run_ACDC(vit, tau, acdc_loader, device=device)
else:
    cirtcuits_paths = paths.chekpoints / "circuits.pkl"
    import pickle
    circuits = pickle.load(open(cirtcuits_paths, "rb"))

In [None]:
from tqdm import tqdm
def get_accuracy_on_coarse_labels(model, datalaoder, device, coarse_model=False) -> float:
    fine_label_to_coarse = {fl:cli for cli, cl in enumerate(coarse_labels.values()) for fl in cl}
    coarse_to_name = {i:cl for i, cl in enumerate(coarse_labels.keys())}

    model.eval()
    correct, total = 0.0, 0

    dataloader_tqdm = tqdm(
        datalaoder, 
        desc=f"[Validation]", 
        leave=False
    )

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(dataloader_tqdm):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            labels.apply_(lambda x: fine_label_to_coarse.get(x, x))

            with torch.amp.autocast(device_type=device.type):
                outputs, _ = model(images)

            _, predicted = outputs.max(1)
            if not coarse_model:
                predicted.apply_(lambda x: fine_label_to_coarse.get(x, x))
            batch_correct = predicted.eq(labels).sum().item()
            batch_total = labels.size(0)

            correct += batch_correct
            total += batch_total
            
    epoch_acc = 100.0 * correct / total
    return epoch_acc

In [None]:
importlib.reload(model)

trainer.load_checkpoint(paths.chekpoints/"checkpoint.pth")
vit = trainer.model
import time
start = time.time()
print(get_accuracy_on_coarse_labels(vit, val_loader, device, coarse_model=False))
end = time.time()
print(f"took: {end-start}")
vit.retrain_circuit(circuits["0.001"])
start = time.time()
print(get_accuracy_on_coarse_labels(vit, val_loader, device, coarse_model=False))
end = time.time()
print(f"took: {end-start}")

In [None]:
importlib.reload(model)
vit.classifier = nn.Linear(config["hidden_size"], len(coarse_labels))
vit.apply(vit._init_weights)
vit = vit.to(device)

class SoftTargetCrossEntropy(nn.Module):
    """Cross-entropy loss compatible with Mixup/Cutmix soft labels"""
    def __init__(self):
        super().__init__()
    
    def forward(self, x, target):
        # x = model outputs (logits)
        # target = mixed labels (probability distributions)
        loss = torch.sum(-target * F.log_softmax(x, dim=1), dim=1)
        return loss.mean()
    
num_epochs = 500
warmup_epochs = 20
base_lr = 3e-4
min_lr = 1e-6
weight_decay = 0.05  
label_smoothing = 0.1  



optimizer = AdamW(vit.parameters(),
                  lr=base_lr,
                  weight_decay = weight_decay,
                  betas=(0.9, 0.98),
                  eps = 1e-6      
                  )

warmup = LinearLR(
    optimizer,
    start_factor=1e-6,  # Near-zero initial LR
    end_factor=1.0,     # Full LR after warmup
    total_iters=warmup_epochs,
)

cosine = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=num_epochs - warmup_epochs,  
    eta_min=min_lr,
)

# Combine them
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_epochs],  
)

mixup_fn = v2.MixUp(
    alpha=1.0,          
    num_classes=58
)

trainer = train.CoarseTrainer(coarse_labels,
                            model=vit,
                            train_loader=train_loader,
                            val_loader=val_loader,
                            optimizer=optimizer,
                            criterion=SoftTargetCrossEntropy(),
                            val_criterion=nn.CrossEntropyLoss(),
                            scheduler=scheduler,
                            device = device,
                            writer=torch.utils.tensorboard.SummaryWriter(log_dir=paths.logs),
                            scaler=torch.amp.GradScaler(),
                            num_epochs=num_epochs,
                            log_interval=50,
                            model_dir=paths.chekpoints,
                            mixup_fn=mixup_fn,
                            early_stop_patience=20,
                            model_name="vit1",
                            resume=False
                            )
acc = trainer.train()