In [1]:
# Imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTConfig
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Config
# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 16          # try 32 if VRAM allows
EPOCHS = 5               # 5–10 good for 140K images
NUM_CLASSES = 2
LR = 2e-5

MODEL_NAME = "google/vit-base-patch16-224"

TRAIN_DIR = "dataset/real-vs-fake/train"
VALID_DIR = "dataset/real-vs-fake/valid"
TEST_DIR  = "dataset/real-vs-fake/test"

MODEL_OUTPUT = "pretrained_vit_model.pkl"

In [3]:
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("CUDA version used by PyTorch:", torch.version.cuda)

CUDA available: True
GPU: NVIDIA GeForce RTX 3050 6GB Laptop GPU
CUDA version used by PyTorch: 12.1


In [4]:
# Image Transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


In [5]:
# Load datasets (ImageFolder)
train_dataset = datasets.ImageFolder(
    root=TRAIN_DIR,
    transform=train_transform
)

valid_dataset = datasets.ImageFolder(
    root=VALID_DIR,
    transform=val_transform
)

test_dataset = datasets.ImageFolder(
    root=TEST_DIR,
    transform=val_transform
)

print("Class mapping:", train_dataset.class_to_idx)
# Example: {'fake': 0, 'real': 1}


Class mapping: {'fake': 0, 'real': 1}


In [6]:
# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [7]:
# Load Pretrained ViT + New Head 
config = ViTConfig.from_pretrained(MODEL_NAME)
config.num_labels = NUM_CLASSES

model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    config=config,
    ignore_mismatched_sizes=True
)

model.to(DEVICE)



Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [None]:
# Freeze Backbone
for param in model.vit.parameters():
    param.requires_grad = False

# Train only classifier head
for param in model.classifier.parameters():
    param.requires_grad = True


In [None]:
# Optimizer, Loss, AMP
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR
)

criterion = nn.CrossEntropyLoss()

# Mixed Precision (RTX 3050 = faster)
scaler = torch.cuda.amp.GradScaler()


  scaler = torch.cuda.amp.GradScaler()


In [11]:
# trTraining Loop (GPU + AMP)
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for images, labels in loop:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images).logits
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")



  with torch.cuda.amp.autocast():
Epoch 1/5: 100%|██████████| 6250/6250 [06:49<00:00, 15.25it/s, loss=0.326]


Epoch 1 | Avg Loss: 0.3742


Epoch 2/5: 100%|██████████| 6250/6250 [06:46<00:00, 15.38it/s, loss=0.279]


Epoch 2 | Avg Loss: 0.3720


Epoch 3/5: 100%|██████████| 6250/6250 [06:46<00:00, 15.39it/s, loss=0.299] 


Epoch 3 | Avg Loss: 0.3660


Epoch 4/5: 100%|██████████| 6250/6250 [06:47<00:00, 15.34it/s, loss=0.515] 


Epoch 4 | Avg Loss: 0.3568


Epoch 5/5: 100%|██████████| 6250/6250 [06:44<00:00, 15.46it/s, loss=0.317]

Epoch 5 | Avg Loss: 0.3493





In [12]:
# Validation
model.eval()
preds, trues = [], []

with torch.no_grad():
    for images, labels in tqdm(valid_loader, desc="Validating"):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model(images).logits
        preds.extend(outputs.argmax(1).cpu().numpy())
        trues.extend(labels.cpu().numpy())

print("Validation Accuracy:", accuracy_score(trues, preds))


Validating: 100%|██████████| 1250/1250 [04:16<00:00,  4.88it/s]

Validation Accuracy: 0.8538





In [13]:
# Test Evaluation
model.eval()
preds, trues = [], []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model(images).logits
        preds.extend(outputs.argmax(1).cpu().numpy())
        trues.extend(labels.cpu().numpy())

print("Test Accuracy:", accuracy_score(trues, preds))


Testing: 100%|██████████| 1250/1250 [04:19<00:00,  4.82it/s]

Test Accuracy: 0.85415





In [14]:
# SAVE MODEL AS .pkl
# Move model to CPU before saving
model_cpu = model.to("cpu")
model_cpu.eval()

with open(MODEL_OUTPUT, "wb") as f:
    pickle.dump(model_cpu, f)

print("Model saved as:", MODEL_OUTPUT)

Model saved as: pretrained_vit_model.pkl


In [15]:
# Load PKL to Verify
with open("pretrained_vit_model.pkl", "rb") as f:
    loaded_model = pickle.load(f)

loaded_model.eval()
print("PKL model loaded successfully")


PKL model loaded successfully
