In [1]:
import torch
import torch.nn as nn
from functools import partial

from timm.models import create_model
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor
from model.deit_reg.models_v2 import deit_small_patch16_LS, Mlp
from model.deit_reg import models_v2

from PIL import Image
from torchvision import datasets, transforms
from torchvision.datasets import Imagenette
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import os

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cpu


In [2]:
# @title get image paths
def get_paths_in_subfolders(folder_path: str) -> list[list[str]]:
  """Given a folder path, returns a list of all image paths in the subfolders.

  Args:
    folder_path: The path to the folder.

  Returns:
    A nested list of all image paths in the subfolders.
  """

  files : list[list[str]] = []
  # Iterate through the subfolders in the given folder
  for subfolder in os.listdir(folder_path):
    subfolder_path = os.path.join(folder_path, subfolder)
    folder : list[str] = []
    if os.path.isdir(subfolder_path):
      for file_ in os.listdir(subfolder_path):
        file_path = os.path.join(subfolder_path, file_)
        folder.append(file_path)

    files.append(folder)

  return files
     

In [3]:
def preprocess(x: Image.Image | list[Image.Image] | torch.Tensor,
               size: tuple[int, int] | int = (224, 224)):

  if isinstance(size, int):
    if not isinstance(x, Image.Image):
      raise ValueError("size must be a tuple for sequence of images")

    width, height = x.size
    size = ((width // size) * size, (height // size) * size)


  x = transforms.Resize(size)(x)
  x = transforms.ToTensor()(x)
  x = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN,
                           std=IMAGENET_DEFAULT_STD)(x)

  assert isinstance(x, torch.Tensor)
  if len(x.shape) == 3:
    x = x.unsqueeze(0)

  return x

In [4]:
def to_tensor(img):
    transform_fn = Compose([Resize(249, 3), CenterCrop(224), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    return transform_fn(img)

In [5]:
checkpoint =torch.load(f'checkpoints/source_checkpoints/checkpoint799.pth', map_location=torch.device(device))
print('epoch', checkpoint['epoch'])

model_deit = deit_small_patch16_LS()
model_deit.default_cfg = models_v2._cfg()

print(checkpoint['model'].keys())

model_deit.load_state_dict(checkpoint["model"])
model_deit.to(device)

  checkpoint =torch.load(f'checkpoints/source_checkpoints/checkpoint799.pth', map_location=torch.device(device))


epoch 799
odict_keys(['cls_token', 'pos_embed', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'blocks.0.gamma_1', 'blocks.0.gamma_2', 'blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.attn.qkv.weight', 'blocks.0.attn.qkv.bias', 'blocks.0.attn.proj.weight', 'blocks.0.attn.proj.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.0.mlp.fc1.weight', 'blocks.0.mlp.fc1.bias', 'blocks.0.mlp.fc2.weight', 'blocks.0.mlp.fc2.bias', 'blocks.1.gamma_1', 'blocks.1.gamma_2', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.attn.qkv.weight', 'blocks.1.attn.qkv.bias', 'blocks.1.attn.proj.weight', 'blocks.1.attn.proj.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.1.mlp.fc1.weight', 'blocks.1.mlp.fc1.bias', 'blocks.1.mlp.fc2.weight', 'blocks.1.mlp.fc2.bias', 'blocks.2.gamma_1', 'blocks.2.gamma_2', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.attn.qkv.weight', 'blocks.2.attn.qkv.bias', 'blocks.2.attn.proj.weight', 'blocks.2.attn.proj.bias', 'b

vit_models(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x Layer_scale_init_Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=384, ou

In [6]:
# Define the preprocessing steps (resize, crop, convert to tensor, and normalize)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Create a dataset from the directory structure. 
# ImageFolder automatically assigns labels based on subfolder names.
train_set = datasets.ImageFolder('datasets/imagenette2/train', transform=transform)
val_set = datasets.ImageFolder('datasets/imagenette2/val', transform=transform)


batch_size = 64

# Create a DataLoader to handle batching, shuffling, and parallel loading.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True, drop_last=True)
if torch.cuda.is_available():
    num_workers = 16
else:
    num_workers = 0
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=True)



In [7]:
model_deit.eval()

vit_models(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x Layer_scale_init_Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=384, ou

In [12]:
#if the model is pretrained, this checkpoint can be used and the training can be skipped
recon_checkpoint =torch.load(f'checkpoints/recon_checkpoints/best_checkpoint.pth', map_location=torch.device(device))
print('epoch', recon_checkpoint['epoch'])

recon_model = Mlp(in_features=384, hidden_features=1536, out_features=3*16*16
        )
print(recon_checkpoint.keys())
recon_model.load_state_dict(recon_checkpoint['model_state_dict'])
recon_model.to(device)
print(recon_checkpoint)

epoch 3
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss', 'val_loss', 'checkpoint_type'])
{'epoch': 3, 'model_state_dict': OrderedDict([('fc1.weight', tensor([[ 0.0064,  0.1183,  0.0193,  ..., -0.0141,  0.0582, -0.0209],
        [-0.0221,  0.0800,  0.0459,  ...,  0.1286, -0.0131,  0.0527],
        [ 0.0017,  0.1089,  0.0709,  ...,  0.0814,  0.0239, -0.0415],
        ...,
        [ 0.0318,  0.0403,  0.0830,  ...,  0.1119,  0.0247, -0.0090],
        [-0.0398,  0.0951,  0.0655,  ..., -0.0409, -0.0029,  0.0260],
        [ 0.0350,  0.0948,  0.2028,  ...,  0.0033,  0.0199,  0.0383]])), ('fc1.bias', tensor([-0.1863, -0.2569, -0.3461,  ..., -0.2510, -0.2404, -0.2555])), ('fc2.weight', tensor([[ 0.0086,  0.0087,  0.0046,  ..., -0.0120,  0.0021, -0.0086],
        [ 0.0022,  0.0096, -0.0028,  ..., -0.0132, -0.0046, -0.0086],
        [ 0.0016,  0.0090, -0.0073,  ..., -0.0124, -0.0052, -0.0108],
        ...,
        [ 0.0011, -0.0031,  0.0041,  ...,  0.0019, -0.0067, -0.0120],

  recon_checkpoint =torch.load(f'checkpoints/recon_checkpoints/best_checkpoint.pth', map_location=torch.device(device))


# Training model

In [10]:
def l2_error(model, transformer, data):
    model.eval()
    errors = []

    for images, _ in tqdm(data, desc="Calculating L2 error"):
        images = images.to(device)
        B = images.shape[0]
        data_inputs = images.reshape(B, 196, 768)

        x = transformer.patch_embed(images)
        cls_tokens = transformer.cls_token.expand(B, -1, -1)
        x = x + transformer.pos_embed
        x = torch.cat((cls_tokens, x), dim=1)
        for blk in transformer.blocks:
            x = blk(x)

        preds = model(x[:, 1:, :])
        preds = preds.squeeze(dim=1)

        error = torch.norm(preds - data_inputs, dim=1).mean()
        errors.append(error.item())

    return errors

In [11]:
def train_model(model, transformer, optimizer, train_data, val_data, loss_module, num_epochs=100, patience=10):
    # check if checkpoint exists
    checkpoint_loaded = False
    if os.path.exists('checkpoints/recon_checkpoints/checkpoint.pth'):
        checkpoint = torch.load('checkpoints/recon_checkpoints/checkpoint.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        checkpoint_loaded = True
        print('Loaded checkpoint')
        if checkpoint['epoch']+1 >= num_epochs:
            print('Model already trained for', num_epochs, 'epochs.')
            return model
        else:
            print('Model has been trained for', checkpoint['epoch']+1, 'epochs.')
            num_epochs -= checkpoint['epoch']+1

    # Set the model to training mode.
    model.train()

    losses = []
    val_losses = []
    if checkpoint_loaded:
        best_loss = checkpoint['val_loss']
    else:
        best_loss = float('inf')
    best_model_params = None

    no_improvement = 0
    for epoch in tqdm(range(num_epochs)):
        if checkpoint_loaded and epoch == 0:
            epoch += checkpoint['epoch']+1
        
        for images, _ in train_data:  # New DataLoader returns (images, labels); we ignore labels if not needed.
            images = images.to(device)
            B = images.shape[0]  # Actual batch size (may vary for the last batch).

            # Optionally, reshape images for use as targets in the loss.
            # (Using B rather than a fixed batch_size ensures compatibility if the last batch is smaller.)
            data_inputs = images.reshape(B, 196, 768)

            # Pass images through the transformer pipeline.
            x = transformer.patch_embed(images)
            cls_tokens = transformer.cls_token.expand(B, -1, -1)
            x = x + transformer.pos_embed
            x = torch.cat((cls_tokens, x), dim=1)
            for blk in transformer.blocks:
                x = blk(x)

            # Forward pass through the model (skipping the class token).
            preds = model(x[:, 1:, :])
            preds = preds.squeeze(dim=1)  # Change shape from [B, 1] to [B].

            # Compute the loss.
            loss = loss_module(preds, data_inputs).to(device)

            # Backpropagation steps.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_loss = l2_error(model, transformer, val_data)
        val_loss = np.mean(val_loss)

        print(epoch)
        checkpoint = {
            'epoch': epoch,  # current epoch number
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,  # optional: current loss value
            'val_loss': val_loss,  # optional: current validation loss
            'checkpoint_type': 'recon'  # identifier for the recon model
        }

        print(f"Epoch {epoch}: Loss = {loss.item()}, Validation Loss = {val_loss}")

        losses.append(loss.item())
        val_losses.append(val_loss)
        if val_loss < best_loss:
            no_improvement = 0
            best_loss = val_loss
            best_model_params = model.state_dict()  # Save best model parameters.

            torch.save(checkpoint, f'checkpoints/recon_checkpoints/best_checkpoint.pth')
            print('New best loss achieved.')
        else:
            no_improvement += 1

        torch.save(checkpoint, f'checkpoints/recon_checkpoints/checkpoint.pth')
    
        if no_improvement >= patience:
            print('No improvement for', patience, 'epochs. Stopping training.')
            break

    return best_model_params


In [None]:
optimizer = torch.optim.AdamW(recon_model.parameters())
loss = nn.MSELoss()
train_model(model=recon_model, transformer= model_deit, optimizer=optimizer, train_data=train_loader, val_data=val_loader, loss_module=loss, num_epochs=200, patience=5)

# clear cuda memory
torch.cuda.empty_cache()

In [None]:
# load best recon model
checkpoint = torch.load('checkpoints/recon_checkpoints/best_checkpoint.pth')
recon_model.load_state_dict(checkpoint['model_state_dict'])
recon_model.to(device)

errors = l2_error(recon_model, model_deit, val_loader)
print(f'Mean L2 error: {np.mean(errors)}')
print(f'Median L2 error: {np.median(errors)}')
print(f'Min L2 error: {np.min(errors)}')
print(f'Max L2 error: {np.max(errors)}')

# clear cuda memory
torch.cuda.empty_cache()

# comparing high and low norm patches

## getting norm threshold

In [9]:
def get_deit_final_patches(model, imgs):
    """
    For DeiT-III, replicate a forward pass to get final patch tokens (excluding the CLS token).
    """
    with torch.no_grad():
        x = model.patch_embed(imgs)
        B, N, D = x.shape

        # add cls token + pos embed
        cls_token = model.cls_token.expand(B, -1, -1)  # [B, 1, D]
        pos_embed = model.pos_embed[:, : (N+1), :]
        x = torch.cat([cls_token, x], dim=1)  # => [B, N+1, D]
        x = x[:, 1:, :] + pos_embed

        for blk in model.blocks:
            x = blk(x)

        x = model.norm(x)  # [B, N+1, D]
        x = x  # => [B, N, D]
    return x

In [10]:
def gather_global_threshold(
    final_extraction_func,
    model,
    loader,
    percentile=0.98
):
    all_norms = []
    for i, (imgs, _) in enumerate(loader):
        imgs = imgs.to(device)
        final_embs = final_extraction_func(model, imgs)  # [B, N, D]
        norms = torch.norm(final_embs, dim=-1)  # [B, N]
        all_norms.append(norms.flatten().cpu())
    all_norms = torch.cat(all_norms, dim=0)
    threshold = torch.quantile(all_norms, percentile).item()
    return threshold

#print(gather_global_threshold(get_deit_final_patches, model_deit, val_loader))

## gather patches with high norms

In [11]:
# To investigate the local information that high and low norm patches hold, they are separated and the loss on them is calculated and compared
high_norm_patches_error = 0
low_norm_patches_error = 0
i = 0
high_norm_patches_errors = []
low_norm_patches_errors = []
batch_size = val_loader.batch_size
threshold = gather_global_threshold(get_deit_final_patches, model_deit, val_loader)
print("Norm threshold:", threshold)

for imgs, _ in val_loader:
    final_embs = get_deit_final_patches(model_deit, imgs)
    norms = torch.norm(final_embs, dim=-1)  # [B, N]
    norms = (norms >= threshold)

    sample_images = imgs.permute(0, 2, 3, 1).reshape(batch_size, 196, 16, 16, 3)
    y = recon_model(final_embs).reshape(batch_size, 196, 16, 16, 3)
    highs = (norms==True)
    lows = (norms==False)

    high_norm_patches = y[highs]
    high_norm_patches_label = sample_images[highs]
    low_norm_patches = y[lows]
    low_norm_patches_label = sample_images[lows]

    high_norm_patches_error += F.mse_loss(high_norm_patches, high_norm_patches_label).detach().numpy()
    low_norm_patches_error += F.mse_loss(low_norm_patches, low_norm_patches_label).detach().numpy()


print("High norm error:", high_norm_patches_error)
print("Low norm error:", low_norm_patches_error)
print("Ratio:", high_norm_patches_error/low_norm_patches_error*100)

Norm threshold: 72.4803695678711
High norm error: 176.54341995716095
Low norm error: 120.02065622806549
ratio: 147.0941965370443
