In [2]:
import sys
sys.path.append("..")

import torch
import matplotlib.pyplot as plt
from torch.utils.data import Subset, DataLoader
from dsvit.model import DSViTDetector
from dsvit.dataset import BrainTumorDataset
import numpy as np
import torch.nn.functional as F

In [3]:
classes = {1: 'Meningioma', 2: 'Pituitary', 3: 'Glioma'}
dataset = BrainTumorDataset(root_dir="/Users/darshdave/Documents/BRAINTUMOR/DATASET/FILES/")
subset_ids = list(range(6))
test_loader = DataLoader(Subset(dataset, subset_ids), batch_size=1)

In [4]:
model = DSViTDetector()
model.load_state_dict(torch.load("/Users/darshdave/Documents/BRAINTUMOR/DSVIT/model-weight/best_dsvit_detector.pth"))
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [None]:
for i, (img_tensor, label) in enumerate(test_loader):
    img_tensor = img_tensor.to(device)
    label = label.item()

    with torch.no_grad():
        bbox_pred, class_logits = model(img_tensor)

    # Prediction output
    pred_class = torch.argmax(class_logits, dim=1).item()
    conf = F.softmax(class_logits, dim=1)[0, pred_class].item()
    bbox = bbox_pred[0].cpu().numpy()

    # Image preparation
    img = img_tensor.cpu().squeeze(0).squeeze(0).numpy()
    scale_factor = 10
    x, y, w, h = bbox * scale_factor

    print(f"Image shape: {img.shape}")
    print(f"Box (scaled): x={x:.2f}, y={y:.2f}, w={w:.2f}, h={h:.2f}")

    if w <= 0 or h <= 0 or x < 0 or y < 0:
        print("⚠️ Skipping invalid box.")
        continue

    # Plotting
    plt.imshow(img, cmap='gray')
    plt.gca().add_patch(plt.Rectangle((x, y), w, h, edgecolor='yellow', facecolor='none', linewidth=2))
    plt.text(x, y-5, f"Pred: {classes[pred_class+1]} ({conf:.2f})", color='lime', fontsize=10, backgroundcolor='black')
    plt.title(f"Image {i+1} – Predicted Bounding Box & Class")
    plt.axis('off')
    plt.show()