In [2]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingWarmRestarts
sys.path.append('../src')
from modules import (
                    paths,
                    dataset,
                    model,
                    utils,
                    acdc,
                    train
                    )
from torchvision.transforms import v2
from torch.optim import AdamW

/home/lexyo/Dev/cv-proj2/notebooks/../src/modules/paths.py


  from .autonotebook import tqdm as notebook_tqdm
/home/lexyo/.config/matplotlib is not a writable directory
Matplotlib created a temporary cache directory at /tmp/matplotlib-5ke2ctk7 because there was an issue with the default path (/home/lexyo/.config/matplotlib); it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [3]:
transform_train = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),

])

transform_valid = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [40]:
toy=False

In [41]:
import importlib
importlib.reload(dataset)
if toy == True:
    print("laoding toy datasets")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train, tiny=True, stop=6)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid, tiny=True, stop=2)

else:
    print("loading full dataet")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid)

train_dataset = dataset.TorchDatasetWrapper(train_dataset, transform=transform_train)
val_dataset = dataset.TorchDatasetWrapper(val_dataset, transform=transform_valid)
print("train:\n"+str(train_dataset))
print("validation:\n"+str(val_dataset))


loading full dataet
Generating animal dataset...
Loading dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/train.pkl
Generating animal dataset...
Loading dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/valid.pkl
train:
Dataset({
    features: ['image', 'label'],
    num_rows: 29000
})
validation:
Dataset({
    features: ['image', 'label'],
    num_rows: 2900
})


In [21]:
batch_size = 5 if toy else 4096 

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

Train the model

In [None]:
importlib.reload(model)
config = {
    "patch_size": 8,           # Kept small for fine-grained patches
    "hidden_size": 64,          # Increased from 48 (better representation)
    "num_hidden_layers": 6,     # Deeper for pruning flexibility
    "num_attention_heads": 8,   # More heads (head_dim = 64/8 = 8)
    "intermediate_size": 4 * 64,# Standard FFN scaling
    "hidden_dropout_prob": 0.2, # Mild dropout for regularization
    "attention_probs_dropout_prob": 0.2,
    "initializer_range": 0.02,
    "image_size": 64,
    "num_classes": 57,
    "num_channels": 3,
    "qkv_bias": True,           # Keep bias for now (can prune later)
}

importlib.reload(train)

class SoftTargetCrossEntropy(nn.Module):
    """Cross-entropy loss compatible with Mixup/Cutmix soft labels"""
    def __init__(self):
        super().__init__()
    
    def forward(self, x, target):
        # x = model outputs (logits)
        # target = mixed labels (probability distributions)
        loss = torch.sum(-target * F.log_softmax(x, dim=1), dim=1)
        return loss.mean()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vit = model.ViT(config).to(device)

num_epochs = 500
warmup_epochs = 20
base_lr = 3e-4
min_lr = 1e-6
weight_decay = 0.05  # For AdamW optimizer
label_smoothing = 0.1  # For cross-entropy
patience = 50



optimizer = AdamW(vit.parameters(),
                  lr=base_lr,
                  weight_decay = weight_decay,
                  betas=(0.9, 0.98),
                  eps = 1e-6      
                  )

# Linear warmup for 30 epochs (0 → base_lr)
warmup = LinearLR(
    optimizer,
    start_factor=1e-6,  # Near-zero initial LR
    end_factor=1.0,     # Full LR after warmup
    total_iters=warmup_epochs,
)

# Cosine decay for remaining epochs (170)
cosine = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=num_epochs - warmup_epochs,  # 170 epochs per cycle
    eta_min=min_lr,
)

# Combine them
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_epochs],  # Switch after warmup
)

mixup_fn = v2.MixUp(
    alpha=1.0,          # Add CutMix
    num_classes=57
)

trainer = train.Trainer(model=vit,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        optimizer=optimizer,
                        criterion=SoftTargetCrossEntropy(),
                        val_criterion=nn.CrossEntropyLoss(),
                        scheduler=scheduler,
                        device = device,
                        writer=torch.utils.tensorboard.SummaryWriter(log_dir=paths.logs),
                        scaler=torch.amp.GradScaler(),
                        num_epochs=num_epochs,
                        log_interval=50,
                        model_dir=paths.chekpoints,
                        mixup_fn=mixup_fn,
                        early_stop_patience=20,
                        model_name="vit1",
                        resume=True
                        )


Resuming training from checkpoint: /home/lexyo/Dev/cv-proj2/notebooks/../checkpoints/checkpoint.pth
Checkpoint loaded. Resuming from epoch 387 with best accuracy 21.55%.




In [8]:
# acc = trainer.train()

In [9]:
importlib.reload(acdc)
importlib.reload(dataset)
acdc_dataset = dataset.ContrastiveWrapper(val_dataset, coarse_labels)

acdc_loader = DataLoader(
    acdc_dataset,
    batch_size=50,
    shuffle=False,
    # num_workers=1,  
    # pin_memory=False,
    # prefetch_factor=1,
    collate_fn=dataset.contrastive_collate_fn,
    # persistent_workers=False
)
clean_batch, corrupted_batch = next(iter(acdc_loader))


Indexing dataset by class for contrastive sampling...
Indexing complete.


In [10]:
trainer.load_checkpoint(paths.chekpoints/"checkpoint.pth")
vit = trainer.model

Checkpoint loaded. Resuming from epoch 387 with best accuracy 21.55%.


In [27]:
importlib.reload(acdc)
importlib.reload(utils)

# circuits = {}
# for tau in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
#     circuits[str(tau)] = acdc.run_ACDC(vit, tau, acdc_loader, device=device)
circuits = {'0.05': {('embedding.final_output', 'encoder.blocks.0.mlp.final_output'),
  ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.1.mlp.final_output'),
  ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.2.mlp.final_output'),
  ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.3.mlp.final_output')},
 '0.1': {('embedding.final_output', 'encoder.blocks.0.mlp.final_output'),
  ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.1.mlp.final_output'),
  ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.2.mlp.final_output')},
 '0.2': {('embedding.final_output', 'encoder.blocks.0.mlp.final_output'),
  ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.1.mlp.final_output')},
 '0.3': set(),
 '0.4': set(),
 '0.5': set(),
 '0.6': set(),
 '0.7': set(),
 '0.8': set(),
 '0.9': set()}
circuits['0.01'] = {('embedding.final_output', 'encoder.blocks.0.attention.heads.4.final_output'),
 ('embedding.final_output', 'encoder.blocks.0.mlp.final_output'),
 ('embedding.final_output', 'encoder.blocks.1.mlp.final_output'),
 ('encoder.blocks.0.attention.heads.4.final_output',
  'encoder.blocks.0.mlp.final_output'),
 ('encoder.blocks.0.mlp.final_output',
  'encoder.blocks.1.attention.heads.2.final_output'),
 ('encoder.blocks.0.mlp.final_output',
  'encoder.blocks.1.attention.heads.4.final_output'),
 ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.1.mlp.final_output'),
 ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.2.mlp.final_output'),
 ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.3.mlp.final_output'),
 ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.4.mlp.final_output'),
 ('encoder.blocks.0.mlp.final_output', 'encoder.blocks.5.mlp.final_output'),
 ('encoder.blocks.1.attention.heads.2.final_output',
  'encoder.blocks.1.mlp.final_output'),
 ('encoder.blocks.1.attention.heads.3.final_output',
  'encoder.blocks.1.mlp.final_output'),
 ('encoder.blocks.1.attention.heads.4.final_output',
  'encoder.blocks.1.mlp.final_output'),
 ('encoder.blocks.1.attention.heads.4.final_output',
  'encoder.blocks.2.mlp.final_output'),
 ('encoder.blocks.1.attention.heads.6.final_output',
  'encoder.blocks.1.mlp.final_output'),
 ('encoder.blocks.1.attention.heads.7.final_output',
  'encoder.blocks.1.mlp.final_output'),
 ('encoder.blocks.1.mlp.final_output', 'encoder.blocks.2.mlp.final_output'),
 ('encoder.blocks.1.mlp.final_output', 'encoder.blocks.3.mlp.final_output'),
 ('encoder.blocks.1.mlp.final_output', 'encoder.blocks.4.mlp.final_output'),
 ('encoder.blocks.1.mlp.final_output', 'encoder.blocks.5.mlp.final_output'),
 ('encoder.blocks.2.mlp.final_output', 'encoder.blocks.3.mlp.final_output'),
 ('encoder.blocks.2.mlp.final_output', 'encoder.blocks.4.mlp.final_output'),
 ('encoder.blocks.2.mlp.final_output', 'encoder.blocks.5.mlp.final_output'),
 ('encoder.blocks.3.mlp.final_output', 'encoder.blocks.4.mlp.final_output'),
 ('encoder.blocks.3.mlp.final_output', 'encoder.blocks.5.mlp.final_output'),
 ('encoder.blocks.4.mlp.final_output', 'encoder.blocks.5.mlp.final_output')}

In [25]:
coarse_labels
coarse_labels_int = {i:cl for i, cl in enumerate(coarse_labels)}
t = torch.tensor([0, 0, 1, 1, 2])
coarse_labels_batch = [cl for cl, fl in coarse_labels_int.items() for label in t if label in fl]


TypeError: 'in <string>' requires string as left operand, not Tensor

In [26]:
coarse_labels_int

{0: 'Aquatic',
 1: 'Amphibians & Reptiles',
 2: 'Arthropods',
 3: 'Birds',
 4: 'Mammals',
 5: 'Marine Life & Fossils'}

In [22]:
get_accuracy(vit, val_loader, coarse_labels)

                                                    

tensor([0, 0, 1, 1, 2])




UnboundLocalError: cannot access local variable 'val_loss' where it is not associated with a value

In [None]:
from tqdm import tqdm
def get_accuracy(model, datalaoder, coarse_labels=None) -> float:
    if coarse_labels:
        coarse_labels_int = {i:cl for i, cl in enumerate(coarse_labels)}
    model.eval()
    correct, total = 0.0, 0

    dataloader_tqdm = tqdm(
        datalaoder, 
        desc=f"[Validation]", 
        leave=False
    )

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(dataloader_tqdm):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            print(labels)

            with torch.amp.autocast(device_type=device.type):
                outputs, _ = model(images)

            _, predicted = outputs.max(1)
            batch_correct = predicted.eq(labels).sum().item()
            batch_total = labels.size(0)

            val_loss += batch_loss * batch_total
            correct += batch_correct
            total += batch_total

            global_step = epoch * len(self.train_loader) + len(self.train_loader) + batch_idx
            batch_acc = 100.0 * batch_correct / batch_total
            self.writer.add_scalar('Loss/val_batch', batch_loss, global_step)
            self.writer.add_scalar('Accuracy/val_batch', batch_acc, global_step)

            if (batch_idx + 1) % self.log_interval == 0 or (batch_idx + 1) == len(self.val_loader):
                cumulative_loss = val_loss / total
                cumulative_acc = 100.0 * correct / total
                val_loader_tqdm.set_postfix(
                    loss=f"{cumulative_loss:.4f}", 
                    accuracy=f"{cumulative_acc:.2f}%"
                )

    epoch_loss = val_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc