In [1]:
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import timm
import tkinter as tk
from tkinter import filedialog, Label, Button
from PIL import ImageTk
import threading

In [32]:
class ViTClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(ViTClassifier, self).__init__()
        self.backbone = timm.create_model("vit_base_patch16_224", pretrained=False)
        in_features = self.backbone.head.in_features
        self.backbone.head = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViTClassifier()
model.load_state_dict(torch.load("vit_brain_tumor.pth", map_location=device))
model.to(device)
model.eval()


ViTClassifier(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
        

In [34]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

class_names = ['No Tumor', 'Tumor']

In [35]:
def predict_image(image_path):
    img = Image.open(image_path).convert("RGB")
    input_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        pred = output.argmax(dim=1).item()
        prob = torch.softmax(output, dim=1)[0, pred].item()

    return class_names[pred], prob, img

In [36]:
def classify_and_update(file_path):
    result_label.config(text="Classifying...")
    window.update_idletasks()

    label, prob, img = predict_image(file_path)

    img.thumbnail((250, 250))
    img_tk = ImageTk.PhotoImage(img)
    panel.config(image=img_tk)
    panel.image = img_tk

    result_text = f"Prediction: {label}\nConfidence: {prob:.2%}"
    result_label.config(text=result_text)

In [37]:
def open_image():
    file_path = filedialog.askopenfilename(
        initialdir=os.getcwd(),
        filetypes=[("Image Files", "*.png *.jpg *.jpeg *.bmp *.tif *.tiff")]
    )
    if file_path:
        threading.Thread(target=classify_and_update, args=(file_path,), daemon=True).start()

In [43]:
window = tk.Tk()
window.title("Brain Tumor Classifier (ViT)")
window.geometry("400x450")

title_label = Label(window, text="Brain Tumor Classifier", font=("Arial", 16))
title_label.pack(pady=10)

select_button = Button(window, text="Select Image", command=open_image, font=("Arial", 12))
select_button.pack(pady=10)

panel = Label(window)
panel.pack()

result_label = Label(window, text="", font=("Arial", 14))
result_label.pack(pady=10)

window.mainloop()