In [None]:
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append('../src')
from modules import (
                    paths,
                    dataset,
                    model,
                    utils,
                    acdc,
                    train
                    )
from torchvision.transforms import v2
from torch.optim import AdamW

In [None]:
transform_train = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),

])

transform_valid = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),  # some images are in grayscale
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
batch_size = 2
toy=True

In [None]:
import importlib
importlib.reload(dataset)
if toy == True:
    print("laoding toy datasets")
    train_dataset = dataset.load("train", transform=transform_train, tiny=True, stop=6)
    val_dataset = dataset.load("valid", transform=transform_valid, tiny=True, stop=2)

else:
    print("loading full dataet")
    train_dataset = dataset.load("train", transform=transform_train)
    val_dataset = dataset.load("valid", transform=transform_valid)
print("train:\n"+str(train_dataset))
print("validation:\n"+str(val_dataset))


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=3,
    shuffle=False,
    num_workers=2,  
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

In [None]:
batch = next(iter(train_loader))


In [None]:
importlib.reload(model)
config = {
    "patch_size": 8,           # Kept small for fine-grained patches
    "hidden_size": 64,          # Increased from 48 (better representation)
    "num_hidden_layers": 6,     # Deeper for pruning flexibility
    "num_attention_heads": 8,   # More heads (head_dim = 64/8 = 8)
    "intermediate_size": 4 * 64,# Standard FFN scaling
    "hidden_dropout_prob": 0.1, # Mild dropout for regularization
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 64,
    "num_classes": 10,
    "num_channels": 3,
    "qkv_bias": True,           # Keep bias for now (can prune later)
}
vit = model.ViT(config)

In [None]:
trainer = train.Trainer(model=vit,
                        train_loader=train_loader,
                        val_loader=val_loader
                        optimizer=,
                        criterion=,
                        scheduler=,

                        )

In [None]:
importlib.reload(utils)

computation_graph = utils.ComputationalGraph(vit)
print(computation_graph.nodes.keys())

In [None]:
len(computation_graph.edges)

In [None]:
importlib.reload(acdc)

with acdc.SaveActivations(list(computation_graph.nodes.values())) as ctx:
    vit(batch[0])
    activations = ctx.get_activations()
activations

In [None]:
importlib.reload(dataset)
animal_dataset, coarse_labels = dataset.load_animal_dataset("train")
animal_dataset[0]
data = dataset.load("valid", tiny=True)
data[0]


In [None]:
importlib.reload(dataset)
small_animal_dataset, coarse_labels = dataset.load_animal_dataset("train", tiny=True, start=0, stop=4)
# matching_dataset = dataset.ContrastiveWrapper(small_animal_dataset, coarse_labels)
small_animal_dataset[0]


In [None]:
matching_dataset[1]

In [None]:
importlib.reload(acdc)
with acdc.ReplaceActivations(m.l2, activations["dense2"]):
    with acdc.SaveActivations([(m.l1, "dense1"), (m.l2, "dense2")]) as ctx:
        print(m(torch.Tensor([1])))
        activations1 = ctx.get_activations()
    
print(activations1)

In [None]:
importlib.reload(dataset)
small_animal_dataset, coarse_labels = dataset.load_animal_dataset("train", tiny=True, start=0, stop=4)
counting = {}
for sample in small_animal_dataset:
    if sample["label"] not in counting: counting[sample["label"]] = 1
    else: counting[sample["label"]]+=1
print(counting)