In [1]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.append('../src')
from modules import (
                    paths,
                    dataset,
                    model,
                    utils,
                    acdc,
                    train
                    )
from torchvision.transforms import v2
from torch.optim import AdamW

/home/lexyo/Dev/cv-proj2/notebooks/../src/modules/paths.py


  from .autonotebook import tqdm as notebook_tqdm
/home/lexyo/.config/matplotlib is not a writable directory
Matplotlib created a temporary cache directory at /tmp/matplotlib-rm2b4tuu because there was an issue with the default path (/home/lexyo/.config/matplotlib); it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
transform_train = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),

])

transform_valid = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [22]:
batch_size = 20
toy=False

In [23]:
import importlib
importlib.reload(dataset)
if toy == True:
    print("laoding toy datasets")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train, tiny=True, stop=6)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid, tiny=True, stop=2)

else:
    print("loading full dataet")
    train_dataset, coarse_labels = dataset.load_animal_dataset("train", transform=transform_train)
    val_dataset, coarse_labels = dataset.load_animal_dataset("valid", transform=transform_valid)
print("train:\n"+str(train_dataset))
print("validation:\n"+str(val_dataset))


loading full dataet
Loading animal dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/animal_train.pkl
Loading animal dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/animal_valid.pkl
train:
Dataset({
    features: ['image', 'label'],
    num_rows: 29000
})
validation:
Dataset({
    features: ['image', 'label'],
    num_rows: 2900
})


In [19]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=3,
    shuffle=False,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

In [20]:
batch = next(iter(train_loader))




In [8]:
torch.cuda.is_available()

False

In [None]:
importlib.reload(model)
config = {
    "patch_size": 8,           # Kept small for fine-grained patches
    "hidden_size": 48,          # Increased from 48 (better representation)
    "num_hidden_layers": 6,     # Deeper for pruning flexibility
    "num_attention_heads": 8,   # More heads (head_dim = 64/8 = 8)
    "intermediate_size": 4 * 64,# Standard FFN scaling
    "hidden_dropout_prob": 0.1, # Mild dropout for regularization
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 64,
    "num_classes": 200,
    "num_channels": 3,
    "qkv_bias": True,           # Keep bias for now (can prune later)
}
vit = model.ViT(config)

importlib.reload(train)

class SoftTargetCrossEntropy(nn.Module):
    """Cross-entropy loss compatible with Mixup/Cutmix soft labels"""
    def __init__(self):
        super().__init__()
    
    def forward(self, x, target):
        # x = model outputs (logits)
        # target = mixed labels (probability distributions)
        loss = torch.sum(-target * F.log_softmax(x, dim=1), dim=1)
        return loss.mean()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_epochs = 10
warmup_epochs = 2
base_lr = 3e-4
min_lr = 1e-6
weight_decay = 0.05  # For AdamW optimizer
label_smoothing = 0.1  # For cross-entropy
patience = 20



optimizer = AdamW(vit.parameters(),
                  lr=base_lr,
                  weight_decay = weight_decay,
                  betas=(0.9, 0.999)
                  )

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=num_epochs - warmup_epochs,
    eta_min=min_lr
)

mixup_fn = v2.MixUp(
    alpha=1.0,          # Add CutMix
    num_classes=200
)

trainer = train.Trainer(model=vit,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        optimizer=optimizer,
                        criterion=SoftTargetCrossEntropy(),
                        val_criterion=nn.CrossEntropyLoss(),
                        scheduler=scheduler,
                        device = device,
                        writer=torch.utils.tensorboard.SummaryWriter(log_dir=paths.logs),
                        scaler=torch.cuda.amp.GradScaler(),
                        num_epochs=num_epochs,
                        log_interval=50,
                        model_dir=paths.chekpoints,
                        mixup_fn=mixup_fn,
                        early_stop_patience=20,
                        model_name="vit1"
                        )
acc = trainer.train()

  scaler=torch.cuda.amp.GradScaler(),


Logging to /home/lexyo/Dev/cv-proj2/notebooks/../logs


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
                                                                                                     

Epoch [1/10] | Train Loss: 5.2698, Acc: 1.44% | Val Loss: 5.1743, Val Acc: 1.72%
New best model saved with accuracy: 1.72%


                                                                                                     

Epoch [2/10] | Train Loss: 4.9291, Acc: 1.72% | Val Loss: 4.4595, Val Acc: 1.72%


                                                                                                     

Epoch [3/10] | Train Loss: 4.3012, Acc: 1.44% | Val Loss: 4.1420, Val Acc: 1.72%


                                                                                                     

Epoch [4/10] | Train Loss: 4.1419, Acc: 2.01% | Val Loss: 4.0818, Val Acc: 2.59%
New best model saved with accuracy: 2.59%


                                                                                                     

Epoch [5/10] | Train Loss: 4.1082, Acc: 1.15% | Val Loss: 4.0710, Val Acc: 2.59%


                                                                                                     

Epoch [6/10] | Train Loss: 4.0847, Acc: 1.44% | Val Loss: 4.0602, Val Acc: 3.45%
New best model saved with accuracy: 3.45%


                                                                                                     

Epoch [7/10] | Train Loss: 4.0708, Acc: 2.30% | Val Loss: 4.0579, Val Acc: 3.45%


                                                                                                     

Epoch [8/10] | Train Loss: 4.0620, Acc: 2.30% | Val Loss: 4.0576, Val Acc: 2.59%


                                                                                                     

Epoch [9/10] | Train Loss: 4.0894, Acc: 1.72% | Val Loss: 4.0598, Val Acc: 5.17%
New best model saved with accuracy: 5.17%


                                                                                                      

Epoch [10/10] | Train Loss: 4.0813, Acc: 2.01% | Val Loss: 4.0391, Val Acc: 1.72%
Training complete. Best validation accuracy: 5.17%




In [10]:
importlib.reload(utils)

computation_graph = utils.ComputationalGraph(vit)
print(computation_graph.nodes.keys())

dict_keys(['embedding', 'encoder.blocks.0.mlp.final_output', 'encoder.blocks.0.attention.heads.0.final_output', 'encoder.blocks.0.attention.heads.1.final_output', 'encoder.blocks.0.attention.heads.2.final_output', 'encoder.blocks.0.attention.heads.3.final_output', 'encoder.blocks.0.attention.heads.4.final_output', 'encoder.blocks.0.attention.heads.5.final_output', 'encoder.blocks.0.attention.heads.6.final_output', 'encoder.blocks.0.attention.heads.7.final_output', 'encoder.blocks.1.mlp.final_output', 'encoder.blocks.1.attention.heads.0.final_output', 'encoder.blocks.1.attention.heads.1.final_output', 'encoder.blocks.1.attention.heads.2.final_output', 'encoder.blocks.1.attention.heads.3.final_output', 'encoder.blocks.1.attention.heads.4.final_output', 'encoder.blocks.1.attention.heads.5.final_output', 'encoder.blocks.1.attention.heads.6.final_output', 'encoder.blocks.1.attention.heads.7.final_output', 'encoder.blocks.2.mlp.final_output', 'encoder.blocks.2.attention.heads.0.final_output'

In [11]:
len(computation_graph.edges)

1317

In [12]:
importlib.reload(acdc)

with acdc.SaveActivations(list(computation_graph.nodes.values())) as ctx:
    vit(batch[0])
    activations = ctx.get_activations()
activations

{'embedding.final_output': tensor([[[-1.1115e-02,  4.5102e-03,  1.4339e-02,  ..., -6.3072e-02,
           -9.9170e-03,  5.8590e-03],
          [ 9.4512e-02,  1.3799e-01, -1.6011e-01,  ...,  4.9907e-01,
           -2.9102e-01, -2.9863e-01],
          [-4.9579e-01, -1.1116e-02,  3.5017e-01,  ..., -8.7025e-01,
           -9.8365e-03,  4.6861e-01],
          ...,
          [ 7.4574e-01, -1.1044e-01,  4.6286e-01,  ..., -1.9132e-01,
           -1.3966e-01, -1.0987e+00],
          [ 7.3893e-01, -1.8903e-01,  4.9467e-01,  ...,  1.1164e-01,
            5.8736e-02, -1.3202e+00],
          [ 6.3944e-01, -1.7670e-01,  3.0681e-01,  ..., -2.9798e-02,
           -2.1390e-01, -1.1330e+00]],
 
         [[-1.1115e-02,  4.5102e-03,  1.4339e-02,  ..., -6.3072e-02,
           -9.9170e-03,  5.8590e-03],
          [-5.4932e-01,  2.1236e-01, -2.0624e-01,  ...,  2.3445e-02,
            5.7441e-01,  9.5717e-01],
          [-4.8156e-01, -6.2293e-01, -4.0438e-01,  ...,  2.8290e-01,
           -1.8369e-01,  3.9221

In [13]:
importlib.reload(dataset)
animal_dataset, coarse_labels = dataset.load_animal_dataset("train")
animal_dataset[0]
data = dataset.load("valid", tiny=True)
data[0]


TypeError: load_animal_dataset() missing 1 required positional argument: 'transform'

In [None]:
importlib.reload(dataset)
small_animal_dataset, coarse_labels = dataset.load_animal_dataset("train", tiny=True, start=0, stop=4)
# matching_dataset = dataset.ContrastiveWrapper(small_animal_dataset, coarse_labels)
small_animal_dataset[0]


In [None]:
matching_dataset[1]

In [None]:
importlib.reload(acdc)
with acdc.ReplaceActivations(m.l2, activations["dense2"]):
    with acdc.SaveActivations([(m.l1, "dense1"), (m.l2, "dense2")]) as ctx:
        print(m(torch.Tensor([1])))
        activations1 = ctx.get_activations()
    
print(activations1)

In [None]:
importlib.reload(dataset)
small_animal_dataset, coarse_labels = dataset.load_animal_dataset("train", tiny=True, start=0, stop=4)
counting = {}
for sample in small_animal_dataset:
    if sample["label"] not in counting: counting[sample["label"]] = 1
    else: counting[sample["label"]]+=1
print(counting)