In [7]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from transformers import ViTForImageClassification, ViTConfig

from sklearn.metrics import accuracy_score, confusion_matrix
from tqdm import tqdm

import numpy as np
import random
from collections import Counter
from pathlib import Path

In [8]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Using device: cuda


In [9]:
BATCH_SIZE = 16
EPOCHS = 5
LR = 5e-6
NUM_CLASSES = 2

MODEL_NAME = "google/vit-base-patch16-224"

TRAIN_DIR = "archive/Dataset/Train"
VALID_DIR = "archive/Dataset/Validation"
TEST_DIR  = "archive/Dataset/Test"

MODEL_OUTPUT = Path("model/pretrained_vit_model_finetuned.pt")
MODEL_OUTPUT.parent.mkdir(parents=True, exist_ok=True)

In [10]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

In [11]:
train_dataset = datasets.ImageFolder(TRAIN_DIR, train_transform)
valid_dataset = datasets.ImageFolder(VALID_DIR, val_transform)
test_dataset  = datasets.ImageFolder(TEST_DIR,  val_transform)

print("Class mapping:", train_dataset.class_to_idx)

pin = DEVICE == "cuda"

train_loader = DataLoader(
    train_dataset, BATCH_SIZE,
    shuffle=True, num_workers=2, pin_memory=pin
)

valid_loader = DataLoader(
    valid_dataset, BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=pin
)

test_loader = DataLoader(
    test_dataset, BATCH_SIZE,
    shuffle=False, num_workers=2, pin_memory=pin
)

Class mapping: {'Fake': 0, 'Real': 1}


In [12]:
label_counts = Counter([label for _, label in train_dataset.samples])
total = sum(label_counts.values())

weights = [
    total / label_counts[train_dataset.class_to_idx["Fake"]],
    total / label_counts[train_dataset.class_to_idx["Real"]],
]

class_weights = torch.tensor(weights, dtype=torch.float32).to(DEVICE)
print("Class weights:", class_weights)

Class weights: tensor([2., 2.], device='cuda:0')


In [13]:
config = ViTConfig.from_pretrained(MODEL_NAME)
config.num_labels = NUM_CLASSES

model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    config=config,
    ignore_mismatched_sizes=True
).to(DEVICE)

# Freeze full backbone
for param in model.parameters():
    param.requires_grad = False

# Unfreeze last 3 transformer blocks
for param in model.vit.encoder.layer[-3:].parameters():
    param.requires_grad = True

# Unfreeze classifier
for param in model.classifier.parameters():
    param.requires_grad = True

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR
)

criterion = nn.CrossEntropyLoss(
    weight=class_weights,
    label_smoothing=0.1
)

scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))


In [9]:
best_val_acc = 0.0
best_epoch = 0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for images, labels in loop:
        images = images.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
            logits = model(images).logits
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    # ---------- Validation ----------
    model.eval()
    preds, trues = [], []

    with torch.no_grad():
        for images, labels in valid_loader:
            images = images.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)

            logits = model(images).logits
            preds.extend(logits.argmax(dim=1).cpu().numpy())
            trues.extend(labels.cpu().numpy())

    val_acc = accuracy_score(trues, preds)
    print(f"Epoch {epoch+1} | Val Acc: {val_acc:.4f}")

    # ---------- Save BEST model ----------
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1

        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "architecture": MODEL_NAME,
                "num_classes": NUM_CLASSES,
                "class_mapping": train_dataset.class_to_idx,
                "epoch": best_epoch,
                "val_accuracy": best_val_acc,
            },
            MODEL_OUTPUT
        )

        print(f"Saved BEST model @ epoch {best_epoch}")

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 1/5: 100%|██████████| 6250/6250 [16:42<00:00,  6.23it/s, loss=0.425]


Epoch 1 | Val Acc: 0.9338
Saved BEST model @ epoch 1


  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 2/5: 100%|██████████| 6250/6250 [11:15<00:00,  9.25it/s, loss=0.316]


Epoch 2 | Val Acc: 0.9563
Saved BEST model @ epoch 2


  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 3/5: 100%|██████████| 6250/6250 [14:53<00:00,  6.99it/s, loss=0.224]


Epoch 3 | Val Acc: 0.9673
Saved BEST model @ epoch 3


  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 4/5: 100%|██████████| 6250/6250 [18:00<00:00,  5.78it/s, loss=0.218]


Epoch 4 | Val Acc: 0.9718
Saved BEST model @ epoch 4


  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 5/5:  17%|█▋        | 1032/6250 [02:02<10:19,  8.42it/s, loss=0.243]


KeyboardInterrupt: 

In [10]:
checkpoint = torch.load(MODEL_OUTPUT, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

preds, trues = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE, non_blocking=True)
        logits = model(images).logits
        preds.extend(logits.argmax(dim=1).cpu().numpy())
        trues.extend(labels.numpy())

print("Test Accuracy:", accuracy_score(trues, preds))
print("Confusion Matrix:")
print(confusion_matrix(trues, preds))

  checkpoint = torch.load(MODEL_OUTPUT, map_location=DEVICE)


Test Accuracy: 0.9718
Confusion Matrix:
[[9876  124]
 [ 440 9560]]
