In [1]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=512, patch_size=16, in_ch=3, emb_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels=in_ch,
            out_channels=emb_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        # x: (B, 3, 32, 32)
        x = self.proj(x)                    # → (B, emb_dim, H', W') = (B, 64, 8, 8)
        x = x.flatten(2)                    # → (B, 64, 64) — flatten H' and W'
        x = x.transpose(1, 2)               # → (B, 64, 64) → (B, num_patches, emb_dim)
        return x


In [2]:
class ViTWithPosition(nn.Module):
    def __init__(self, patch_embed, emb_dim=768, num_patches=1024):
        super().__init__()
        self.patch_embed = patch_embed  # instance of PatchEmbedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))  # +1 for CLS

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)                      # (B, 64, 64)
        cls = self.cls_token.expand(B, -1, -1)       # (B, 1, 64)
        x = torch.cat([cls, x], dim=1)               # (B, 65, 64)
        x = x + self.pos_embed                       # add positional encoding
        return x


In [13]:
import torch.nn.functional as F

class SketchSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=1, sketch_dim=None, use_sketch=False,
                 train_mode=False, layer_idx=0, s_q_path=None, s_k_path=None):
        super(SketchSelfAttention, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim if num_heads == 1 else dim // num_heads
        self.sketch_dim = sketch_dim
        self.sketch_head_dim = sketch_dim if num_heads == 1 else sketch_dim // num_heads
        self.use_sketch = use_sketch
        self.train_mode = train_mode

        self.Wq_list = nn.ModuleList()
        self.Wk_list = nn.ModuleList()
        self.Wv_list = nn.ModuleList()

        if use_sketch and train_mode:
            self.S_q_list = nn.ParameterList()
            self.S_k_list = nn.ParameterList()

        for h in range(num_heads):
            if use_sketch and not train_mode:
                # Inference: fused linear layers
                self.Wq_list.append(nn.Linear(self.head_dim, self.sketch_head_dim, bias=False))
                self.Wk_list.append(nn.Linear(self.head_dim, self.sketch_head_dim, bias=False))
            else:
                # Train: full projection before sketching
                self.Wq_list.append(nn.Linear(self.head_dim, self.head_dim, bias=False))
                self.Wk_list.append(nn.Linear(self.head_dim, self.head_dim, bias=False))

            self.Wv_list.append(nn.Linear(self.head_dim, self.head_dim, bias=False))

            if use_sketch and train_mode:
                # Learnable sketch matrices
                S_q = nn.Parameter(torch.empty(self.sketch_head_dim, self.head_dim))
                S_k = nn.Parameter(torch.empty(self.sketch_head_dim, self.head_dim))
                nn.init.kaiming_uniform_(S_q, a=np.sqrt(5))
                nn.init.kaiming_uniform_(S_k, a=np.sqrt(5))

                self.S_q_list.append(S_q)
                self.S_k_list.append(S_k)

                self.register_parameter(f"S_q_layer{layer_idx}_head{h}", S_q)
                self.register_parameter(f"S_k_layer{layer_idx}_head{h}", S_k)

        self.softmax = nn.Softmax(dim=-1)
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        split = torch.chunk(x, self.num_heads, dim=2)
        out_heads = []

        for h in range(self.num_heads):
            q = self.Wq_list[h](split[h])
            k = self.Wk_list[h](split[h])
            v = self.Wv_list[h](split[h])

            if self.use_sketch and self.train_mode:
                q = F.linear(q, self.S_q_list[h])
                k = F.linear(k, self.S_k_list[h])

            attn = self.softmax(torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(q.shape[-1]))
            out = torch.matmul(attn, v)
            out_heads.append(out)

        out = torch.cat(out_heads, dim=-1)
        return self.output_proj(out)


In [4]:
class SimpleTransformerBlock(nn.Module):
    def __init__(self, dim, sketch_dim, num_heads, use_sketch,
                 train_mode, layer_idx,
                 wq_dir=None, wk_dir=None, s_q_path=None, s_k_path=None):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = SketchSelfAttention(dim=dim, sketch_dim=sketch_dim,
                                num_heads=num_heads,
                                use_sketch=use_sketch, train_mode=train_mode,
                                layer_idx=layer_idx,
                                s_q_path=s_q_path, s_k_path=s_k_path)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [5]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=512, patch_size=16, in_ch=3, emb_dim=768, depth=1, num_heads=1,
                 sketch_dim=256, num_classes=40, use_sketch=True, train_mode=True,
                 wq_dir=None, wk_dir=None, s_q_path=None, s_k_path=None):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_ch, emb_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))

        self.blocks = nn.Sequential(*[
            SimpleTransformerBlock(emb_dim, sketch_dim, num_heads, use_sketch, train_mode, layer_idx=l,
                                  wq_dir=wq_dir, wk_dir=wk_dir,
                                  s_q_path=s_q_path, s_k_path=s_k_path)
            for l in range(depth)
        ])

        self.norm = nn.LayerNorm(emb_dim)
        self.head = nn.Linear(emb_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.blocks(x)
        x = self.norm(x[:, 0])
        return self.head(x)


In [6]:
import os
from collections import defaultdict, Counter
import numpy as np
from PIL import Image
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torch

def extract_category(class_name):
    return class_name.split("_")[0]  # e.g., 'apple_blackrot' → 'apple'

class PlantVillageBalancedDataset(Dataset):
    def __init__(self, root, selected_classes=None, transform=None, split="train"):
        self.root = root
        self.transform = transform
        self.split = split

        self.base_dataset = datasets.ImageFolder(root)
        self.classes = self.base_dataset.classes
        self.class_to_idx = self.base_dataset.class_to_idx

        if selected_classes is None:
            # Count images per class
            class_counts = Counter([self.classes[label] for _, label in self.base_dataset.samples])
            # Filter classes with enough images (≥ 650 total)
            eligible_classes = {cls: count for cls, count in class_counts.items() if count >= 650}

            # Group classes by category
            category_to_classes = defaultdict(list)
            for cls in eligible_classes:
                category = extract_category(cls)
                category_to_classes[category].append(cls)

            # Only keep categories with ≥ 2 classes
            final_classes = []
            for category, cls_list in category_to_classes.items():
                if len(cls_list) >= 2:
                    final_classes.extend(cls_list)
                if len(final_classes) >= 25:
                    break

            self.selected_classes = final_classes[:25]  #  top 25
            print(f"[INFO] Selected {len(self.selected_classes)} classes from grouped categories.")
        else:
            self.selected_classes = selected_classes

        self.selected_class_to_idx = {cls: i for i, cls in enumerate(self.selected_classes)}
        self.idx_remap = {self.class_to_idx[cls]: i for i, cls in enumerate(self.selected_classes)}

        self.image_paths, self.image_to_class = self._build_balanced_dataset()

    def _build_balanced_dataset(self):
        class_to_images = defaultdict(list)
        for img_path, label in self.base_dataset.samples:
            class_name = self.classes[label]
            if class_name in self.selected_classes:
                class_to_images[class_name].append(img_path)

        image_paths, image_to_class = [], {}
        for cls in self.selected_classes:
            imgs = class_to_images[cls]
            np.random.shuffle(imgs)

            if self.split == "train":
                selected = imgs[:max(0, len(imgs) - 150)][:500]
            elif self.split == "val":
                selected = imgs[-150:-50][:100]
            else:  # test
                selected = imgs[-50:][:50]

            for path in selected:
                image_paths.append(path)
                image_to_class[path] = self.class_to_idx[cls]

        return image_paths, image_to_class

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.idx_remap[self.image_to_class[img_path]]
        return img, torch.tensor(label)


transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])


root_path = "/kaggle/input/plantvillage-dataset/color"


train_dataset = PlantVillageBalancedDataset(root=root_path, transform=transform, split="train")
selected_classes = train_dataset.selected_classes

val_dataset = PlantVillageBalancedDataset(root=root_path, transform=transform, split="val", selected_classes=selected_classes)
test_dataset = PlantVillageBalancedDataset(root=root_path, transform=transform, split="test", selected_classes=selected_classes)


def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch])
    return images, labels

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2, collate_fn=collate_fn)


img, label = train_dataset[0]
print(f"Image shape: {img.shape}, Label: {label}")


[INFO] Selected 21 classes from grouped categories.
Image shape: torch.Size([3, 512, 512]), Label: 0


In [7]:
def verify_class_distribution(dataset, class_names):
    # Count occurrences of each class
    class_counts = torch.zeros(len(class_names))
    
    for _, label in DataLoader(dataset, batch_size=256, collate_fn=collate_fn):
        unique, counts = torch.unique(label, return_counts=True)
        class_counts[unique] += counts.float()
    
    print("\nClass distribution:")
    for i, (name, count) in enumerate(zip(class_names, class_counts)):
        print(f"{i:2d} {name:15s}: {int(count)} samples")

# Run verification
print("=== Training Set ===")
verify_class_distribution(train_dataset, selected_classes)
print("\n=== Validation Set ===")
verify_class_distribution(val_dataset, selected_classes)
print("\n=== Test Set ===")
verify_class_distribution(test_dataset, selected_classes)

=== Training Set ===

Class distribution:
 0 Cherry_(including_sour)___Powdery_mildew: 500 samples
 1 Cherry_(including_sour)___healthy: 500 samples
 2 Corn_(maize)___Common_rust_: 500 samples
 3 Corn_(maize)___Northern_Leaf_Blight: 500 samples
 4 Corn_(maize)___healthy: 500 samples
 5 Grape___Black_rot: 500 samples
 6 Grape___Esca_(Black_Measles): 500 samples
 7 Grape___Leaf_blight_(Isariopsis_Leaf_Spot): 500 samples
 8 Pepper,_bell___Bacterial_spot: 500 samples
 9 Pepper,_bell___healthy: 500 samples
10 Potato___Early_blight: 500 samples
11 Potato___Late_blight: 500 samples
12 Tomato___Bacterial_spot: 500 samples
13 Tomato___Early_blight: 500 samples
14 Tomato___Late_blight: 500 samples
15 Tomato___Leaf_Mold: 500 samples
16 Tomato___Septoria_leaf_spot: 500 samples
17 Tomato___Spider_mites Two-spotted_spider_mite: 500 samples
18 Tomato___Target_Spot: 500 samples
19 Tomato___Tomato_Yellow_Leaf_Curl_Virus: 500 samples
20 Tomato___healthy: 500 samples

=== Validation Set ===

Class distri

In [8]:
import torch
images, labels = next(iter(train_loader))
print("Image batch shape:", images.shape)
print("Image batch dtype:", images.dtype)
print("Image batch min value:", torch.min(images))
print("Image batch max value:", torch.max(images))
print("\nLabel batch shape:", labels.shape)
print("Label batch dtype:", labels.dtype)
print("Label batch min value:", torch.min(labels))
print("Label batch max value:", torch.max(labels))
print("Number of unique labels in the batch:", len(torch.unique(labels)))

Image batch shape: torch.Size([32, 3, 512, 512])
Image batch dtype: torch.float32
Image batch min value: tensor(0.)
Image batch max value: tensor(1.)

Label batch shape: torch.Size([32])
Label batch dtype: torch.int64
Label batch min value: tensor(0)
Label batch max value: tensor(20)
Number of unique labels in the batch: 20


In [9]:
def train_model(model, optimizer, criterion, dataloader, device, epochs=5):
    model.train()
    for epoch in range(epochs):
        correct, total = 0, 0
        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for x, y in loop:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            loop.set_postfix(loss=loss.item(), acc=correct/total)
    return correct / total

def evaluate(model, dataloader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

def time_inference(model, dataloader, device, repetitions=30):
    model.eval()
    start = time.time()
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            if i >= repetitions: break
            x = x.to(device)
            _ = model(x)
    end = time.time()
    return (end - start) * 1000  # ms



In [10]:
def save_sketched_weights_per_head(model, s_q_path, s_k_path, save_dir_q, save_dir_k):
    os.makedirs(save_dir_q, exist_ok=True)
    os.makedirs(save_dir_k, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    depth = len(model.blocks)
    for l in range(depth):
        attn = model.blocks[l].attn
        for h in range(attn.num_heads):
            Wq = attn.Wq_list[h].weight.data
            Wk = attn.Wk_list[h].weight.data

            
            s_q_file = f"{s_q_path}_q.pt" if attn.num_heads == 1 else f"{s_q_path}_head{h}_q.pt"
            s_k_file = f"{s_k_path}_k.pt" if attn.num_heads == 1 else f"{s_k_path}_head{h}_k.pt"
            # Sq_h = torch.load(s_q_file, weights_only=False).clone().detach().to(device)
            # Sk_h = torch.load(s_k_file, weights_only=False).clone().detach().to(device)
            Sq_h = attn.S_q_list[h].detach().to(device)
            Sk_h = attn.S_k_list[h].detach().to(device)


            Wq_sk = Wq @ Sq_h.T
            Wk_sk = Wk @ Sk_h.T

            torch.save(Wq_sk, f"{save_dir_q}/Wq_layer{l}_head{h}.pt")
            torch.save(Wk_sk, f"{save_dir_k}/Wk_layer{l}_head{h}.pt")

In [11]:
from tqdm import tqdm
# EXPERIMENT
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Shared params
# base_s_q_path = "/kaggle/working/gauss_matrix_all/gauss_sketch_matrix"
# base_s_k_path = "/kaggle/working/gauss_matrix_all/gauss_sketch_matrix"
WQ_SAVE_DIR = "/kaggle/working/sketched_weights_q"
WK_SAVE_DIR = "/kaggle/working/sketched_weights_k"

depth = 4
num_heads = 4   # Set to 1 for single-head
sketch_dim = 256
dim = 768


head_dim = dim if num_heads == 1 else dim // num_heads  # 768 (single-head) or 192 (multi-head)
sketch_head_dim = sketch_dim if num_heads == 1 else sketch_dim // num_heads  # 384 or 96

# # Generate sketch matrices
# if num_heads == 1:
#     gauss_sketch_matrix_file(sketch_head_dim, head_dim, 2,
#                              f"{base_s_q_path}_q.pt",
#                              f"{base_s_k_path}_k.pt")
# else:
#     for h in range(num_heads):
#         gauss_sketch_matrix_file(sketch_head_dim, head_dim, 2,
#                                  f"{base_s_q_path}_head{h}_q.pt",
#                                  f"{base_s_k_path}_head{h}_k.pt")
# criterion = nn.CrossEntropyLoss()

import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Normal ViT
normal_model = VisionTransformer(depth=depth, num_heads=num_heads,
                                 use_sketch=False).to(device)
optimizer = torch.optim.Adam(normal_model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
print("Training Normal ViT")
acc_train_normal = train_model(normal_model, optimizer, criterion, train_loader, device, epochs=10)
acc_val_normal = evaluate(normal_model, val_loader, device)  
acc_test_normal = evaluate(normal_model, test_loader, device)
import time
inf_time_normal = time_inference(normal_model, test_loader, device)

Training Normal ViT


Epoch 1/10: 100%|██████████| 329/329 [08:30<00:00,  1.55s/it, acc=0.233, loss=1.77] 
Epoch 2/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.544, loss=1.94] 
Epoch 3/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.684, loss=0.713]
Epoch 4/10: 100%|██████████| 329/329 [08:41<00:00,  1.59s/it, acc=0.784, loss=0.242]
Epoch 5/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.852, loss=0.038]
Epoch 6/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.895, loss=0.0902]
Epoch 7/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.921, loss=0.727] 
Epoch 8/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.937, loss=0.121] 
Epoch 9/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.945, loss=0.0694]
Epoch 10/10: 100%|██████████| 329/329 [08:43<00:00,  1.59s/it, acc=0.956, loss=0.00796]


In [15]:
print("Training Sketched ViT")
sketch_model = VisionTransformer(depth=depth, num_heads=num_heads,
                                 sketch_dim=sketch_dim,
                                 use_sketch=True, train_mode=True).to(device)
optimizer_sketch = torch.optim.Adam(sketch_model.parameters(), lr=0.0001)

acc_train_sketch = train_model(sketch_model, optimizer_sketch, criterion, train_loader, device, epochs=10)
acc_val_sketch = evaluate(sketch_model, val_loader, device)  
acc_test_sketch = evaluate(sketch_model, test_loader, device)

dummy_path = "unused_dummy_path"

from pathlib import Path
Path(WQ_SAVE_DIR).mkdir(parents=True, exist_ok=True)
Path(WK_SAVE_DIR).mkdir(parents=True, exist_ok=True)

save_sketched_weights_per_head(sketch_model,dummy_path, dummy_path, WQ_SAVE_DIR, WK_SAVE_DIR)

# Sketched ViT: INFERENCE ONLY
sketched_infer_model = VisionTransformer(depth=depth, num_heads=num_heads,
                                         sketch_dim=sketch_dim,
                                         use_sketch=True, train_mode=False,
                                         s_q_path=f"{WQ_SAVE_DIR}/Wq",s_k_path=f"{WK_SAVE_DIR}/Wk").to(device)
inf_time_sketch = time_inference(sketched_infer_model, test_loader, device)

print("\nFINAL RESULTS")
print(f"Normal ViT     → Train Acc: {acc_train_normal:.4f}, Test Acc: {acc_test_normal:.4f}, Inf Time: {inf_time_normal:.2f} ms")
print(f"Sketched ViT   → Train Acc: {acc_train_sketch:.4f}, Test Acc: {acc_test_sketch:.4f}, Inf Time: {inf_time_sketch:.2f} ms")

Training Sketched ViT


Epoch 1/10: 100%|██████████| 329/329 [08:04<00:00,  1.47s/it, acc=0.225, loss=1.44] 
Epoch 2/10: 100%|██████████| 329/329 [08:04<00:00,  1.47s/it, acc=0.536, loss=2.39] 
Epoch 3/10: 100%|██████████| 329/329 [08:05<00:00,  1.48s/it, acc=0.673, loss=0.584]
Epoch 4/10: 100%|██████████| 329/329 [08:05<00:00,  1.47s/it, acc=0.783, loss=0.341]
Epoch 5/10: 100%|██████████| 329/329 [08:05<00:00,  1.48s/it, acc=0.845, loss=0.297] 
Epoch 6/10: 100%|██████████| 329/329 [08:05<00:00,  1.48s/it, acc=0.893, loss=0.0104]
Epoch 7/10: 100%|██████████| 329/329 [08:04<00:00,  1.47s/it, acc=0.913, loss=2.35]  
Epoch 8/10: 100%|██████████| 329/329 [08:07<00:00,  1.48s/it, acc=0.928, loss=0.149] 
Epoch 9/10: 100%|██████████| 329/329 [08:05<00:00,  1.48s/it, acc=0.952, loss=0.529] 
Epoch 10/10: 100%|██████████| 329/329 [08:05<00:00,  1.48s/it, acc=0.959, loss=0.0485]



FINAL RESULTS
Normal ViT     → Train Acc: 0.9556, Test Acc: 0.9010, Inf Time: 18207.97 ms
Sketched ViT   → Train Acc: 0.9588, Test Acc: 0.9219, Inf Time: 16006.45 ms


In [16]:
inf_time_normal_val = time_inference(normal_model, val_loader, device)
inf_time_sketch_val = time_inference(sketched_infer_model, val_loader, device)

print(f"Normal val ViT     → val Acc: {acc_val_normal :.4f}, Inf Time val: {inf_time_normal_val:.2f} ms")
print(f"sketched val  ViT     → val Acc: {acc_val_sketch:.4f}, Inf Time val: {inf_time_sketch_val:.2f} ms")

Normal val ViT     → val Acc: 0.9124, Inf Time val: 18403.46 ms
sketched val  ViT     → val Acc: 0.9214, Inf Time val: 17313.93 ms


In [17]:
import time
import numpy as np
import torch

def time_inference(model, dataloader, device, repetitions=4, verbose=True):
    model.eval()
    timings = []

    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats(device)

    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            if i >= repetitions:
                break
            x = x.to(device)

            start = time.time()
            _ = model(x)
            torch.cuda.synchronize() if device.type == "cuda" else None
            end = time.time()

            batch_time = (end - start) * 1000  # ms
            timings.append(batch_time)

            if verbose:
                print(f"[Batch {i+1}] Time: {batch_time:.2f} ms")

    avg_time = np.mean(timings) if timings else 0.0
    total_time = np.sum(timings)
    peak_memory_MB = 0.0

    if device.type == "cuda":
        peak_memory_MB = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
        if verbose:
            print(f"[CUDA] Peak memory usage: {peak_memory_MB:.2f} MB")

    if verbose:
        print(f"\n[Inference Summary] Avg Time/Batch: {avg_time:.2f} ms | Total Time: {total_time:.2f} ms")

    return avg_time, peak_memory_MB


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inf_time_normal, mem_normal = time_inference(normal_model, test_loader, device)
inf_time_sketch, mem_sketch = time_inference(sketched_infer_model, test_loader, device)

print(f"Normal Model  → Time: {inf_time_normal:.2f} ms | Peak Memory: {mem_normal:.2f} MB")
print(f"Sketched Model → Time: {inf_time_sketch:.2f} ms | Peak Memory: {mem_sketch:.2f} MB")

[Batch 1] Time: 590.98 ms
[Batch 2] Time: 549.99 ms
[Batch 3] Time: 549.81 ms
[Batch 4] Time: 561.22 ms
[CUDA] Peak memory usage: 2212.80 MB

[Inference Summary] Avg Time/Batch: 563.00 ms | Total Time: 2252.00 ms
[Batch 1] Time: 482.99 ms
[Batch 2] Time: 482.77 ms
[Batch 3] Time: 487.72 ms
[Batch 4] Time: 492.26 ms
[CUDA] Peak memory usage: 2181.68 MB

[Inference Summary] Avg Time/Batch: 486.44 ms | Total Time: 1945.75 ms
Normal Model  → Time: 563.00 ms | Peak Memory: 2212.80 MB
Sketched Model → Time: 486.44 ms | Peak Memory: 2181.68 MB


In [19]:
def fuse_sketch_into_model(model, save_path=None):
    """
    Fuses Wq and Wk with their learned sketch matrices in-place,
    deletes S_q and S_k parameters, and optionally saves the cleaned model.
    """
    device = next(model.parameters()).device
    model.eval()

    for layer_idx, block in enumerate(model.blocks):
        attn = block.attn
        if not attn.use_sketch or not attn.train_mode:
            continue  # Skip if not sketch mode or already fused

        for h in range(attn.num_heads):
            # Get original Wq, Wk and sketch matrices
            Wq = attn.Wq_list[h].weight.data.to(device)
            Wk = attn.Wk_list[h].weight.data.to(device)
            Sq = attn.S_q_list[h].detach().to(device)
            Sk = attn.S_k_list[h].detach().to(device)

            # Fused projections
            Wq_fused = Wq @ Sq.T  # [head_dim, sketch_head_dim]
            Wk_fused = Wk @ Sk.T

            # Replace with fused Linear layer
            fused_q = nn.Linear(attn.head_dim, attn.sketch_head_dim, bias=False).to(device)
            fused_k = nn.Linear(attn.head_dim, attn.sketch_head_dim, bias=False)
            fused_q.weight.data.copy_(Wq_fused.T)
            fused_k.weight.data.copy_(Wk_fused.T)

            attn.Wq_list[h] = fused_q
            attn.Wk_list[h] = fused_k

        # Clean up sketch params
        attn.S_q_list = nn.ParameterList()
        attn.S_k_list = nn.ParameterList()
        attn.train_mode = False  # Mark as fused/inference mode

    print("[INFO] Sketch weights fused into model.")

    # Optional: Save clean model state_dict
    if save_path:
        full_state = model.state_dict()
        clean_state = {k: v for k, v in full_state.items() if "S_q" not in k and "S_k" not in k}
        torch.save(clean_state, save_path)
        print(f"[SAVED] Fused model saved to: {save_path}")

    return model


In [21]:
model_save_path = "/kaggle/working/normal_model_trained.pth"
torch.save(normal_model.state_dict(), model_save_path)

In [22]:
fused_path = "/kaggle/working/sketch_model_fused_clean.pth"
fuse_sketch_into_model(sketch_model, save_path=fused_path)


[INFO] Sketch weights fused into model.
[SAVED] Fused model saved to: /kaggle/working/sketch_model_fused_clean.pth


VisionTransformer(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): Sequential(
    (0): SimpleTransformerBlock(
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): SketchSelfAttention(
        (Wq_list): ModuleList(
          (0-3): 4 x Linear(in_features=192, out_features=64, bias=False)
        )
        (Wk_list): ModuleList(
          (0-3): 4 x Linear(in_features=192, out_features=64, bias=False)
        )
        (Wv_list): ModuleList(
          (0-3): 4 x Linear(in_features=192, out_features=192, bias=False)
        )
        (S_q_list): ParameterList()
        (S_k_list): ParameterList()
        (softmax): Softmax(dim=-1)
        (output_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=1536, bias=True)
        (1): GELU(a

In [23]:
fused_model = VisionTransformer(depth=4, num_heads=4, sketch_dim=256,
                                use_sketch=True, train_mode=False).to(device)
fused_model.load_state_dict(torch.load(fused_path))
fused_model.eval()


VisionTransformer(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): Sequential(
    (0): SimpleTransformerBlock(
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): SketchSelfAttention(
        (Wq_list): ModuleList(
          (0-3): 4 x Linear(in_features=192, out_features=64, bias=False)
        )
        (Wk_list): ModuleList(
          (0-3): 4 x Linear(in_features=192, out_features=64, bias=False)
        )
        (Wv_list): ModuleList(
          (0-3): 4 x Linear(in_features=192, out_features=192, bias=False)
        )
        (softmax): Softmax(dim=-1)
        (output_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=1536, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1536, out_features=76

In [24]:
device = torch.device("cpu")
fused_model.to("cpu")
normal_model.to("cpu")
inf_time_sketch, mem_sketch = time_inference(fused_model, test_loader, device,repetitions=4)
inf_time_normal, mem_normal = time_inference(normal_model, test_loader, device, repetitions=4)

[Batch 1] Time: 13126.11 ms
[Batch 2] Time: 11990.22 ms
[Batch 3] Time: 12333.32 ms
[Batch 4] Time: 12425.33 ms

[Inference Summary] Avg Time/Batch: 12468.75 ms | Total Time: 49874.98 ms
[Batch 1] Time: 13203.80 ms
[Batch 2] Time: 13395.11 ms
[Batch 3] Time: 12937.77 ms
[Batch 4] Time: 12913.56 ms

[Inference Summary] Avg Time/Batch: 13112.56 ms | Total Time: 52450.25 ms


In [26]:
import shutil
import os

folder_to_zip = '/kaggle/working/sketched_weights_k' 
output_zip = '/kaggle/working/sketched_weights_k.zip'  
shutil.make_archive(output_zip.replace('.zip', ''), 'zip', folder_to_zip)
if os.path.exists(output_zip):
    print(f"Folder zipped successfully as {output_zip}")
else:
    print("Failed to create zip file")

Folder zipped successfully as /kaggle/working/sketched_weights_k.zip


In [28]:
import os
import torch

def save_learned_sketch_matrices(model, save_dir="/kaggle/working//learned_sketches"):
    os.makedirs(save_dir, exist_ok=True)
    device = next(model.parameters()).device

    for l, block in enumerate(model.blocks):
        attn = block.attn
        if not attn.use_sketch or not attn.train_mode:
            continue

        for h in range(attn.num_heads):
            Sq = attn.S_q_list[h].detach().cpu()
            Sk = attn.S_k_list[h].detach().cpu()

            torch.save(Sq, os.path.join(save_dir, f"S_q_layer{l}_head{h}.pt"))
            torch.save(Sk, os.path.join(save_dir, f"S_k_layer{l}_head{h}.pt"))

    print(f"[INFO] Saved sketch matrices to: {save_dir}")


In [31]:
!ls -l /kaggle/working/learned_sketches


total 0


In [32]:
# 1. Save learned sketch matrices
save_learned_sketch_matrices(sketch_model, save_dir="/kaggle/working/learned_sketches")

[INFO] Saved sketch matrices to: /kaggle/working/learned_sketches
