A notebook to test the zigzag and hilbert embedding on pre-trained ViTs.

In [1]:
# Imports
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset, Subset
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedShuffleSplit

from fastai.basics import *
from timm.loss import SoftTargetCrossEntropy

from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights, VisionTransformer

import sys
sys.path.append('..')

from src.tokenizers.hilbert_embedding import HilbertEmbedding
from src.tokenizers.random_embedding import RandomEmbedding
from src.tokenizers.zigzag_embedding import ZigzagEmbedding
# from src.models.vit import VisionTransformer
from src.models.altvit import SimpleViT, HilbertViT
from src.training.scheduler import WarmupCosineScheduler
from src.training.train import train_with_mixup, evaluate

# Set global variables
torch.backends.cudnn.benchmark = True



Load in the pre-trained ViT model and the dataset

In [2]:
def hilbert_curve(order, size=1.0):
    """
    Generate points for a Hilbert curve of a given order
    on the unit square.

    Args:
        order (int): Recursion depth of the Hilbert curve.
        size (float): Length of one side of the entire curve's square.

    Returns:
        List[Tuple[float, float]]: The list of (x, y) points.
    """
    points = []

    def hilbert(x0, y0, xi, xj, yi, yj, n):
        if n <= 0:
            x = x0 + (xi + yi) / 2
            y = y0 + (xj + yj) / 2
            points.append((x, y))
        else:
            hilbert(x0, y0,               yi/2, yj /
                    2,               xi/2, xj/2, n-1)
            hilbert(x0 + xi/2, y0 + xj/2, xi/2, xj /
                    2,               yi/2, yj/2, n-1)
            hilbert(x0 + xi/2 + yi/2, y0 + xj/2 +
                    yj/2, xi/2, xj/2, yi/2, yj/2, n-1)
            hilbert(x0 + xi/2 + yi, y0 + xj/2 + yj, -
                    yi/2, -yj/2, -xi/2, -xj/2, n-1)

    hilbert(0, 0, size, 0, 0, size, order)
    return points

def resize_positional_embeddings(model, new_size):
    old_posemb = model.encoder.pos_embedding  # [1, 197, D]
    cls_token = old_posemb[:, :1, :]
    grid = old_posemb[:, 1:, :]  # [1, 196, D]

    num_patches_old = int(grid.shape[1] ** 0.5)
    grid = grid.reshape(1, num_patches_old, num_patches_old, -1).permute(0, 3, 1, 2)  # [1, D, H, W]
    grid = F.interpolate(grid, size=(new_size, new_size), mode='bilinear', align_corners=False)
    grid = grid.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
    new_posemb = torch.cat([cls_token, grid], dim=1)
    model.encoder.pos_embedding = torch.nn.Parameter(new_posemb)

def my_forward(self, x):
    B, C, H, W = x.shape
    D = self.hidden_dim
    p = self.patch_size
    N = (H//p)*(W//p)

    x = self.conv_proj(x)                     # [B,D,H/p,W/p]
    x = x.reshape(B, D, N).permute(0,2,1)      # [B,N,D]
    x = x[:, self.hilbert_indices, :]         # reorder

    cls_tok = self.class_token.expand(B, -1, -1)  # [B,1,D]
    x = torch.cat([cls_tok, x], dim=1)            # [B,N+1,D]

    x = x + self.pos_embed.unsqueeze(0)           # add your PE

    x = self.encoder(x)      # transformer encoder
    x = x[:,0]               # take cls
    return self.heads(x)

# Load model with pretrained weights
weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights=weights)

model.image_size = 128
model.patch_size = 16
model.num_classes = 257
model.channels = 1
model.optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-5)
model.forward = types.MethodType(my_forward, model)

grid_size = (model.image_size // model.patch_size)
order     = int(math.log2(model.image_size // model.patch_size))
points    = hilbert_curve(order)  # your function returning [(x,y) in [0,1)]
flat_idxs = [int(x*grid_size) * grid_size + int(y*grid_size)
             for x,y in points]
model.register_buffer("hilbert_indices", torch.tensor(flat_idxs, dtype=torch.long))

def build_hilbert_pe(indices, dim, T=4, h_param=3.0):
    n = indices.numel()
    N = int(math.sqrt(n))
    pos = indices.to(torch.float32).unsqueeze(1)   # [n,1]
    i_ar = torch.arange(dim//2, dtype=torch.float32).unsqueeze(0)
    two_pi = 2*math.pi
    scale = (2*i_ar * N**2 * pos * two_pi) / (T * n * dim)
    phase = (h_param * 2*i_ar * pos * two_pi) / dim
    arg   = scale + phase
    pe    = torch.cat([torch.sin(arg), torch.cos(arg)], dim=1)
    return pe  # [n, dim]

pe = build_hilbert_pe(model.hilbert_indices, model.hidden_dim)
cls_pe = torch.zeros(1, model.hidden_dim, device=pe.device, dtype=pe.dtype)
model.register_buffer("pos_embed", torch.cat([cls_pe, pe], dim=0))  # [N+1, D]
# old_conv = model.conv_proj
# model.conv_proj = ZigzagEmbedding(
#     img_size=model.image_size,
#     patch_size=model.patch_size,
#     in_channels=3,
#     embed_dim=model.hidden_dim
# )

# with torch.no_grad():
#     model.conv_proj.proj.weight.copy_(old_conv.weight)
#     model.conv_proj.proj.bias.copy_(old_conv.bias)

# Resize pos embeddings for 128x128 input with 16x16 patches = 8x8 = 64 patches
resize_positional_embeddings(model, new_size=8)

# Replace classifier head
model.heads.head = torch.nn.Linear(model.heads.head.in_features, 257)

print(model)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

Load the data

In [3]:
def load_data(dataset_name):
    """Load and return PyTorch-ready datasets."""
    if dataset_name == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        return train_set, test_set

    elif dataset_name == 'caltech256':
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )

        transform_train = transforms.Compose([
            transforms.Lambda(lambda img: img.convert("RGB")),
            transforms.RandomResizedCrop(128, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            transforms.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
        ])

        transform_val = transforms.Compose([
            transforms.Lambda(lambda img: img.convert("RGB")),
            transforms.Resize(256),
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            normalize
        ])

        # 2) Download once (no transform) to extract labels
        raw_dataset = datasets.Caltech256(
            root="./data/caltech256",
            download=True,
            transform=None
        )
        # Caltech256 stores its integer labels in .y
        labels = np.array(raw_dataset.y)

        # 3) Stratified split
        sss = StratifiedShuffleSplit(
            n_splits=1, test_size=0.2, random_state=42
        )
        train_idx, val_idx = next(sss.split(np.zeros(len(labels)), labels))

        # 4) Now create two dataset objects with the proper transforms
        dataset_train = datasets.Caltech256(
            root="./data/caltech256",
            download=False,
            transform=transform_train
        )
        dataset_val = datasets.Caltech256(
            root="./data/caltech256",
            download=False,
            transform=transform_val
        )

        # 5) Wrap in Subset and DataLoader
        train_set = Subset(dataset_train, train_idx)
        val_set   = Subset(dataset_val,   val_idx)

        return train_set, val_set

    elif dataset_name == 'imagenette':
        path = untar_data(URLs.IMAGENETTE_320)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(128, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            transforms.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
        ])

        transform_val = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            normalize
        ])

        train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transform=transform_train)
        val_dataset = datasets.ImageFolder(os.path.join(path, 'val'), transform=transform_val)

        return train_dataset, val_dataset

    else:
        raise ValueError(f"Dataset '{dataset_name}' not supported.")

train_set, test_set = load_data('caltech256')
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

train_set2, test_set2 = load_data('cifar10')
train_loader2 = DataLoader(train_set2, batch_size=64, shuffle=True)
test_loader2 = DataLoader(test_set2, batch_size=64, shuffle=False)

Now train the model

In [4]:
epochs = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

train_criterion = SoftTargetCrossEntropy()
test_criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

pre_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs, eta_min=1e-6
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs * 3, eta_min=1e-6
)

for epoch in range(epochs):
    train_loss, train_acc = train_with_mixup(model, train_loader, train_criterion, optimizer, pre_scheduler, device)
    test_loss, test_acc = evaluate(model, test_loader, test_criterion, device)
    print(
        f"Epoch {epoch + 1}: "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
        f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}"
    )

    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch + 1}.pth')
        print(f"Checkpoint saved for epoch {epoch + 1}")

for epoch in range(epochs * 3):
    train_loss, train_acc = train_with_mixup(model, train_loader2, train_criterion, optimizer, scheduler, device)
    test_loss, test_acc = evaluate(model, test_loader2, test_criterion, device)
    print(
        f"Epoch {epoch + 1}: "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
        f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}"
    )

                                                          

KeyboardInterrupt: 