In [22]:
from torchvision.datasets.mnist import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data.dataset import Subset
import torch
from train_dataset import dataset_mean, dataset_std

val_dataset = FashionMNIST(
    "./data",
    train=False,
    transform=transforms.Compose(
        [
            # transforms.Resize(224),
            # transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(dataset_mean, dataset_std),
        ]
    ),
    download=True,
)

head_val, head_train = torch.utils.data.random_split(val_dataset, [0.4, 0.6])
# head_train = Subset(head_train, indices=torch.randperm(len(head_train))[:1000])

In [23]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from tqdm import tqdm
import timm
from model import DinoViT, DinoHead

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

head_train_loader = DataLoader(head_train, batch_size=256, num_workers=16)

ViT = DinoViT(device)
ViT.load_state_dict(torch.load("DINO_ViT_simple_head.pth"))

# ViT = timm.create_model("vit_base_patch16_224", pretrained=True).to(device)
ViT.eval()
# linear_head = DinoHead(1000, 10).to(device)
linear_head = nn.Linear(64, 10).to(device)
CE_loss = nn.CrossEntropyLoss()
head_optimizer = torch.optim.Adam(linear_head.parameters(), lr=1e-3)
epochs = 20


for epoch in range(epochs):
    linear_head.train()
    total_loss = 0
    for imgs, labels in tqdm(head_train_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        head_optimizer.zero_grad()
        with torch.no_grad():
            feats = ViT(imgs)
        logits = linear_head(feats)
        loss = CE_loss(logits, labels)
        prediction = torch.argmax(logits)
        loss.backward()
        head_optimizer.step()
        total_loss += loss.item()

    avg = total_loss / len(head_train_loader)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg:4f}")

# torch.save(linear_head, 'linear_head.pth')

100%|██████████| 24/24 [00:00<00:00, 32.64it/s]


Epoch 1/20 - Loss: 2.371667


100%|██████████| 24/24 [00:00<00:00, 41.66it/s]


Epoch 2/20 - Loss: 1.907686


100%|██████████| 24/24 [00:00<00:00, 41.43it/s]


Epoch 3/20 - Loss: 1.589299


100%|██████████| 24/24 [00:00<00:00, 40.49it/s]


Epoch 4/20 - Loss: 1.369138


100%|██████████| 24/24 [00:00<00:00, 42.83it/s]


Epoch 5/20 - Loss: 1.214314


100%|██████████| 24/24 [00:00<00:00, 42.00it/s]


Epoch 6/20 - Loss: 1.103464


100%|██████████| 24/24 [00:00<00:00, 43.40it/s]


Epoch 7/20 - Loss: 1.022107


100%|██████████| 24/24 [00:00<00:00, 42.96it/s]


Epoch 8/20 - Loss: 0.960680


100%|██████████| 24/24 [00:00<00:00, 41.28it/s]


Epoch 9/20 - Loss: 0.912987


100%|██████████| 24/24 [00:00<00:00, 38.27it/s]


Epoch 10/20 - Loss: 0.874998


100%|██████████| 24/24 [00:00<00:00, 40.82it/s]


Epoch 11/20 - Loss: 0.844046


100%|██████████| 24/24 [00:00<00:00, 42.53it/s]


Epoch 12/20 - Loss: 0.818323


100%|██████████| 24/24 [00:00<00:00, 40.97it/s]


Epoch 13/20 - Loss: 0.796580


100%|██████████| 24/24 [00:00<00:00, 42.95it/s]


Epoch 14/20 - Loss: 0.777926


100%|██████████| 24/24 [00:00<00:00, 40.27it/s]


Epoch 15/20 - Loss: 0.761718


100%|██████████| 24/24 [00:00<00:00, 42.82it/s]


Epoch 16/20 - Loss: 0.747476


100%|██████████| 24/24 [00:00<00:00, 42.20it/s]


Epoch 17/20 - Loss: 0.734842


100%|██████████| 24/24 [00:00<00:00, 40.95it/s]


Epoch 18/20 - Loss: 0.723537


100%|██████████| 24/24 [00:00<00:00, 41.14it/s]


Epoch 19/20 - Loss: 0.713347


100%|██████████| 24/24 [00:00<00:00, 41.91it/s]

Epoch 20/20 - Loss: 0.704102





In [24]:
head_val_loader = DataLoader(head_train, batch_size=32, num_workers=16)

linear_head.eval()
ViT.eval()
correct, total = 0, 0
with torch.no_grad():
    for imgs, labels in tqdm(head_val_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        feats = ViT(imgs)
        logits = linear_head(feats)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

100%|██████████| 188/188 [00:01<00:00, 155.82it/s]

Test Accuracy: 76.42%



