In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

import sys
sys.path.append("..")
from hsvit.dataset import BrainTumorDataset
from hsvit.model import ViTBackbone

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir = "/Users/darshdave/Documents/BRAINTUMOR/DATASET/FILES/"

In [3]:
dataset = BrainTumorDataset(root_dir)
test_subset = Subset(dataset, list(range(10)))
test_loader = DataLoader(test_subset, batch_size=4, shuffle=True)

In [4]:
model = ViTBackbone().to(device)
criterion_cls = nn.CrossEntropyLoss()
criterion_bbox = nn.MSELoss()
criterion_mask = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [5]:
EPOCHS = 3
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device) - 1  # Ensure 0-based labels

        class_logits, bbox_preds, boundary_preds = model(images)

        bbox_targets = torch.randn_like(bbox_preds).to(device)
        mask_targets = torch.rand_like(boundary_preds).to(device)  # Valid [0, 1] range for BCE

        loss_cls = criterion_cls(class_logits, labels)
        loss_bbox = criterion_bbox(bbox_preds, bbox_targets)
        loss_mask = criterion_mask(boundary_preds, mask_targets)

        loss = loss_cls + loss_bbox + loss_mask

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device) - 1
            class_logits, _, _ = model(images)
            preds = torch.argmax(class_logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = 100.0 * correct / total
    print(f"[Epoch {epoch+1}/{EPOCHS}] Loss: {epoch_loss:.4f} | Accuracy: {acc:.2f}%")


[Epoch 1/3] Loss: 9.1687 | Accuracy: 70.00%
[Epoch 2/3] Loss: 7.2277 | Accuracy: 70.00%
[Epoch 3/3] Loss: 6.5107 | Accuracy: 70.00%
