# **Initialization & Session Params**

In [1]:
__MODEL__ = 'vit_b_16'           # 'Fourier' 'SinCos2D' 'swin', 'deit_b_16' , 'vit_b_16', 'beit_b_16', 'crossvit_base_240' 'vit_art'
__DATASET__ = 'imagenet100'           # 'oxford_iiit', 'dtd' , 'caltech',  'cifar100' 'imagenet100'
__ENCODER__ = 'sin_cos_1d'             # 'sin_cos_1d',  'learnable_1d', 'hilbert', 'no_pe'
__PATCH__ = 16                    # 16, 32, 8
__TVT_SPLIT__ = (80,10,10)
__BATCH_SIZE_TRAIN__ = 96
__BATCH_SIZE_VALID__ = 96
__IMG_SIZE__ = 224
__XIMG_SIZE__ = 256
START_EPOCH = 1
__EPOCHS__ = 500
__MAX_LR__ = 0.0001
__MIN_LR__ = 0.0001/20
__OPTIM__ = 'adam'
__DEVICE__ = 'gpu'
__WARNING__ = 'supressed'
__STEP_COUNTER__ = 50
best_acc = 0.0
best_epoch = 0

In [2]:
import warnings
if __WARNING__ == 'supressed':
    warnings.filterwarnings("ignore")
import torchvision.models as models
import torch.nn as nn
import torch
import time
import gc
from torchinfo import summary
import timm
import torch.nn.functional as F
from transformers import ViTForImageClassification, ViTConfig
from transformers import BeitForImageClassification
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import albumentations as A
import torchvision.transforms as T
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import sys
import time
import sklearn
import re
import random
from tqdm.notebook import tqdm
import skimage.io as io
import skimage.color as color
from skimage.feature import local_binary_pattern
from torch import optim 
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data

2026-01-26 18:54:09.504184: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769453649.528388     241 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769453649.535904     241 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769453649.555157     241 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769453649.555186     241 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769453649.555189     241 computation_placer.cc:177] computation placer alr

In [3]:
BASE_PATH = f"weights_(model-{__MODEL__})_(data-{__DATASET__})_(enc-{__ENCODER__})_(p-{__PATCH__})_(im-{__IMG_SIZE__}).pth"
LAST_PATH = f"weights(last)_(model-{__MODEL__})_(data-{__DATASET__})_(enc-{__ENCODER__})_(p-{__PATCH__})_(im-{__IMG_SIZE__}).pth"
print(BASE_PATH)
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    device = "cuda:0"
else:
    device = "cpu"
print("DEVICE WE WILL BE USING IS ", device)

weights_(model-vit_b_16)_(data-imagenet100)_(enc-sin_cos_1d)_(p-16)_(im-224).pth
DEVICE WE WILL BE USING IS  cuda:0


# **Dataset Preparation**

In [4]:
class RandomRotate90:
    def __init__(self, p=0.5):
        self.p = p
    def __call__(self, img):
        if random.random() < self.p:
            angle = random.choice([0, 90, 180, 270])
            return T.functional.rotate(img, angle)
        return img

class RandomGamma:
    def __init__(self, gamma_range=(0.8, 1.2), p=0.5):
        self.gamma_range = gamma_range
        self.p = p
    def __call__(self, img):
        if random.random() < self.p:
            gamma = random.uniform(*self.gamma_range)
            return T.functional.adjust_gamma(img, gamma)
        return img

**Cifar-100**

In [5]:
class cifar100(Dataset):
    def __init__(self,files,cls_dict,mode='train'):
        self.mode = mode
        self.name = 'cifar_100'
        self.cls_dict = cls_dict
        self.num_cls = cls_dict.__len__()
        print(files[33][2].shape)
        self.images = [np.array(file[2]).reshape((3,32,32)).transpose((1,2,0)) for file in files]
        self.labels = [file[1] for file in files]
            
        if self.mode == 'train':
            self.transforms = A.Compose([
                                    A.Resize(__XIMG_SIZE__,__XIMG_SIZE__),
                                    A.HorizontalFlip(p=0.5),
                                    A.Rotate(limit=15, p=0.5),
                                    A.RandomCrop(__IMG_SIZE__, __IMG_SIZE__),
                                    ToTensorV2()
                                    ])
        else :
            self.transforms = A.Compose([
                                    A.Resize(__IMG_SIZE__, __IMG_SIZE__),
                                    ToTensorV2()
                                ])
    
    def __len__(self):
        return len(self.labels)
    def __getitem__(self,idx):
        image = self.images[idx]
        label = self.labels[idx]
        image = self.transforms(image=image)['image']/255.00
        temp = torch.zeros(self.num_cls).float()
        temp[label] = 1.0
        
        return image, temp

In [6]:
if __DATASET__ == 'cifar100' or __DATASET__ == 'pretest':
    try:
        split_dir = "/kaggle/input/cifar100-splits/cifar100/"
        if not os.path.isdir(split_dir):
            raise FileNotFoundError(f"Directory '{directory}' does not exist!")
        train_files = np.load(split_dir+"train_files.npz", allow_pickle=True)["data"]
        valid_files = np.load(split_dir+"valid_files.npz", allow_pickle=True)["data"]
        test_files = np.load(split_dir+"test_files.npz", allow_pickle=True)["data"]
        print('Successfully preloaded train, valid, test splits.....')
    except:
        def unpickle(file):
            import pickle
            with open(file, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
            return dict
        train_dir = '/kaggle/input/cifar100/train'
        train_dict = unpickle(train_dir)
        lst = list(train_dict.keys())
        length = train_dict[lst[0]].__len__()
        files = [(train_dict[lst[0]][i], train_dict[lst[2]][i], train_dict[lst[4]][i]) for i in range(length)]
        train_files, vt_files = train_test_split(files, test_size = (__TVT_SPLIT__[1]+__TVT_SPLIT__[2])/100, random_state=42)
        valid_files, test_files = train_test_split(vt_files, test_size = (__TVT_SPLIT__[2])/(__TVT_SPLIT__[1]+__TVT_SPLIT__[2]), random_state=42)
        try:
            os.mkdir("/kaggle/working/cifar100")
        except:
            print("sub-folder already created")
        np.savez('/kaggle/working/cifar100/train_files.npz', data=np.array(train_files,dtype=object))
        np.savez('/kaggle/working/cifar100/valid_files.npz', data=np.array(valid_files,dtype=object))
        np.savez('/kaggle/working/cifar100/test_files.npz', data=np.array(test_files,dtype=object))
    print("total classes: ", 100)
    classes = [i for i in range(100)]
    train_dataset = cifar100(train_files, classes, 'train')
    valid_dataset = cifar100(valid_files, classes, 'valid')
    test_dataset = cifar100(test_files, classes, 'test')

# ImageNeT100

In [7]:
# =========================
# ImageNet100 drop-in setup
# =========================

import os, glob
import numpy as np
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2


# ---- (recommended) for ViT scratch training on ImageNet-like data
# Keep your existing __IMG_SIZE__ / __XIMG_SIZE__ if already defined.
# Typical: __IMG_SIZE__=224, __XIMG_SIZE__=256
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

# IMPORTANT:
# Your existing pipeline does: transforms(...)->ToTensorV2() then "/255.0" in __getitem__.
# Albumentations Normalize would become wrong if you keep "/255.0".
# So we normalize using a Lambda that is compatible with the later "/255.0":
#   output = (x - mean*255)/std   and later /255 => (x/255 - mean)/std  (correct)
def _imagenet_norm_keep_div255(img, **kwargs):
    img = img.astype(np.float32)
    mean = np.array(IMAGENET_MEAN, dtype=np.float32) * 255.0
    std  = np.array(IMAGENET_STD,  dtype=np.float32)
    return (img - mean) / std

class imagenet100(Dataset):
    def __init__(self, files, cls_dict, mode='train', preload_ram=True):
        self.mode = mode
        self.name = 'imagenet_100'
        self.cls_dict = cls_dict
        self.num_cls = cls_dict.__len__()

        self.labels = [int(file[1]) for file in files]
        self.preload_ram = preload_ram

        if self.preload_ram:
            # Preload into a single contiguous uint8 array to minimize overhead
            N = len(files)
            self.images = np.empty((N, __IMG_SIZE__, __IMG_SIZE__, 3), dtype=np.uint8)

            for i, file in enumerate(files):
                p = file[2]
                img = cv2.imread(p, cv2.IMREAD_COLOR)
                if img is None:
                    raise FileNotFoundError(f"Failed to read image: {p}")
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # uint8

                # Safety: if any image isn't exactly 224x224, resize once during preload
                if img.shape[0] != __IMG_SIZE__ or img.shape[1] != __IMG_SIZE__:
                    img = cv2.resize(img, (__IMG_SIZE__, __IMG_SIZE__), interpolation=cv2.INTER_AREA)

                self.images[i] = img
        else:
            self.images = [file[2] for file in files]  # paths

        # transforms unchanged
        if self.mode == 'train':
            self.transforms = A.Compose([
                A.RandomResizedCrop(
                    size=(__IMG_SIZE__, __IMG_SIZE__),
                    scale=(0.08, 1.0),
                    ratio=(0.75, 1.3333333),
                    interpolation=cv2.INTER_CUBIC,
                    p=1.0
                ),
                A.HorizontalFlip(p=0.5),
                A.OneOf([
                    A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=1.0),
                    A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=1.0),
                    A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=1.0),
                    A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=1.0),
                ], p=0.8),
                A.OneOf([
                    A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
                    A.GaussianBlur(blur_limit=(3, 5), p=1.0),
                    A.MotionBlur(blur_limit=(3, 5), p=1.0),
                ], p=0.2),
                A.CoarseDropout(
                    max_holes=1,
                    max_height=int(__IMG_SIZE__ * 0.25),
                    max_width=int(__IMG_SIZE__ * 0.25),
                    min_holes=1,
                    min_height=int(__IMG_SIZE__ * 0.10),
                    min_width=int(__IMG_SIZE__ * 0.10),
                    fill_value=0,
                    p=0.25
                ),
                A.Lambda(image=_imagenet_norm_keep_div255),
                ToTensorV2(),
            ])
        else:
            self.transforms = A.Compose([
                A.Resize(__XIMG_SIZE__, __XIMG_SIZE__, interpolation=cv2.INTER_CUBIC),
                A.CenterCrop(__IMG_SIZE__, __IMG_SIZE__),
                A.Lambda(image=_imagenet_norm_keep_div255),
                ToTensorV2(),
            ])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.preload_ram:
            img = self.images[idx]  # uint8 RGB
        else:
            path = self.images[idx]
            img = cv2.imread(path, cv2.IMREAD_COLOR)
            if img is None:
                raise FileNotFoundError(f"Failed to read image: {path}")
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        img = self.transforms(image=img)['image'] / 255.0

        label = self.labels[idx]
        temp = torch.zeros(self.num_cls).float()
        temp[label] = 1.0
        return img, temp


In [8]:
# =========================
# Split handling (drop-in)
# =========================

if __DATASET__ == 'imagenet100' or __DATASET__ == 'pretest':
    root_dir = "/kaggle/input/processed-imagenet-dataset-224"   # class folders live here
    split_dir = "/kaggle/working/imagenet100"
    os.makedirs(split_dir, exist_ok=True)

    def _load_splits(dirpath):
        tr = np.load(os.path.join(dirpath, "train_files.npz"), allow_pickle=True)["data"]
        va = np.load(os.path.join(dirpath, "valid_files.npz"), allow_pickle=True)["data"]
        te = np.load(os.path.join(dirpath, "test_files.npz"),  allow_pickle=True)["data"]
        return tr, va, te

    # Try load cached splits (first from an input dataset if you have one, else from working)
    loaded = False
    for cand in [
        "/kaggle/input/imagenet100-splits/imagenet100",  # optional (if you created/uploaded)
        split_dir
    ]:
        try:
            if os.path.isfile(os.path.join(cand, "train_files.npz")):
                train_files, valid_files, test_files = _load_splits(cand)
                print(f"Successfully preloaded train/valid/test splits from: {cand}")
                loaded = True
                break
        except:
            pass

    if not loaded:
        # discover class folders (e.g., n01440764)
        classes_synset = sorted([
            d for d in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, d)) and not d.startswith(".")
        ])
        cls2idx = {c: i for i, c in enumerate(classes_synset)}

        # build file tuples: (placeholder, label_index, image_path)
        files = []
        exts = ("*.JPEG", "*.JPG", "*.jpeg", "*.jpg", "*.png")
        for c in classes_synset:
            cdir = os.path.join(root_dir, c)
            for ext in exts:
                for p in glob.glob(os.path.join(cdir, ext)):
                    files.append((None, cls2idx[c], p))

        # stratified TVT split like your CIFAR code
        labels = [f[1] for f in files]
        vt_ratio = (__TVT_SPLIT__[1] + __TVT_SPLIT__[2]) / 100.0

        train_files, vt_files = train_test_split(
            files,
            test_size=vt_ratio,
            random_state=42,
            stratify=labels
        )

        vt_labels = [f[1] for f in vt_files]
        test_ratio_inside_vt = (__TVT_SPLIT__[2]) / (__TVT_SPLIT__[1] + __TVT_SPLIT__[2])

        valid_files, test_files = train_test_split(
            vt_files,
            test_size=test_ratio_inside_vt,
            random_state=42,
            stratify=vt_labels
        )

        # cache
        np.savez(os.path.join(split_dir, "train_files.npz"), data=np.array(train_files, dtype=object))
        np.savez(os.path.join(split_dir, "valid_files.npz"), data=np.array(valid_files, dtype=object))
        np.savez(os.path.join(split_dir, "test_files.npz"),  data=np.array(test_files,  dtype=object))

        print(f"Created and cached splits in: {split_dir}")
        print("Classes discovered:", len(classes_synset))

    # Your original code passes "classes" only to define num_cls.
    # Since labels are already integer indices 0..C-1, we pass range(C).
    # If you loaded cached splits, infer C from labels.
    all_labels = np.array([f[1] for f in train_files], dtype=np.int64)
    num_classes = int(all_labels.max()) + 1
    print("total classes:", num_classes)

    classes = [i for i in range(num_classes)]
    train_dataset = imagenet100(train_files, classes, 'train')
    valid_dataset = imagenet100(valid_files, classes, 'valid')
    test_dataset  = imagenet100(test_files,  classes, 'test')

Successfully preloaded train/valid/test splits from: /kaggle/working/imagenet100
total classes: 100


In [9]:
__NUM_CLASS__ = train_dataset.num_cls
print(f"DATASET LOADED NAMED: {train_dataset.name} WITH NUM OF CLASS: {__NUM_CLASS__}")

DATASET LOADED NAMED: imagenet_100 WITH NUM OF CLASS: 100


# **Positional Encoders**

**No PE**

In [10]:
def get_zero_pos_embed(embed_dim, seq_length):
    return nn.Parameter(torch.zeros((1, seq_length, embed_dim), dtype=torch.float32), requires_grad=False)

**Sin-Cos 1D PE**

In [11]:
def get_1d_sincos_pos_embed(embed_dim, seq_length):
    position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(1)  # (seq_length, 1)
    div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float32) * -(np.log(10000.0) / embed_dim))
    pos_embed = torch.zeros((seq_length, embed_dim), dtype=torch.float32)
    pos_embed[:, 0::2] = torch.sin(position * div_term)  # apply sin to even indices in the embedding
    pos_embed[:, 1::2] = torch.cos(position * div_term)  # apply cos to odd indices in the embedding
    return nn.Parameter(pos_embed.unsqueeze(0),requires_grad=False)  # add a batch dimension so shape becomes (1, seq_length, embed_dim)

**1D learnable PE**

In [12]:
def get_1d_learnable_pos_embed(embed_dim, seq_length):
     return nn.Parameter(torch.empty(1, seq_length, embed_dim).normal_(std=0.02), requires_grad=True)

**Hilbert PE**

In [13]:
def gilbert_2d(width, height):
    def sgn(x):
        return 1 if x > 0 else -1 if x < 0 else 0
    
    def generate(x, y, ax, ay, bx, by):
        w = abs(ax + ay)
        h = abs(bx + by)
        dax, day = sgn(ax), sgn(ay)
        dbx, dby = sgn(bx), sgn(by)

        if h == 1:
            return [(x + dax*i, y + day*i) for i in range(w)]
        if w == 1:
            return [(x + dbx*i, y + dby*i) for i in range(h)]

        ax2, ay2 = ax//2, ay//2
        bx2, by2 = bx//2, by//2
        w2, h2 = abs(ax2 + ay2), abs(bx2 + by2)

        if 2*w > 3*h:
            if w2%2 and w>2:
                ax2 += dax
                ay2 += day
            return generate(x, y, ax2, ay2, bx, by) + \
                   generate(x+ax2, y+ay2, ax-ax2, ay-ay2, bx, by)
        else:
            if h2%2 and h>2:
                bx2 += dbx
                by2 += dby
            return generate(x, y, bx2, by2, ax2, ay2) + \
                   generate(x+bx2, y+by2, ax, ay, bx-bx2, by-by2) + \
                   generate(x+(ax-dax)+(bx2-dbx), 
                            y+(ay-day)+(by2-dby), 
                            -bx2, -by2, -ax+ax2, -ay+ay2)

    # Generate coordinates and create index array
    curve = generate(0, 0, width, 0, 0, height)
    arr = np.zeros((height, width), dtype=int)
    for idx, (x, y) in enumerate(curve):
        arr[y, x] = idx+1  # Note y is first dimension in numpy arrays
    
    return arr

In [14]:
def generate_hilbert_grid(N):
    def rot(n, x, y, rx, ry):
        if ry == 0:
            if rx == 1:
                x, y = n - 1 - x, n - 1 - y
            x, y = y, x
        return x, y

    def hilbert_index(n, d):
        x = y = 0
        t = d
        s = 1
        while s < n:
            rx = (t // 2) & 1
            ry = (t ^ rx) & 1
            x, y = rot(s, x, y, rx, ry)
            x += s * rx
            y += s * ry
            t //= 4
            s *= 2
        return x, y

    grid = [[-1 for _ in range(N)] for _ in range(N)]
    for d in range(N * N):
        x, y = hilbert_index(N, d)
        grid[y][x] = d
    return grid

In [15]:
def get_hilbert_pos_embed(embed_dim, seq_length):
    N = int(np.sqrt(seq_length-1))
    grid = gilbert_2d(N, N)
    grid = torch.Tensor(grid).reshape(seq_length-1)
    position = torch.cat((torch.tensor([0]), grid)).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float32) * -(np.log(10000.0) / embed_dim))
    pos_embed = torch.zeros((seq_length, embed_dim), dtype=torch.float32)
    pos_embed[:, 0::2] = torch.sin(position * div_term)  # apply sin to even indices in the embedding
    pos_embed[:, 1::2] = torch.cos(position * div_term)  # apply cos to odd indices in the embedding
    return nn.Parameter(pos_embed.unsqueeze(0),requires_grad=False)

**PE Callers**

In [16]:
if __ENCODER__ == 'no_pe':
    pe_caller = get_zero_pos_embed
elif __ENCODER__ == 'sin_cos_1d' or __ENCODER__ == 'pretest':
    pe_caller = get_1d_sincos_pos_embed
elif __ENCODER__ == 'learnable_1d':
    pe_caller = get_1d_learnable_pos_embed
elif __ENCODER__ == 'hilbert':
    pe_caller = get_hilbert_pos_embed

# **MODEL**

In [17]:
class art_PE(nn.Module):
    def __init__(self, seq_length, hidden_dim, img_size=224, patch_size=16):
        super().__init__()
        self.seq_length, self.hidden_dim, self.img_size, self.patch_size = seq_length, hidden_dim, img_size, patch_size
        self.register_buffer('cord', torch.stack(torch.meshgrid(torch.linspace(-1,1,img_size, dtype=torch.float), torch.linspace(-1,1,img_size, dtype=torch.float), indexing='ij')))
        self.art_conv = nn.Sequential(
            nn.Conv2d(5, 32, 16, 16, 0), nn.GELU(),
            nn.Conv2d(32, 16, 5, 1, 2), nn.GELU(),
            nn.Conv2d(16, 8, 5, 1, 2), nn.GELU(),
            nn.Conv2d(8, 4, 5, 1, 2), nn.GELU(),
            nn.Conv2d(4, 1, 5, 1, 2), nn.GELU(),
            nn.BatchNorm2d(1), nn.Flatten(),
            nn.Linear((img_size // patch_size) ** 2, self.seq_length-1, bias=False), nn.Sigmoid()
        )
        #self.wrap_pos = nn.Parameter(torch.zeros(seq_length, 1, dtype=torch.float32))
        self.register_buffer('wrap_pos_hilbert', torch.Tensor(gilbert_2d((img_size // patch_size), (img_size // patch_size))).reshape(self.seq_length-1).unsqueeze(1))
        self.register_buffer('div_term', torch.exp(torch.arange(0, hidden_dim, 2, dtype=torch.float32).unsqueeze(0) * -(np.log(10000.0) / hidden_dim)))
        #self.div_term = nn.Parameter(torch.randn(1,self.hidden_dim//2))
        self.cls_token = nn.Parameter(torch.zeros((1,self.hidden_dim)))
        #self.fc1 = nn.Linear(self.hidden_dim, 32, bias=True)
        #self.fc2 = nn.Linear(32, self.hidden_dim, bias=True)
        #self.activation = nn.GELU()

    def forward(self, x):
        n = x.shape[0]
        out = self.art_conv(torch.cat([x, self.cord.unsqueeze(0).expand(n, -1, -1, -1)], dim=1)).unsqueeze(2)
        out = 2*out-1
        out = out + self.wrap_pos_hilbert
        #out = self.wrap_pos_hilbert.unsqueeze(0).expand(n,-1,-1)
        sin_part, cos_part = torch.sin(out * self.div_term), torch.cos(out * self.div_term)
        pos_embedding = torch.empty((n, self.seq_length, self.hidden_dim), dtype=self.wrap_pos_hilbert.dtype, device=x.device)
        pos_embedding[:, 1:, 0::2], pos_embedding[:, 1:, 1::2] = sin_part, cos_part
        pos_embedding[:, :1, :] = self.cls_token.unsqueeze(0).expand(n,-1,-1)
        #pos_embedding = self.fc2(self.activation(self.fc1(pos_embedding)))
        return pos_embedding

In [18]:
class ViT_Art(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
        _, self.seq_length, self.hidden_dim = self.model.vit.embeddings.position_embeddings.shape
        self.model.vit.embeddings.position_embeddings = get_zero_pos_embed(self.hidden_dim, self.seq_length)
        self.pos_embedding = art_PE(self.seq_length, self.hidden_dim)
        self.patch = self.model.vit.embeddings.patch_embeddings
        self.encoder = self.model.vit.encoder
        self.layernorm = self.model.vit.layernorm
        self.classifier = nn.Linear(in_features=self.hidden_dim, out_features=__NUM_CLASS__, bias=True)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
    def forward(self, x):
        n = x.shape[0]
        patched = self.patch(x)
        batched_cls_token = self.cls_token.expand(n, -1, -1)
        patched = torch.cat([batched_cls_token, patched], dim=1)
        patched = patched + self.pos_embedding(x)
        patched = self.encoder(patched)
        patched = self.layernorm(patched.last_hidden_state)
        patched = patched[:, 0]
        patched = self.classifier(patched)
        return patched

In [19]:
if __MODEL__ == 'vit_art' or __MODEL__ == 'pretest':
    model = ViT_Art()

**Fourier**

In [20]:
if __MODEL__ == 'Fourier' or __MODEL__ == 'pretest':    
    class FourierPositionalEncoding(nn.Module):
        def __init__(self, num_patches, dim, fourier_dim=768, hidden_dim=32, groups=1):
            super().__init__()
            self.num_patches = num_patches  # N = 196
            self.fourier_dim = fourier_dim  # |F| = 768
            self.hidden_dim = hidden_dim    # |H| = 32
            self.dim = dim                  # D = 768
            self.groups = groups            # G = 1
            self.M = 2                      # M = 2D positional values

            grid_size = int(np.sqrt(self.num_patches))  # Typically 14x14 for 196 patches
            x = np.linspace(-1, 1, grid_size)
            y = np.linspace(-1, 1, grid_size)
            xx, yy = np.meshgrid(x, y)
            positions = np.stack([xx.flatten(), yy.flatten()], axis=-1)  # (num_patches, 2)
            #print(positions.shape)
            
            # Reshape to (N, G, M) = (196, 3, 2)
            self.register_buffer('positions', torch.tensor(positions, dtype=torch.float32).unsqueeze(1))  # (196, 1, 2)
            # Learnable Fourier weights Wr ∈ R^(|F|/2, M), sampled from N(0, γ^-2)
            self.Wr = nn.Parameter(torch.randn(fourier_dim//2, self.M))  # (384, 2)
            
            # MLP layers using Linear instead of manual weight parameters
            self.fc1 = nn.Linear(fourier_dim, hidden_dim, bias=True)  # (768, 32)
            self.fc2 = nn.Linear(hidden_dim, dim // groups, bias=True)  # (32, 256)
            self.activation = nn.GELU()
            self.cls_embed = nn.Parameter(torch.zeros((1, self.dim)))

        def forward(self, x):
            """
            Compute Fourier-based positional encoding.
            """    
            # Compute Fourier features F = [cos(XWr^T); sin(XWr^T)]
            proj = torch.matmul(self.positions, self.Wr.T)  # (196, 3, 384)
            #F = torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)  # (196, 3, 768)
            F = torch.empty((self.num_patches,self.groups, self.fourier_dim), dtype=proj.dtype, device=x.device)
            F[:,:,0::2] = torch.sin(proj)
            F[:,:,1::2] = torch.cos(proj)
            
    
            # Pass through MLP: Y = GeLU(FW1 + B1)W2 + B2
            Y = self.fc2(self.activation(self.fc1(F)))  # (196, 3, 256)
    
            # Reshape Y to (N, D) = (196, 768)
            PEX = Y.reshape(self.num_patches, self.dim)
            pos = torch.cat([self.cls_embed, PEX], dim=0)
    
            return pos+x
    class ViT_Fourier(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
            _, self.seq_length, self.hidden_dim = self.model.vit.embeddings.position_embeddings.shape
            self.model.vit.embeddings.position_embeddings = get_zero_pos_embed(self.hidden_dim, self.seq_length)
            num_patches = self.model.vit.embeddings.patch_embeddings.num_patches  # Should be 196 for 224x224 images with 16x16 patches
            dim = self.model.config.hidden_size  # Should be 768
            # Compute Fourier positional encoding
            self.pos_embedding = FourierPositionalEncoding(num_patches=num_patches, dim=dim, fourier_dim=768, hidden_dim=32, groups=1)
            self.patch = self.model.vit.embeddings.patch_embeddings
            self.encoder = self.model.vit.encoder
            self.layernorm = self.model.vit.layernorm
            self.classifier = nn.Linear(in_features=self.hidden_dim, out_features=__NUM_CLASS__, bias=True)
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
        def forward(self, x):
            n = x.shape[0]
            patched = self.patch(x)
            batched_cls_token = self.cls_token.expand(n, -1, -1)
            patched = torch.cat([batched_cls_token, patched], dim=1)
            patched = self.pos_embedding(patched)
            patched = self.encoder(patched)
            patched = self.layernorm(patched.last_hidden_state)
            patched = patched[:, 0]
            patched = self.classifier(patched)
            return patched

    model = ViT_Fourier()

**SinCos2D**

In [21]:
if __MODEL__ == 'SinCos2D' or __MODEL__ == 'pretest':    
    class Sin2DPositionalEncoding(nn.Module):
        def __init__(self, num_patches, dim):
            super().__init__()
            self.num_patches = num_patches
            self.dim = dim
            # Frequency scaling term (same as in Transformers)
            self.register_buffer('div_term', torch.exp(torch.arange(0, dim, 2).float() * (-np.log(10000.0) / dim)))
            grid_size = int(np.sqrt(self.num_patches))  # Typically 14x14 for 196 patches
            x = np.linspace(-1, 1, grid_size)
            y = np.linspace(-1, 1, grid_size)
            xx, yy = np.meshgrid(x, y)
            positions = np.stack([xx.flatten(), yy.flatten()], axis=-1)  # (num_patches, 2)
            self.register_buffer('normalized_positions', torch.tensor(positions, dtype=torch.float32))
            self.register_buffer('cls_embed', torch.zeros((1, self.dim)))
        def forward(self, x):
            positions = self.normalized_positions.unsqueeze(-1)  # (196, 2, 1)
            pe = torch.zeros(self.num_patches, self.dim, device=x.device)  # (196, d)
            pe[:, 0::2] = torch.sin(positions[:, 0] * self.div_term[: self.dim // 2]) + torch.sin(positions[:, 1] * self.div_term[: self.dim // 2])
            pe[:, 1::2] = torch.cos(positions[:, 0] * self.div_term[: self.dim // 2]) + torch.cos(positions[:, 1] * self.div_term[: self.dim // 2])
            pe = torch.cat([self.cls_embed, pe], dim=0).unsqueeze(0)
            return pe+x  # Shape: (196, d)
    class ViT_SinCos2D(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
            _, self.seq_length, self.hidden_dim = self.model.vit.embeddings.position_embeddings.shape
            self.model.vit.embeddings.position_embeddings = get_zero_pos_embed(self.hidden_dim, self.seq_length)
            num_patches = self.model.vit.embeddings.patch_embeddings.num_patches  # Should be 196 for 224x224 images with 16x16 patches
            dim = self.model.config.hidden_size  # Should be 768
            # Compute Fourier positional encoding
            self.pos_embedding = Sin2DPositionalEncoding(num_patches=num_patches, dim=dim)
            self.patch = self.model.vit.embeddings.patch_embeddings
            self.encoder = self.model.vit.encoder
            self.layernorm = self.model.vit.layernorm
            self.classifier = nn.Linear(in_features=self.hidden_dim, out_features=__NUM_CLASS__, bias=True)
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
        def forward(self, x):
            n = x.shape[0]
            patched = self.patch(x)
            batched_cls_token = self.cls_token.expand(n, -1, -1)
            patched = torch.cat([batched_cls_token, patched], dim=1)
            patched = self.pos_embedding(patched)
            patched = self.encoder(patched)
            patched = self.layernorm(patched.last_hidden_state)
            patched = patched[:, 0]
            patched = self.classifier(patched)
            return patched

    model = ViT_SinCos2D()

**CPEViT**

In [22]:
import torch
import torch.nn as nn
from transformers import ViTModel

class ConditionalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, num_patches):
        super().__init__()
        self.conv = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, groups=embed_dim)

    def forward(self, x):
        B, N, D = x.shape  # (Batch, Num Tokens, Embedding Dim)
        H = W = int((N - 1) ** 0.5)  # Ignore CLS token for reshaping

        cls_token = x[:, 0:1, :]  # Extract CLS token
        x = x[:, 1:, :]
        
        # Reshape for convolution
        cpe = x.transpose(1, 2).reshape(B, D, H, W)  # (B, D, H, W)
        cpe = self.conv(cpe).flatten(2).transpose(1, 2)  # Apply CPE

        x = x + cpe  # Add CPE to the features

        return torch.cat([cls_token, x], dim=1)  # Reinsert CLS token

class CPEViT(nn.Module):
    def __init__(self, model_name="google/vit-base-patch16-224"):
        super().__init__()
        
        self.vit = ViTModel.from_pretrained(model_name)
        self.vit.embeddings.position_embeddings = None  # Remove default positional embeddings

        embed_dim = self.vit.config.hidden_size
        num_patches = (self.vit.config.image_size // self.vit.config.patch_size) ** 2  # Compute patches

        self.encoder1 = self.vit.encoder.layer[0]
        self.cpe = ConditionalPositionalEncoding(embed_dim, num_patches)
        self.remaining_encoders = nn.ModuleList(self.vit.encoder.layer[1:])
        self.fc = nn.Linear(embed_dim, __NUM_CLASS__)

    def forward(self, x):
        x = self.vit.embeddings.patch_embeddings(x)  # (B, 196, D)

        batch_size = x.shape[0]
        cls_tokens = self.vit.embeddings.cls_token.expand(batch_size, -1, -1)  # (B, 1, D)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, 197, D)

        x = self.encoder1(x)[0]  # Extract tensor from tuple

        x = self.cpe(x)  # Apply Conditional Positional Encoding

        for layer in self.remaining_encoders:
            x = layer(x)[0]  # Extract tensor from tuple

        cls_token_final = x[:, 0]  # (B, D) - Extract the class token's final representation
        out = self.fc(cls_token_final)  # (B, num_classes) - Classification output
        
        return out

In [23]:
if __MODEL__ == 'CPEViT' or __MODEL__ == 'pretest':
    model = CPEViT("google/vit-base-patch16-224")

**RPE ViT**

In [24]:
if __MODEL__ == 'RPE' or __MODEL__ == 'pretest':    
    
    # Load Pretrained ViT Model
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
    _, seq_length, hidden_dims = model.vit.embeddings.position_embeddings.shape
    model.vit.embeddings.position_embeddings = pe_caller(hidden_dims, seq_length)
    model.classifier = nn.Linear(in_features=768, out_features=__NUM_CLASS__, bias=True)
    
    # Compute grid size dynamically
    grid_size = model.config.image_size // model.config.patch_size  # Example: 224//16 = 14
    
    class GlobalRelativePositionBias(nn.Module):
        def __init__(self, grid_size, num_heads):
            super().__init__()
            self.grid_size = grid_size
            self.num_heads = num_heads
    
            # Define a learnable relative position bias table
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * grid_size - 1) * (2 * grid_size - 1), num_heads)
            )  # Shape: ((2H-1) * (2W-1), nH)
    
            # Initialize with truncated normal
            nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
    
            # Compute relative position indices
            self.define_relative_position_index()
    
        def define_relative_position_index(self):
            """Compute pairwise relative position indices for the full grid."""
            coords_h = torch.arange(self.grid_size)
            coords_w = torch.arange(self.grid_size)
            coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))  # Shape: (2, H, W)
            #print(coords)
    
            
    
            coords_flatten = torch.flatten(coords, 1)  # Shape: (2, H*W)
            #print(coords_flatten.shape)
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # Shape: (2, H*W, H*W)
    
            #print(coords_flatten[:, None,:].shape)
    
            #print(relative_coords[0,0,:])
    
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Shape: (H*W, H*W, 2)
            relative_coords[:, :, 0] += self.grid_size - 1  # Shift to start from 0
            relative_coords[:, :, 1] += self.grid_size - 1
    
            relative_coords[:, :, 0] *= 2 * self.grid_size - 1
            relative_position_index = relative_coords.sum(-1).flatten()  # Shape: (H*W * H*W)
    
            self.register_buffer("relative_position_index", relative_position_index)
    
        def forward(self):
            """Retrieve relative position bias."""
            return self.relative_position_bias_table[self.relative_position_index].view(
                self.grid_size * self.grid_size, self.grid_size * self.grid_size, self.num_heads
            ).permute(2, 0, 1)  # Shape: (nH, H*W, H*W)
    
    
    class RifatAttention(nn.Module):
        def __init__(self, pretrained_query, pretrained_key, pretrained_value, grid_size, num_heads):
            super().__init__()
            self.query_ = nn.Parameter(pretrained_query.weight)
            self.key_ = nn.Parameter(pretrained_key.weight)
            self.value_ = nn.Parameter(pretrained_value.weight)
    
            self.relative_position_bias_module = GlobalRelativePositionBias(grid_size, num_heads)
    
        def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
            q = torch.matmul(hidden_states, self.query_.T)
            k = torch.matmul(hidden_states, self.key_.T)
            v = torch.matmul(hidden_states, self.value_.T)
    
            print(f'query shape:{q.shape}')
            print(f'key shape:{k.shape}')
    
            attn_scores = torch.matmul(q, k.transpose(-2,-1))  # Standard attention
            print(f'attention score shape:{attn_scores.shape}')
    
            # Add relative positional bias (broadcasted across batches)
            #attn_scores += self.relative_position_bias_module().unsqueeze(0)
            print(f' RPE shape: {self.relative_position_bias_module().shape}')
            print(attn_scores[:, 1: , 1: ].shape)
            attn_scores[:, 1: , 1: ] += self.relative_position_bias_module()
            print(f'attention score after RPE shape:{attn_scores.shape}')
    
    
            # Apply mask if available
            if attention_mask is not None:
                attn_scores = attn_scores + attention_mask  # Hugging Face applies -10000 for padding tokens
    
            attn_probs = torch.softmax(attn_scores, dim=-1)
    
            # Apply head mask if given
            if head_mask is not None:
                attn_probs = attn_probs * head_mask
    
            x = torch.matmul(attn_probs, v)
    
            if output_attentions:
                return x, attn_probs
            return x
    
    for i in range(12):
        for name, module in model.vit.encoder.layer[i].named_modules():
            if 'attention.attention' in name:
                query = model.vit.encoder.layer[i].attention.attention.query
                key = model.vit.encoder.layer[i].attention.attention.key
                value = model.vit.encoder.layer[i].attention.attention.value
                module = RifatAttention(
                query, key, value, grid_size, 1)

**ViT_B_16**

In [25]:
from transformers import ViTConfig, ViTForImageClassification
import torch
import torch.nn as nn

if __MODEL__ == 'vit_b_16' or __MODEL__ == 'pretest':

    config = ViTConfig(
        image_size=__IMG_SIZE__,          # 224
        patch_size=16,
        num_channels=3,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,

        num_labels=__NUM_CLASS__,

        # DeiT-style: usually keep dropout ~0 and use drop-path + strong aug + mixup/cutmix
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
    )

    # stochastic depth (drop-path)
    if hasattr(config, "drop_path_rate"):
        config.drop_path_rate = 0.1
    else:
        # older transformers: safe to skip, training still works
        pass

    # scratch init
    model = ViTForImageClassification(config)

    # replace position embeddings (robust to tensor vs Parameter)
    _, seq_length, hidden_dims = model.vit.embeddings.position_embeddings.shape
    new_pe = pe_caller(hidden_dims, seq_length)

    if isinstance(new_pe, torch.Tensor) and not isinstance(new_pe, nn.Parameter):
        new_pe = nn.Parameter(new_pe)

    model.vit.embeddings.position_embeddings = new_pe

    # classifier is already correct because num_labels was set
    # model.classifier = nn.Linear(768, __NUM_CLASS__, bias=True)  # not needed


**DEIT**

In [26]:
if __MODEL__ == 'deit_b_16' or __MODEL__ == 'pretest':
    model = ViTForImageClassification.from_pretrained('facebook/deit-base-patch16-224')
    _, seq_length, hidden_dim = model.vit.embeddings.position_embeddings.shape
    model.vit.embeddings.position_embeddings = pe_caller(hidden_dim, seq_length)
    model.classifier = nn.Linear(in_features=768, out_features=__NUM_CLASS__, bias=True)

**BEIT**

In [27]:
if __MODEL__ == 'beit_b_16' or __MODEL__=='pretest':
    model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')

**Cross-ViT**

In [28]:
if __MODEL__ == 'crossvit_base_240' or __MODEL__ == 'pretest' :
    model = timm.create_model( 'crossvit_base_240.in1k', pretrained = True)
    
    _, seq_length, hidden_dim = model.pos_embed_0.shape
    model.pos_embed_0 = pe_caller(hidden_dim, seq_length)
    
    _, seq_length, hidden_dim = model.pos_embed_1.shape
    model.pos_embed_1 = pe_caller(hidden_dim, seq_length)
    
    model.head = nn.ModuleList([
            nn.Linear(in_features=head.in_features, out_features=__NUM_CLASS__, bias=True) 
            for head in model.head])

**CaiT**

In [29]:
if __MODEL__ == 'cait_s24_224' or __MODEL__ == 'pretest':
    model = timm.create_model('cait_s24_224.fb_dist_in1k', pretrained=True)

    # Get model's embedding dimensions
    _, seq_length, hidden_dim = model.pos_embed.shape

    # Replace positional embeddings
    temp = pe_caller(hidden_dim, seq_length+1)
    model.pos_embed = nn.Parameter(temp[:,1:,:],requires_grad=temp.requires_grad)
    model.cls_token = nn.Parameter(temp[:,:1,:],requires_grad=temp.requires_grad)

    # Modify classifier layer
    model.head = nn.Linear(in_features=hidden_dim, out_features=__NUM_CLASS__, bias=True)
    

**Cross RPE (Useless)**

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel
from transformers.models.vit.modeling_vit import ViTSelfAttention

if __MODEL__ == 'Cross-RPE' or __MODEL__ == 'pretest':
    class CrossMethodRPE(nn.Module):
        def __init__(self, image_size, embed_dim, alpha=1.0, beta=10.0, gamma=100.0):
            super(CrossMethodRPE, self).__init__()
    
            self.image_size = image_size  # (Height, Width)
            self.height, self.width = image_size
            self.embed_dim = embed_dim
    
            # Learnable scalars for horizontal and vertical directions
            self.px = nn.Parameter(torch.randn(embed_dim, 1, 1))  # Horizontal
            self.py = nn.Parameter(torch.randn(embed_dim, 1, 1))  # Vertical
    
            # Parameters for piecewise function g(x)
            self.alpha = alpha
            self.beta = beta
            self.gamma = gamma
    
        def g(self, dist):
            alpha_tensor = torch.tensor(self.alpha, device=dist.device)
            gamma_tensor = torch.tensor(self.gamma, device=dist.device)
    
            dist = dist.abs()

            # Case 1: |x| ≤ α
            less_than_alpha = dist <= alpha_tensor
            result = torch.round(dist) * less_than_alpha.float()
    
            # Case 2: |x| > α
            greater_than_alpha = dist > alpha_tensor
            sign_dist = torch.sign(dist)
    
            log_scaled = alpha_tensor + (torch.log(dist / alpha_tensor) / torch.log(gamma_tensor / alpha_tensor)) * (self.beta - alpha_tensor)
            result += sign_dist * torch.min(self.beta * torch.ones_like(log_scaled), log_scaled) * greater_than_alpha.float()
    
            return result
    
        def forward(self):
            """Computes relative positional encoding b_ij"""
            grid_x, grid_y = torch.meshgrid(torch.arange(self.height), torch.arange(self.width), indexing="ij")
            grid_x = grid_x.to(torch.float32).unsqueeze(0).unsqueeze(0)
            grid_y = grid_y.to(torch.float32).unsqueeze(0).unsqueeze(0)
    
            Ix = self.g(grid_x - grid_x.transpose(-1, -2)) * self.px
            Iy = self.g(grid_y - grid_y.transpose(-1, -2)) * self.py
    
            encoding = Ix + Iy  # Final RPE b_ij
            return encoding
    
    
    class ViTSelfAttentionWithRPE(ViTSelfAttention):
        def __init__(self, config, rpe_module):
            super().__init__(config)
            self.rpe_module = rpe_module  # Inject RPE module

        def forward(self, hidden_states, head_mask=None, output_attentions=False):
            batch_size, seq_len, hidden_dim = hidden_states.shape
            num_heads = self.num_attention_heads
            head_dim = hidden_dim // num_heads
    
            # Project queries, keys, values
            query_layer = self.query(hidden_states).view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
            key_layer = self.key(hidden_states).view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
            value_layer = self.value(hidden_states).view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
    
            # Compute attention scores
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / (head_dim ** 0.5)
    
            # Add relative positional encoding b_ij
            rpe = self.rpe_module().to(hidden_states.device)  # Compute b_ij
            rpe = rpe.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, height, width]
            attention_scores = attention_scores + rpe
    
            # Softmax and apply attention
            attention_probs = F.softmax(attention_scores, dim=-1)
            attention_probs = self.attention_dropout(attention_probs)
    
            # Compute context
            context_layer = torch.matmul(attention_probs, value_layer)
            context_layer = context_layer.transpose(1, 2).reshape(batch_size, seq_len, hidden_dim)
    
            # Apply output projection
            context_layer = self.dense(context_layer)
    
            # Return values in the expected format
            outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
            return outputs
    
    
    model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
    
    # Extract model configuration
    hidden_size = model.config.hidden_size  # dz
    image_size = (14, 14)  # ViT input size after patch embedding (224/16 = 14)
    rpe_module = CrossMethodRPE(image_size=image_size, embed_dim=hidden_size)
    
    # Replace the self-attention module in all transformer blocks
    for layer in model.encoder.layer:
        layer.attention.attention = ViTSelfAttentionWithRPE(model.config, rpe_module)
    
    print("Modified ViT with Cross Method RPE is ready!")

# **Trainer**

In [31]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [32]:
import gc
import torch
import torch.nn as nn

def Trainer(train_loader, model, optimizer, epoch, device, grad_accum_steps=1, max_grad_norm=1.0, label_smoothing=0.0):
    model.train()

    grad_accum_steps = max(1, int(grad_accum_steps))
    current_lr = optimizer.param_groups[0]["lr"]
    print(f"Epoch [{epoch}] - Learning Rate: {current_lr:.6f}")

    losses = AverageMeter()
    bce_crit = nn.BCEWithLogitsLoss()
    ce_crit = nn.CrossEntropyLoss(label_smoothing=label_smoothing) if label_smoothing > 0 else nn.CrossEntropyLoss()

    optimizer.zero_grad(set_to_none=True)
    accum_count = 0
    micro_steps = 0
    total_batches = len(train_loader)

    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        out = model(images)
        logits = out.logits if hasattr(out, "logits") else out

        if __DATASET__ == "new_dataset":
            loss = bce_crit(logits, labels.float())
        else:
            targets = labels.argmax(dim=1).long() if labels.ndim > 1 else labels.long()
            loss = ce_crit(logits, targets)

        losses.update(loss.detach().item(), labels.size(0))

        accum_count += 1
        micro_steps += 1

        # scale by actual number of micro-batches in this update
        if (idx == total_batches - 1) and (accum_count != grad_accum_steps):
            scale = accum_count
        else:
            scale = grad_accum_steps

        (loss / scale).backward()

        do_step = (accum_count == grad_accum_steps) or (idx == total_batches - 1)
        if do_step:
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            accum_count = 0

        if micro_steps % __STEP_COUNTER__ == 0:
            print(f"micro-step: {micro_steps} | avg loss: {losses.avg:.6f}")

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return losses.avg


# **Validator**

In [33]:
def validator_for_new(model, valid_loader, device):
    model.eval()
    total = 0
    dist_comp = 0
    orientation = 0
    area_comp = 0
    vec_sum = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images = images.to(device)
            outputs = model(images)
            try:
                outputs = nn.Sigmoid()(outputs.logits).detach().cpu().numpy()
            except:
                outputs = nn.Sigmoid()(outputs).detach().cpu().numpy()
            outputs = np.where(outputs>0.5, 1, 0)
            labels = np.where(labels>0.5, 1, 0)
            dist_comp += (outputs[:,:2]==labels[:,:2]).sum()
            orientation += (outputs[:,2]==labels[:,2]).sum()
            area_comp += (outputs[:,3:5]==labels[:,3:5]).sum()
            vec_sum += (outputs[:,5]==labels[:,5]).sum()
            total += labels.shape[0]
    dist_comp_acc = 100*dist_comp/(total*2)
    orientation_acc = 100*orientation/(total)
    area_comp_acc = 100*area_comp/(2*total)
    vec_sum_acc = 100*vec_sum/total
    avg_accuracy = (dist_comp_acc + orientation_acc + area_comp_acc + vec_sum_acc)/4
    print(f'Distance comparison accuracy: {dist_comp_acc:.4f}%')
    print(f'Orientation accuracy: {orientation_acc:.4f}%')
    print(f'Area comparison accuracy: {area_comp_acc:.4f}%')
    print(f'Vector sum accuracy: {vec_sum_acc:.4f}%')
    print(f'Average Accuracy: {avg_accuracy:.4f}%')
    return avg_accuracy


In [34]:
import torch

def validator_regular(model, valid_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in valid_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            out = model(images)
            logits = out if torch.is_tensor(out) else out.logits  # [B, C]

            preds = logits.argmax(dim=1)  # [B]

            # support both one-hot labels [B, C] and index labels [B]
            if labels.ndim > 1:
                targets = labels.argmax(dim=1)
            else:
                targets = labels.long()

            correct += (preds == targets).sum().item()
            total += targets.size(0)

    accuracy = 100.0 * correct / max(1, total)
    print(f'Validation/Test Accuracy: {accuracy:.4f}%')
    return accuracy

In [35]:
if __DATASET__ == 'new_dataset':
    validator = validator_for_new
else:
    validator = validator_regular

# **Prepare Hardware**

In [36]:
import torch
import torch.optim as optim

# ----------------------------
# device
# ----------------------------
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("DEVICE WE WILL BE USING IS ", device)
model = model.to(device)

# ----------------------------
# DeiT-style: no weight decay on biases/norm + (recommended) cls/pos tokens
# ----------------------------
def param_groups_weight_decay(model, weight_decay: float = 0.05):
    decay, no_decay = [], []
    skip_wd_names = ("position_embeddings", "pos_embed", "cls_token", "dist_token")
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if (
            p.ndim == 1
            or name.endswith(".bias")
            or ("norm" in name.lower())
            or any(k in name for k in skip_wd_names)
        ):
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]

# ----------------------------
# global batch (make sure these match your actual training loop)
# ----------------------------
# Prefer using your real variables:
# global_batch = __BATCH_SIZE_TRAIN__ * __GRAD_ACCUM_STEPS__ * world_size
__BATCH_SIZE__ = 64
__GRAD_ACCUM_STEPS__ = 16
world_size = 1
global_batch = __BATCH_SIZE__ * __GRAD_ACCUM_STEPS__ * world_size

# ----------------------------
# DeiT-style LR scaling
# defaults: lr=5e-4, min_lr=1e-5, warmup_lr=1e-6 at global batch=1024
# ----------------------------
scale = global_batch / 1024.0
__MAX_LR__    = 5e-4 * scale
__MIN_LR__    = 1e-5 * scale
__WARMUP_LR__ = 1e-6 * scale

# Safety: never let min_lr exceed max_lr
__MIN_LR__ = min(__MIN_LR__, __MAX_LR__)

# ----------------------------
# optimizer: AdamW (DeiT default wd=0.05, eps=1e-8)
# ----------------------------
optimizer = optim.AdamW(
    param_groups_weight_decay(model, weight_decay=0.05),
    lr=__MAX_LR__,
    betas=(0.9, 0.999),
    eps=1e-8
)

# ----------------------------
# scheduler: warmup -> cosine (STRICT epoch-based stepping)
# call scheduler.step() once per epoch
# ----------------------------
warmup_epochs = 5
total_epochs = __EPOCHS__

start_factor = (__WARMUP_LR__ / __MAX_LR__) if __MAX_LR__ > 0 else 1.0
start_factor = max(1e-8, min(1.0, start_factor))

warmup = optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=start_factor,
    end_factor=1.0,
    total_iters=warmup_epochs
)

cosine = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=max(1, total_epochs - warmup_epochs),
    eta_min=__MIN_LR__
)

scheduler = optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_epochs]
)


DEVICE WE WILL BE USING IS  cuda:0


# **Main Loop**

In [37]:
# DataLoaders (recommended small fixes: valid/test shuffle=False, workers>0 if possible)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=__BATCH_SIZE_TRAIN__,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=__BATCH_SIZE_VALID__,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=__BATCH_SIZE_VALID__,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [None]:
best_acc = 0.0
best_epoch = START_EPOCH

train_losses = []
val_accs = []

for epoch in range(START_EPOCH, __EPOCHS__ + 1):
    start = time.time()

    train_loss = Trainer(
        train_loader=train_loader,
        model=model,
        optimizer=optimizer,
        epoch=epoch,
        device=device,
        grad_accum_steps=__GRAD_ACCUM_STEPS__,
        max_grad_norm=1.0,     # good default for ViT scratch
        # label_smoothing=0.1   # enable only if your Trainer supports it and before mixup/cutmix
    )

    print(f"train loss for epoch {epoch}: {train_loss}")
    train_losses.append(train_loss)

    print("validating.....")
    val_acc = validator(model, valid_loader, device)
    val_accs.append(val_acc)

    if val_acc > best_acc:
        best_acc = val_acc
        best_epoch = epoch
        print(f"saving best (epoch={epoch}, val_acc={val_acc:.4f}).....")
        torch.save(model.state_dict(), BASE_PATH)

        test_acc = validator(model, test_loader, device)
        print(f"test acc (best checkpoint): {test_acc:.4f}")

    torch.save(model.state_dict(), LAST_PATH)

    # STRICT epoch-based scheduler step (exactly once per epoch)
    scheduler.step()

    end = time.time()
    print(f"best so far: epoch={best_epoch}, val_acc={best_acc:.4f}")
    print(f"time elapsed (sec): {(end - start):.2f}")


Epoch [1] - Learning Rate: 0.000001
micro-step: 50 | avg loss: 4.748254
micro-step: 100 | avg loss: 4.728645
micro-step: 150 | avg loss: 4.718540
micro-step: 200 | avg loss: 4.702933
micro-step: 250 | avg loss: 4.688892
micro-step: 300 | avg loss: 4.677930
micro-step: 350 | avg loss: 4.668380
micro-step: 400 | avg loss: 4.659218
micro-step: 450 | avg loss: 4.650622
micro-step: 500 | avg loss: 4.642492
micro-step: 550 | avg loss: 4.634812
micro-step: 600 | avg loss: 4.628419
micro-step: 650 | avg loss: 4.621867
micro-step: 700 | avg loss: 4.616164
micro-step: 750 | avg loss: 4.610176
micro-step: 800 | avg loss: 4.605120
micro-step: 850 | avg loss: 4.600066
micro-step: 900 | avg loss: 4.595452
micro-step: 950 | avg loss: 4.591213
micro-step: 1000 | avg loss: 4.586459
micro-step: 1050 | avg loss: 4.581948
train loss for epoch 1: 4.579237996614896
validating.....
Validation/Test Accuracy: 3.6308%
saving best (epoch=1, val_acc=3.6308).....
Validation/Test Accuracy: 3.3077%
test acc (best ch

In [None]:
img = valid_dataset[3][0].unsqueeze(0).to(device)
img.shape

In [None]:
pos = model.pos_embedding(img).squeeze().detach().cpu().numpy()[1:].reshape(14,14,768)
final = pos

In [None]:
indices = np.arccos(final[:,:,767])

In [None]:
from numpy.linalg import norm
cen = (0,0)
similarity = ((final * final[cen]).sum(-1))/(norm(final, axis=-1)*norm(final[cen]))

In [None]:
plt.imshow(similarity, cmap='hot')
#plt.colorbar()
plt.title('14x14 Array Scaled to 224x224',)
plt.show()

In [None]:
def radial_smooth(array, center=None, bin_width=1.0):
    nrows, ncols = array.shape
    if center is None: center = (nrows / 2, ncols / 2)
    a, b = center
    y, x = np.indices((nrows, ncols))
    r = np.sqrt((y - a)**2 + (x - b)**2)
    bin_indices = np.floor(r / bin_width).astype(int)
    smoothed = np.zeros_like(array)
    for bin_val in np.unique(bin_indices):
        mask = (bin_indices == bin_val)
        smoothed[mask] = array[mask].mean()
    return smoothed

smoothed_array = radial_smooth(similarity, center=(0,0), bin_width=1.0)

plt.subplot(1, 2, 1)
plt.imshow(similarity, cmap='viridis')
plt.subplot(1, 2, 2)
plt.imshow(smoothed_array, cmap='viridis')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom

def radial_smooth(array, center=None, bin_width=1.0):
    nrows, ncols = array.shape
    if center is None: center = (nrows / 2, ncols / 2)
    a, b = center
    y, x = np.indices((nrows, ncols))
    r = np.sqrt((y - a)**2 + (x - b)**2)
    bin_indices = np.floor(r / bin_width).astype(int)
    values = np.zeros(np.unique(bin_indices).shape[0])
    smoothed = np.zeros_like(array)
    for bin_val in np.unique(bin_indices):
        mask = (bin_indices == bin_val)
        values[bin_val] = array[mask].std()
        smoothed[mask] = array[mask].mean()
    return smoothed, values

array = similarity
#array_224 = zoom(array, (224 / 14, 224 / 14), order=1)
smoothed_array, values = radial_smooth(array, center=(0,0), bin_width=1.0)

plt.figure(figsize=(12,12))
plt.subplot(1, 2, 1)
plt.imshow(array, cmap='hot')
plt.subplot(1, 2, 2)
plt.imshow(smoothed_array, cmap='hot')
plt.show()


In [None]:
def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[N:] - cumsum[:-N]) / float(N)
plt.plot(running_mean(values,3))
plt.show()

# **Test Results**

In [None]:
#model.load_state_dict(torch.load(BASE_PATH, weights_only=True))
test_acc = validator(model, test_loader, device)

In [None]:
img = x.squeeze().cpu().permute(1, 2, 0).numpy()

fig, ax = plt.subplots()
ax.imshow(img)
cell_size = 224 // 14

for i in range(15):
    ax.axhline(i * cell_size, color='white', linewidth=0.5)
    ax.axvline(i * cell_size, color='white', linewidth=0.5)

ax.axis('off')
plt.show()

In [None]:
np.set_printoptions(linewidth=np.inf)

x = valid_dataset[3][0].unsqueeze(0).to(device)
cord = model.pos_embedding.cord
art_conv = model.pos_embedding.art_conv
out = art_conv(torch.cat([x, cord.unsqueeze(0).expand(1, -1, -1, -1)], dim=1)).squeeze().detach().cpu().numpy()
out = 2*out-1
wrap_pos_hilbert = model.pos_embedding.wrap_pos_hilbert.squeeze().detach().cpu().numpy()
np.set_printoptions(precision=2, suppress=True)
print(out.reshape(14,14))
out = out + wrap_pos_hilbert
print(wrap_pos_hilbert.reshape(14,14))
print(out.reshape(14,14))

# **Loss & Val_Acc**

In [None]:
x = np.arange(0, val_accs.__len__(), 1)  # x-axis values  
plt.figure(figsize=(8, 5))  # Optional: Set the figure size  
plt.plot(x, train_losses, label="train_loss", color="blue", linestyle="-")  # First line  
plt.plot(x, val_accs, label="val_accuracy", color="red", linestyle="--")  # Second line  
plt.legend()  # Show the legend  
plt.grid(True)  # Optional: Add a grid  
plt.show()  

In [None]:
p = np.arange(196).reshape((14,14))
#p = gilbert_2d(14,14)
cen = (6,6)
p = p-p[cen]
p = (p-p.min())
p = p/p.max()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom

def radial_smooth(array, center=None, bin_width=1.0):
    nrows, ncols = array.shape
    if center is None: center = (nrows / 2, ncols / 2)
    a, b = center
    y, x = np.indices((nrows, ncols))
    r = np.sqrt((y - a)**2 + (x - b)**2)
    bin_indices = np.floor(r / bin_width).astype(int)
    values = np.zeros(np.unique(bin_indices).shape[0])
    smoothed = np.zeros_like(array)
    for bin_val in np.unique(bin_indices):
        mask = (bin_indices == bin_val)
        values[bin_val] = array[mask].mean()
        smoothed[mask] = array[mask].mean()
    return smoothed, values

array = p
array_224 = zoom(array, (224 / 14, 224 / 14), order=1)
smoothed_array, values = radial_smooth(array_224, center=(cen[0]*16+8, cen[1]*16+8), bin_width=1.0)

plt.subplot(1, 2, 1)
plt.imshow(array_224, cmap='viridis')
plt.subplot(1, 2, 2)
plt.imshow(smoothed_array, cmap='viridis')
plt.show()

In [None]:
plt.plot(values)
plt.show()

In [None]:
values