In [5]:
import torch
import random
import numpy as np
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
import timm
from tqdm import tqdm



In [6]:


def get_cifar10_split(fraction=0.1, image_size=224, train_ratio=0.8):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size,), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])
    full_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    indices = list(range(len(full_data)))
    random.shuffle(indices)
    subset_len = int(fraction * len(full_data))
    subset = Subset(full_data, indices[:subset_len])
    train_len = int(train_ratio * subset_len)
    test_len = subset_len - train_len
    return random_split(subset, [train_len, test_len])



In [7]:


def train_vit_on_subset(patch_size=16, image_size=None, epochs=6):
    if image_size is None:
       # image_size = patch_size * 7  # 112 for patch 16, 224 for patch 32
       image_size = 224 if patch_size == 16 else 384

    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    train_set, test_set = get_cifar10_split(fraction=0.1, image_size=image_size)

    train_loader = DataLoader(train_set, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)#, prefetch_factor=0)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=8, pin_memory=True)#, prefetch_factor=0)

    model = timm.create_model(f'vit_base_patch{patch_size}_{image_size}', pretrained=True, num_classes=10)
    model = model.to("cuda" if torch.cuda.is_available() else "cpu")

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = torch.nn.CrossEntropyLoss()
    device = next(model.parameters()).device
   # for param in model.parameters():
   #     print(param.requires_grad)  # should all be True

    
    for epoch in range(epochs):
        print(f"[Patch {patch_size}] Epoch {epoch+1}") 


        model.train()
        total_loss, correct = 0, 0
        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} - Training"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()

        acc = 100. * correct / len(train_set)
        print(f"[Patch {patch_size}] Epoch {epoch+1}: Train Loss={total_loss:.2f}, Avg Loss: {total_loss / len(train_set):.4f}, Accuracy={acc:.2f}%")

    torch.save(model.state_dict(), f"vit_patch{patch_size}_VIT01.pth")

    # Evaluation
    print(f"[Patch {patch_size}] Epoch {epoch+1} - Evaluating...")
    model.eval()
    correct = 0
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc=f"Epoch {epoch+1} - Evaluating"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            correct += (outputs.argmax(1) == labels).sum().item()

    test_acc = 100. * correct / len(test_set)
    print(f"[Patch {patch_size}] Final Test Accuracy: {test_acc:.2f}%")



In [8]:


if __name__ == "__main__":
    print("🔍 Training with Patch Size 16")
    train_vit_on_subset(patch_size=16)


🔍 Training with Patch Size 16
Files already downloaded and verified
[Patch 16] Epoch 1


Epoch 1 - Training: 100%|██████████| 250/250 [02:53<00:00,  1.44it/s]


[Patch 16] Epoch 1: Train Loss=1939.43, Avg Loss: 0.4849, Accuracy=84.10%
[Patch 16] Epoch 2


Epoch 2 - Training: 100%|██████████| 250/250 [02:54<00:00,  1.44it/s]


[Patch 16] Epoch 2: Train Loss=963.32, Avg Loss: 0.2408, Accuracy=92.25%
[Patch 16] Epoch 3


Epoch 3 - Training: 100%|██████████| 250/250 [02:53<00:00,  1.44it/s]


[Patch 16] Epoch 3: Train Loss=702.00, Avg Loss: 0.1755, Accuracy=94.45%
[Patch 16] Epoch 4


Epoch 4 - Training: 100%|██████████| 250/250 [02:55<00:00,  1.43it/s]


[Patch 16] Epoch 4: Train Loss=527.64, Avg Loss: 0.1319, Accuracy=95.65%
[Patch 16] Epoch 5


Epoch 5 - Training: 100%|██████████| 250/250 [02:53<00:00,  1.44it/s]


[Patch 16] Epoch 5: Train Loss=595.18, Avg Loss: 0.1488, Accuracy=95.42%
[Patch 16] Epoch 6


Epoch 6 - Training: 100%|██████████| 250/250 [02:54<00:00,  1.44it/s]


[Patch 16] Epoch 6: Train Loss=435.82, Avg Loss: 0.1090, Accuracy=96.50%
[Patch 16] Epoch 6 - Evaluating...


Epoch 6 - Evaluating: 100%|██████████| 32/32 [00:37<00:00,  1.17s/it]

[Patch 16] Final Test Accuracy: 84.00%





In [10]:

if __name__ == "__main__":    
    print("\n🔍 Training with Patch Size 32")
    train_vit_on_subset(patch_size=32)



🔍 Training with Patch Size 32
Files already downloaded and verified


model.safetensors:   3%|2         | 10.5M/364M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


[Patch 32] Epoch 1


Epoch 1 - Training: 100%|██████████| 250/250 [02:18<00:00,  1.81it/s]


[Patch 32] Epoch 1: Train Loss=2159.95, Avg Loss: 0.5400, Accuracy=84.12%
[Patch 32] Epoch 2


Epoch 2 - Training: 100%|██████████| 250/250 [02:14<00:00,  1.86it/s]


[Patch 32] Epoch 2: Train Loss=1008.92, Avg Loss: 0.2522, Accuracy=91.97%
[Patch 32] Epoch 3


Epoch 3 - Training: 100%|██████████| 250/250 [02:14<00:00,  1.85it/s]


[Patch 32] Epoch 3: Train Loss=625.02, Avg Loss: 0.1563, Accuracy=94.85%
[Patch 32] Epoch 4


Epoch 4 - Training: 100%|██████████| 250/250 [02:14<00:00,  1.86it/s]


[Patch 32] Epoch 4: Train Loss=703.85, Avg Loss: 0.1760, Accuracy=94.08%
[Patch 32] Epoch 5


Epoch 5 - Training: 100%|██████████| 250/250 [02:14<00:00,  1.86it/s]


[Patch 32] Epoch 5: Train Loss=648.64, Avg Loss: 0.1622, Accuracy=94.83%
[Patch 32] Epoch 6


Epoch 6 - Training: 100%|██████████| 250/250 [02:15<00:00,  1.85it/s]


[Patch 32] Epoch 6: Train Loss=544.27, Avg Loss: 0.1361, Accuracy=95.55%
[Patch 32] Epoch 6 - Evaluating...


Epoch 6 - Evaluating: 100%|██████████| 32/32 [00:35<00:00,  1.11s/it]

[Patch 32] Final Test Accuracy: 81.50%



