In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import mat73
import os

import sys
sys.path.append("..")

from hsvit.dataset import BrainTumorDataset
from hsvit.model import ViTBackbone
from hsvit.utils import compute_classification_metrics

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir = "/Users/darshdave/Documents/BRAINTUMOR/DATASET/FILES/"
cvind_path = "/Users/darshdave/Documents/BRAINTUMOR/DATASET/cvind.mat"

In [3]:
split_data = mat73.loadmat(cvind_path)
split_labels = split_data['cvind']

train_ids = np.where(np.array(split_labels) == 1)[0]
val_ids   = np.where(np.array(split_labels) == 2)[0]

In [None]:
dataset = BrainTumorDataset(root_dir)
train_loader = DataLoader(Subset(dataset, train_ids), batch_size=8, shuffle=True)
val_loader   = DataLoader(Subset(dataset, val_ids), batch_size=8, shuffle=False)

In [5]:
model = ViTBackbone().to(device)
criterion_cls = nn.CrossEntropyLoss()
criterion_bbox = nn.MSELoss()
criterion_mask = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
EPOCHS = 10
best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device) - 1  # convert labels to 0-based

        class_logits, bbox_preds, boundary_preds = model(images)

        bbox_targets = torch.randn_like(bbox_preds).to(device)
        mask_targets = torch.rand_like(boundary_preds).to(device)

        loss_cls = criterion_cls(class_logits, labels)
        loss_bbox = criterion_bbox(bbox_preds, bbox_targets)
        loss_mask = criterion_mask(boundary_preds, mask_targets)

        loss = loss_cls + loss_bbox + loss_mask

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device) - 1
            class_logits, _, _ = model(images)
            preds = torch.argmax(class_logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = 100.0 * correct / total

    # Save best model
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "/Users/darshdave/Documents/BRAINTUMOR/HSViT/model-weight/hsvit_best.pt")
        print(f"✅ Saved new best model at epoch {epoch+1} with accuracy {acc:.2f}%")

    print(f"[Epoch {epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Accuracy: {acc:.2f}%")

In [None]:
torch.save(model.state_dict(), "/Users/darshdave/Documents/BRAINTUMOR/HSViT/model-weight/hsvit_final.pt")
print("✅ Saved final model")

✅ Saved final model
