In [1]:
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomVerticalFlip,  RandomRotation
from tqdm import tqdm 
from torch.utils.data import ConcatDataset

train_dir='/kaggle/input/brain-tumor-classification-mri/Training'
test_dir='/kaggle/input/brain-tumor-classification-mri/Testing'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "google/vit-base-patch16-224"

classes=4


In [2]:
processor = ViTImageProcessor.from_pretrained(model_name)
#augmentation
# train_transform = Compose([
#     RandomResizedCrop(224, scale=(0.8, 1.0)),  
#     RandomHorizontalFlip(),                    
#     RandomRotation(degrees=15),                
#     ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color jitter
#     ToTensor(),                                
#     Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 
# ])

transform = Compose([
    Resize((224, 224)), 
    ToTensor(),                                      
    Normalize(mean=processor.image_mean, std=processor.image_std),  
])
transform1 = Compose([
    Resize((224, 224)), 
    RandomHorizontalFlip(p=1.0),
    ToTensor(),                                      
    Normalize(mean=processor.image_mean, std=processor.image_std),  
])
transform2 = Compose([
    Resize((224, 224)), 
    RandomVerticalFlip(p=1.0),
    ToTensor(),                                      
    Normalize(mean=processor.image_mean, std=processor.image_std),  
])
transform3 = Compose([
    Resize((224, 224)), 
    RandomHorizontalFlip(p=1.0),
    RandomVerticalFlip(p=1.0),
    ToTensor(),                                      
    Normalize(mean=processor.image_mean, std=processor.image_std),  
])
transform4 = Compose([
    Resize((224, 224)),                 
    RandomRotation(degrees=45), 
    ToTensor(),                                      
    Normalize(mean=processor.image_mean, std=processor.image_std),  
])




# Load dataset
train_dataset1 = ImageFolder(root=train_dir, transform=transform)
train_dataset2 = ImageFolder(root=train_dir, transform=transform1)
train_dataset3 = ImageFolder(root=train_dir, transform=transform2)
train_dataset4 = ImageFolder(root=train_dir, transform=transform3)
train_dataset5 = ImageFolder(root=train_dir, transform=transform4)

train_dataset =  ConcatDataset([train_dataset1, train_dataset2, train_dataset3, train_dataset4, train_dataset5])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [3]:
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=classes,                
    ignore_mismatched_sizes=True        
)

# Optimizer and Loss Function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

model.to(device)

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([4, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [4]:
def train(num_epochs):
    correct=0
    total=0
    for epoch in range(num_epochs):  # Train for 10 epochs
        model.train()
        for batch in tqdm(train_loader):
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
        print(f"Epoch {epoch + 1}: Train Accuracy = {correct / total:.2f}")   
        
        
    

In [5]:
train(num_epochs=4)

100%|██████████| 449/449 [09:20<00:00,  1.25s/it]


Epoch 1: Train Accuracy = 0.95


100%|██████████| 449/449 [09:20<00:00,  1.25s/it]


Epoch 2: Train Accuracy = 0.97


100%|██████████| 449/449 [09:16<00:00,  1.24s/it]


Epoch 3: Train Accuracy = 0.98


100%|██████████| 449/449 [09:21<00:00,  1.25s/it]

Epoch 4: Train Accuracy = 0.98





In [6]:
#test
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in test_loader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    print(f"Test Accuracy = {correct / total:.2f}")

test_dataset = ImageFolder(root=test_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
test()

Test Accuracy = 0.77
