In [1]:
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation
from tqdm import tqdm 

train_dir='/kaggle/input/brain-tumor-mixture/Dataset2/train'
test_dir='/kaggle/input/brain-tumor-mixture/Dataset2/test'
val_dir='/kaggle/input/brain-tumor-mixture/Dataset2/val'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "google/vit-base-patch16-224"

classes=2


In [2]:
processor = ViTImageProcessor.from_pretrained(model_name)
transform = Compose([
    Resize((224, 224)),                              # Resize images
    ToTensor(),                                      # Convert image to tensor
    Normalize(mean=processor.image_mean, std=processor.image_std),  # Normalize with pretrained stats
])

# Load dataset
train_dataset = ImageFolder(root=train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = ImageFolder(root=val_dir, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [3]:
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=classes,                
    ignore_mismatched_sizes=True        
)

# Optimizer and Loss Function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

model.to(device)

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [4]:
def train(num_epochs):
    correct=0
    total=0
    for epoch in tqdm(range(num_epochs)):  # Train for 10 epochs
        model.train()
        for batch in train_loader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
        print(f"Epoch {epoch + 1}: Train Accuracy = {correct / total:.2f}")   
        validation(epoch)
        
    
def validation(epoch):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in val_loader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    print(f"Epoch {epoch + 1}: Validation Accuracy = {correct / total:.2f}")


In [5]:
train(num_epochs=3)

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1: Train Accuracy = 0.96


 33%|███▎      | 1/3 [03:29<06:58, 209.24s/it]

Epoch 1: Validation Accuracy = 0.99
Epoch 2: Train Accuracy = 0.98


 67%|██████▋   | 2/3 [06:55<03:27, 207.73s/it]

Epoch 2: Validation Accuracy = 1.00
Epoch 3: Train Accuracy = 0.98


100%|██████████| 3/3 [10:21<00:00, 207.29s/it]

Epoch 3: Validation Accuracy = 0.93



