In [1]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingWarmRestarts
sys.path.append('../src')
from modules import (
                    paths,
                    dataset,
                    model,
                    utils,
                    acdc,
                    train
                    )
from torchvision.transforms import v2
from torch.optim import AdamW

/home/user/cv-proj2/notebooks/../src/modules/paths.py


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform_train = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),

])

transform_valid = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [3]:
toy=False

In [4]:
import importlib
importlib.reload(dataset)
if toy == True:
    print("laoding toy datasets")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train, tiny=True, stop=6)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid, tiny=True, stop=2)

else:
    print("loading full dataet")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid)

train_dataset = dataset.TorchDatasetWrapper(train_dataset, transform=transform_train)
val_dataset = dataset.TorchDatasetWrapper(val_dataset, transform=transform_valid)
print("train:\n"+str(train_dataset))
print("validation:\n"+str(val_dataset))


loading full dataet
Loading animal dataset from /home/user/cv-proj2/notebooks/../data/animal_train.pkl
Loading animal dataset from /home/user/cv-proj2/notebooks/../data/animal_valid.pkl
train:
Dataset({
    features: ['image', 'label'],
    num_rows: 29000
})
validation:
Dataset({
    features: ['image', 'label'],
    num_rows: 2900
})


In [5]:
batch_size = 5 if toy else 4096 

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

Train the model

In [6]:
importlib.reload(model)
config = {
    "patch_size": 8,           # Kept small for fine-grained patches
    "hidden_size": 64,          # Increased from 48 (better representation)
    "num_hidden_layers": 6,     # Deeper for pruning flexibility
    "num_attention_heads": 8,   # More heads (head_dim = 64/8 = 8)
    "intermediate_size": 4 * 64,# Standard FFN scaling
    "hidden_dropout_prob": 0.2, # Mild dropout for regularization
    "attention_probs_dropout_prob": 0.2,
    "initializer_range": 0.02,
    "image_size": 64,
    "num_classes": 58,
    "num_channels": 3,
    "qkv_bias": True,           # Keep bias for now (can prune later)
}

importlib.reload(train)

class SoftTargetCrossEntropy(nn.Module):
    """Cross-entropy loss compatible with Mixup/Cutmix soft labels"""
    def __init__(self):
        super().__init__()
    
    def forward(self, x, target):
        # x = model outputs (logits)
        # target = mixed labels (probability distributions)
        loss = torch.sum(-target * F.log_softmax(x, dim=1), dim=1)
        return loss.mean()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vit = model.ViT(config).to(device)

num_epochs = 500
warmup_epochs = 20
base_lr = 3e-4
min_lr = 1e-6
weight_decay = 0.05  # For AdamW optimizer
label_smoothing = 0.1  # For cross-entropy
patience = 50



optimizer = AdamW(vit.parameters(),
                  lr=base_lr,
                  weight_decay = weight_decay,
                  betas=(0.9, 0.98),
                  eps = 1e-6      
                  )

# Linear warmup for 30 epochs (0 → base_lr)
warmup = LinearLR(
    optimizer,
    start_factor=1e-6,  # Near-zero initial LR
    end_factor=1.0,     # Full LR after warmup
    total_iters=warmup_epochs,
)

cosine = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=num_epochs - warmup_epochs,  
    eta_min=min_lr,
)

# Combine them
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_epochs], 
)

mixup_fn = v2.MixUp(
    alpha=1.0,          
    num_classes=58
)

trainer = train.Trainer(model=vit,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        optimizer=optimizer,
                        criterion=SoftTargetCrossEntropy(),
                        val_criterion=nn.CrossEntropyLoss(),
                        scheduler=scheduler,
                        device = device,
                        writer=torch.utils.tensorboard.SummaryWriter(log_dir=paths.logs),
                        scaler=torch.amp.GradScaler(),
                        num_epochs=num_epochs,
                        log_interval=50,
                        model_dir=paths.chekpoints,
                        mixup_fn=mixup_fn,
                        early_stop_patience=20,
                        model_name="vit1",
                        resume=True
                        )


Resuming training from checkpoint: /home/user/cv-proj2/notebooks/../checkpoints/checkpoint.pth
Checkpoint loaded. Resuming from epoch 494 with best accuracy 26.97%.


In [7]:
# acc = trainer.train()

In [8]:
importlib.reload(acdc)
importlib.reload(dataset)
acdc_dataset = dataset.ContrastiveWrapper(val_dataset, coarse_labels)

acdc_loader = DataLoader(
    acdc_dataset,
    batch_size=50,
    shuffle=False,
    # num_workers=1,  
    # pin_memory=False,
    # prefetch_factor=1,
    collate_fn=dataset.contrastive_collate_fn,
    # persistent_workers=False
)
clean_batch, corrupted_batch = next(iter(acdc_loader))


Indexing dataset by class for contrastive sampling...
Indexing complete.


In [9]:
importlib.reload(acdc)
importlib.reload(utils)
run_acdc = False
if run_acdc:
    circuits = {}
    for tau in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
        circuits[str(tau)] = acdc.run_ACDC_optimized(vit, tau, acdc_loader, device=device)
else:
    cirtcuits_paths = paths.chekpoints / "circuits.pkl"
    import pickle
    circuits = pickle.load(open(cirtcuits_paths, "rb"))

In [10]:
acdc.test_taus(circuits, val_loader, coarse_labels, config, device)

Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.


                                                           

Original model accuracy: 60.06896551724138
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 0 unused MLP layers...
Pruning 0 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.0001


                                                           

Accuracy: 60.06896551724138 | Avg Time over 3 runs: 3.1808121744791666 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 0 unused MLP layers...
Pruning 0 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.0005


                                                           

Accuracy: 60.06896551724138 | Avg Time over 3 runs: 3.3248758138020835 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 0 unused MLP layers...
Pruning 7 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.001


                                                           

Accuracy: 60.172413793103445 | Avg Time over 3 runs: 3.2313787434895835 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 0 unused MLP layers...
Pruning 26 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.005


                                                           

Accuracy: 49.13793103448276 | Avg Time over 3 runs: 3.204301188151042 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 0 unused MLP layers...
Pruning 41 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.01


                                                           

Accuracy: 31.82758620689655 | Avg Time over 3 runs: 3.1933037109374998 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 0 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.05


                                                           

Accuracy: 6.896551724137931 | Avg Time over 3 runs: 3.2307185058593753 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 3 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.1


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.20207958984375 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 3 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.2


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.2388056640625003 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 3 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.3


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.2212481282552083 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 4 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.4


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.0961822102864587 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 4 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.5


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.149993408203125 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 4 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.6


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.08873828125 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 4 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.7


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.1799784342447914 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 4 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.8


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.121033121744792 seconds
Checkpoint loaded. Resuming from epoch 397 with best accuracy 26.97%.
Pruning 4 unused MLP layers...
Pruning 48 unused attention heads...

Model pruned. Ready for retraining on the circuit.
Testing model with tau 0.9


                                                           

Accuracy: 8.620689655172415 | Avg Time over 3 runs: 3.0938041178385416 seconds




Train the new model

In [11]:
importlib.reload(model)
vit.classifier = nn.Linear(config["hidden_size"], len(coarse_labels))
vit.apply(vit._init_weights)
vit = vit.to(device)

In [12]:
optimizer = AdamW(vit.parameters(),
                  lr=base_lr,
                  weight_decay = weight_decay,
                  betas=(0.9, 0.98),
                  eps = 1e-6      
                  )

# Linear warmup for 30 epochs (0 → base_lr)
warmup = LinearLR(
    optimizer,
    start_factor=1e-6,  # Near-zero initial LR
    end_factor=1.0,     # Full LR after warmup
    total_iters=warmup_epochs,
)

cosine = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=num_epochs - warmup_epochs,  
    eta_min=min_lr,
)

# Combine them
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_epochs], 
)

mixup_fn = v2.MixUp(
    alpha=1.0,          
    num_classes=58
)

trainer = train.Trainer(model=vit,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        optimizer=optimizer,
                        criterion=SoftTargetCrossEntropy(),
                        val_criterion=nn.CrossEntropyLoss(),
                        scheduler=scheduler,
                        device = device,
                        writer=torch.utils.tensorboard.SummaryWriter(log_dir=paths.logs),
                        scaler=torch.amp.GradScaler(),
                        num_epochs=num_epochs,
                        log_interval=50,
                        model_dir=paths.chekpoints,
                        mixup_fn=mixup_fn,
                        early_stop_patience=20,
                        model_name="vit1",
                        resume=True
                        )


Resuming training from checkpoint: /home/user/cv-proj2/notebooks/../checkpoints/checkpoint.pth


RuntimeError: Error(s) in loading state_dict for ViT:
	size mismatch for classifier.weight: copying a param with shape torch.Size([58, 64]) from checkpoint, the shape in current model is torch.Size([6, 64]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([58]) from checkpoint, the shape in current model is torch.Size([6]).