In [45]:
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append('../src')
from modules import (
                    paths,
                    dataset,
                    model,
                    utils,
                    acdc
                    )
from torchvision.transforms import v2

In [2]:
toy = True
batch_size = 2

In [3]:
import importlib
importlib.reload(dataset)
if toy == True:
    print("laoding toy datasets")
    train_dataset = dataset.load("train", tiny=True, stop=2)
    val_dataset = dataset.load("valid", tiny=True, stop=2)

else:
    print("loading full dataet")
    train_dataset = dataset.load("train", stop=2)
    val_dataset = dataset.load("valid", stop=2)
print("train:\n"+str(train_dataset))
print("validation:\n"+str(val_dataset))


laoding toy datasets
Loading dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/train.pkl
Loading dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/valid.pkl
train:
Dataset({
    features: ['image', 'label'],
    num_rows: 400
})
validation:
Dataset({
    features: ['image', 'label'],
    num_rows: 400
})


In [4]:
transform_train = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),

])

transform_valid = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [5]:
train_loader = DataLoader(
    dataset.TorchDatasetWrapper(train_dataset, transform_train),
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

val_loader = DataLoader(
    dataset.TorchDatasetWrapper(val_dataset, transform_valid),
    batch_size=3,
    shuffle=False,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

In [6]:
batch = next(iter(train_loader))



In [49]:
importlib.reload(model)
config = {
    "patch_size": 8,           # Kept small for fine-grained patches
    "hidden_size": 64,          # Increased from 48 (better representation)
    "num_hidden_layers": 6,     # Deeper for pruning flexibility
    "num_attention_heads": 8,   # More heads (head_dim = 64/8 = 8)
    "intermediate_size": 4 * 64,# Standard FFN scaling
    "hidden_dropout_prob": 0.1, # Mild dropout for regularization
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 64,
    "num_classes": 10,
    "num_channels": 3,
    "qkv_bias": True,           # Keep bias for now (can prune later)
}
# embedding = model.Embeddings(config)
# x = embedding(batch[0])
# single = model.AttentionHead(1000, 10, 0.1)
# single(x)
# multi = model.MultiHeadAttention(config)
# multi(x, output_attentions = True)
# encoder = model.Encoder(config)
# encoder(x)
# input = torch.ones(config["hidden_size"])
# bad_input = torch.ones(config["hidden_size"]) +100000
vit = model.ViT(config)

modules.model


In [55]:
importlib.reload(utils)

computation_graph = utils.ComputationalGraph(vit)
print(computation_graph.nodes.keys())

dict_keys(['embedding', 'encoder.blocks.0.mlp.final_output', 'encoder.blocks.0.attention.heads.0.final_output', 'encoder.blocks.0.attention.heads.1.final_output', 'encoder.blocks.0.attention.heads.2.final_output', 'encoder.blocks.0.attention.heads.3.final_output', 'encoder.blocks.0.attention.heads.4.final_output', 'encoder.blocks.0.attention.heads.5.final_output', 'encoder.blocks.0.attention.heads.6.final_output', 'encoder.blocks.0.attention.heads.7.final_output', 'encoder.blocks.1.mlp.final_output', 'encoder.blocks.1.attention.heads.0.final_output', 'encoder.blocks.1.attention.heads.1.final_output', 'encoder.blocks.1.attention.heads.2.final_output', 'encoder.blocks.1.attention.heads.3.final_output', 'encoder.blocks.1.attention.heads.4.final_output', 'encoder.blocks.1.attention.heads.5.final_output', 'encoder.blocks.1.attention.heads.6.final_output', 'encoder.blocks.1.attention.heads.7.final_output', 'encoder.blocks.2.mlp.final_output', 'encoder.blocks.2.attention.heads.0.final_output'

In [52]:
len(computation_graph.edges)

1317

In [9]:
importlib.reload(acdc)

with acdc.SaveActivations(list(computation_graph.nodes.values())) as ctx:
    vit(batch[0])
    activations = ctx.get_activations()
activations

{'embedding.final_output': tensor([[[ 0.0065, -0.0174, -0.0183,  ..., -0.0457, -0.0050,  0.0233],
          [ 0.2335, -0.2645,  0.0179,  ..., -0.2675, -0.5039,  0.2646],
          [ 0.0192, -0.2172,  0.2180,  ..., -0.1022, -0.3378,  0.2323],
          ...,
          [-0.0906, -0.2557,  0.4264,  ..., -0.2720, -0.0510,  0.0466],
          [ 0.0938, -0.2093,  0.5405,  ..., -0.1086, -0.2870,  0.0099],
          [-0.1115,  0.3952,  0.2975,  ...,  0.0880, -0.3537,  0.2548]],
 
         [[ 0.0065, -0.0174, -0.0183,  ..., -0.0457, -0.0050,  0.0233],
          [-0.1569,  0.2763, -0.2673,  ...,  0.2274,  0.3379,  0.0398],
          [-0.4459,  0.2735, -0.1393,  ...,  0.0847,  0.3884, -0.2525],
          ...,
          [ 0.3496, -0.3625, -0.1129,  ..., -0.2534, -0.5272,  0.2232],
          [ 0.0662, -0.4209,  0.2272,  ...,  0.1860, -0.1374,  0.3052],
          [ 0.1646, -0.5048, -0.0607,  ..., -0.1067, -0.3631,  0.0732]]],
        grad_fn=<AddBackward0>),
 'encoder.blocks.0.mlp.final_output': tens

In [47]:
importlib.reload(dataset)
animal_dataset, coarse_labels = dataset.load_animal_dataset("train")
animal_dataset[0]
data = dataset.load("valid", tiny=True)
data[0]


Loading animal dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/animal_train.pkl
Loading dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/valid.pkl


(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>, 0)

In [48]:
importlib.reload(dataset)
small_animal_dataset, coarse_labels = dataset.load_animal_dataset("train", tiny=True, start=0, stop=4)
# matching_dataset = dataset.ContrastiveWrapper(small_animal_dataset, coarse_labels)
small_animal_dataset[0]


Loading animal dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/animal_train.pkl


(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>, 0)

In [38]:
matching_dataset[1]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64 at 0x7F98D6545610>, 'label': 0}


ValueError: Label label at index 1 not found in coarse_labels mapping.

In [None]:
importlib.reload(acdc)
with acdc.ReplaceActivations(m.l2, activations["dense2"]):
    with acdc.SaveActivations([(m.l1, "dense1"), (m.l2, "dense2")]) as ctx:
        print(m(torch.Tensor([1])))
        activations1 = ctx.get_activations()
    
print(activations1)

In [26]:
importlib.reload(dataset)
small_animal_dataset, coarse_labels = dataset.load_animal_dataset("train", tiny=True, start=0, stop=4)
counting = {}
for sample in small_animal_dataset:
    if sample["label"] not in counting: counting[sample["label"]] = 1
    else: counting[sample["label"]]+=1
print(counting)

Loading animal dataset from /home/lexyo/Dev/cv-proj2/notebooks/../data/animal_train.pkl
{0: 4, 1: 4, 2: 4, 3: 4, 4: 4, 5: 4, 6: 4, 7: 4, 8: 4, 9: 4, 10: 4, 11: 4, 12: 4, 13: 4, 14: 4, 15: 4, 16: 4, 17: 4, 18: 4, 19: 4, 20: 4, 21: 4, 22: 4, 23: 4, 24: 4, 25: 4, 26: 4, 27: 4, 28: 4, 29: 4, 30: 4, 31: 4, 32: 4, 33: 4, 34: 4, 35: 4, 36: 4, 37: 4, 38: 4, 39: 4, 40: 4, 41: 4, 42: 4, 43: 4, 44: 4, 45: 4, 46: 4, 47: 4, 48: 4, 49: 4, 50: 4, 51: 4, 52: 4, 53: 4, 170: 4, 184: 4, 187: 4, 190: 4}
