<a href="https://colab.research.google.com/github/Tanmay240405/disease-detection/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1: Kaggle setup and dataset download
from google.colab import files

# Upload kaggle.json
files.upload()

# Setup Kaggle API credentials
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download Breast Histopathology Images dataset
!kaggle datasets download -d paultimothymooney/breast-histopathology-images

# Unzip into "data" folder
!unzip -q breast-histopathology-images.zip -d data

# Verify folder structure
!ls data


Saving kaggle.json to kaggle.json
Dataset URL: https://www.kaggle.com/datasets/paultimothymooney/breast-histopathology-images
License(s): CC0-1.0
Downloading breast-histopathology-images.zip to /content
 99% 3.06G/3.10G [00:23<00:00, 219MB/s]
100% 3.10G/3.10G [00:23<00:00, 141MB/s]
10253  10301  12872  12930  13613  14305  16554  9041  9259
10254  10302  12873  12931  13616  14306  16555  9043  9260
10255  10303  12875  12932  13617  14321  16568  9044  9261
10256  10304  12876  12933  13666  15471  16569  9073  9262
10257  10305  12877  12934  13687  15472  16570  9075  9265
10258  10306  12878  12935  13688  15473  16895  9076  9266
10259  10307  12879  12947  13689  15510  16896  9077  9267
10260  10308  12880  12948  13691  15512  8863	 9078  9290
10261  12241  12881  12949  13692  15513  8864	 9081  9291
10262  12242  12882  12951  13693  15514  8865	 9083  9319
10264  12626  12883  12954  13694  15515  8867	 9123  9320
10268  12748  12884  12955  13916  15516  8913	 9124  9321
10

In [42]:
import os
import random
from glob import glob

main_path = "data"

class0_files = []
class1_files = []

for root, dirs, files in os.walk(main_path):
    for file in files:
        if file.endswith(".png"):
            if file.endswith("class0.png"):
                class0_files.append((os.path.join(root, file), 0))
            else:
                class1_files.append((os.path.join(root, file), 1))

print(f"Class 0: {len(class0_files)} images")
print(f"Class 1: {len(class1_files)} images")

# Reduce number of images to avoid RAM issues
sample_count = 40000  # change if needed
class0_files = random.sample(class0_files, min(sample_count, len(class0_files)))
class1_files = random.sample(class1_files, min(sample_count, len(class1_files)))

# Combine and shuffle
combine_data = class0_files + class1_files
random.shuffle(combine_data)
print("Total images after balancing:", len(combine_data))

# Train/Val split
from sklearn.model_selection import train_test_split
train_files, val_files = train_test_split(combine_data, test_size=0.2, random_state=42)
print(f"Train: {len(train_files)}, Validation: {len(val_files)}")


Class 0: 397476 images
Class 1: 157572 images
Total images after balancing: 80000
Train: 64000, Validation: 16000


In [50]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

# Transform
transform = transforms.Compose([
    transforms.Resize((128,128)),  # same as your ViT patch size
    transforms.ToTensor(),
])

# Dataset class
class CancerDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path, label = self.file_list[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Create Dataset objects
train_dataset = CancerDataset(train_files, transform=transform)
val_dataset = CancerDataset(val_files, transform=transform)

# DataLoader
batch_size = 2000  # change according to your GPU
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")


Number of training batches: 32
Number of validation batches: 8


In [51]:
import torch.nn as nn

img_size = 128
patch_size = 16
num_channels = 3
num_patches = (img_size // patch_size) ** 2
num_heads = 1
embed_dim = 128
mlp_dim = 256
transformer_units = 4

class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2)
        x = x.transpose(1,2)
        return x

class TransformerArchitecture(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim)
        )

    def forward(self, x):
        residual_1 = x
        attn_output = self.self_attention(self.layer_norm_1(x), self.layer_norm_1(x), self.layer_norm_1(x))[0]
        x = attn_output + residual_1
        residual_2 = x
        x = self.mlp(self.layer_norm_2(x)) + residual_2
        return x

class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.transformer_layers = nn.Sequential(*[TransformerArchitecture() for _ in range(transformer_units)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 2)  # 2 classes: tumor / no tumor
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        B = x.size(0)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.transformer_layers(x)
        x = x[:,0]
        x = self.mlp_head(x)
        return x


In [52]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


In [54]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"\nEpoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct_epoch += (preds == labels).sum().item()
        total_epoch += labels.size(0)

        if batch_idx % 1 == 0:  # print every batch
            batch_acc = 100.0 * (preds == labels).sum().item() / labels.size(0)
            print(f"  Batch {batch_idx+1}: Loss = {loss.item():.4f}, Accuracy = {batch_acc:.2f}%")

    epoch_acc = 100.0 * correct_epoch / total_epoch
    print(f"==> Epoch {epoch+1} Summary: Total Loss = {total_loss:.4f}, Accuracy = {epoch_acc:.2f}%")



Epoch 1
  Batch 1: Loss = 0.7565, Accuracy = 48.25%
  Batch 2: Loss = 0.6883, Accuracy = 49.30%
  Batch 3: Loss = 0.7101, Accuracy = 46.95%
  Batch 4: Loss = 0.7006, Accuracy = 51.00%
  Batch 5: Loss = 0.6939, Accuracy = 51.15%
  Batch 6: Loss = 0.6928, Accuracy = 47.65%
  Batch 7: Loss = 0.6760, Accuracy = 57.45%
  Batch 8: Loss = 0.6786, Accuracy = 61.90%
  Batch 9: Loss = 0.6781, Accuracy = 51.45%
  Batch 10: Loss = 0.6832, Accuracy = 49.50%
  Batch 11: Loss = 0.6802, Accuracy = 49.45%
  Batch 12: Loss = 0.6707, Accuracy = 52.80%
  Batch 13: Loss = 0.6673, Accuracy = 72.05%
  Batch 14: Loss = 0.6616, Accuracy = 69.80%
  Batch 15: Loss = 0.6632, Accuracy = 60.65%
  Batch 16: Loss = 0.6658, Accuracy = 57.00%
  Batch 17: Loss = 0.6599, Accuracy = 57.75%
  Batch 18: Loss = 0.6597, Accuracy = 58.75%
  Batch 19: Loss = 0.6532, Accuracy = 64.25%
  Batch 20: Loss = 0.6478, Accuracy = 71.20%
  Batch 21: Loss = 0.6464, Accuracy = 73.70%
  Batch 22: Loss = 0.6444, Accuracy = 72.35%
  Batch 23

In [55]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

val_acc = 100 * correct / total
print(f"Validation Accuracy: {val_acc:.2f}%")


Validation Accuracy: 79.76%
