# BRAIN-ViT Model Inferencing for 3-class classification

#### class labels: 

0: CN

1: AD

2: MCI

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from data.data_loader_3c import NiftiDataset
from sklearn.metrics import precision_score, recall_score, f1_score
from config.BRAIN_ViT_config import get_config
torch.manual_seed(100)

config = get_config()
device = config["device"]


#DEFINE YOUR BASE DIR PATH
base_dir = "----your path----"

# Load train, validation, and test datasets
test_dataset = NiftiDataset("data/metadata.csv", base_dir=base_dir)
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=True)
print("test_loader:",len(test_loader.dataset))

model = torch.load('BRAIN-ViT.pt')
print(model)

criterion = nn.CrossEntropyLoss()
checkpoint_path = "checkpoints/BRAIN-ViT_ckp.pth"
if checkpoint_path and os.path.isfile(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    start_epoch = checkpoint['epoch']
    best_val_accuracy = checkpoint.get('val_accuracy', 0.0)
    print(f"Best validation accuracy: {best_val_accuracy:.4f}")

# uncomment for infering on multiple GPUs
# if torch.cuda.device_count() > 1:
#     model = DataParallel(model)
model = model.to(device)

In [None]:

# Function to get the predictions along with other metrics
def evaluate_model(model, config, criterion, data_loader, action):    
    model.eval()
    running_loss, correct, total = 0, 0, 0
    all_labels = []
    all_preds = []
    all_probs = []

    config['viz_attn_weights'] = False
    config['viz_topk'] = False

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs, config)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Collect labels and predictions for metric calculations
            all_labels.extend(labels.detach().cpu().numpy().tolist())
            all_preds.extend(predicted.detach().cpu().numpy().tolist())

    # Calculate accuracy
    accuracy = 100 * correct / total
    total_loss = running_loss / len(data_loader)
    
    # Calculate precision, recall, and F1 score with macro averaging for multiclass classification
    precision = precision_score(all_labels, all_preds, average="macro")
    recall = recall_score(all_labels, all_preds, average="macro")
    f1 = f1_score(all_labels, all_preds, average="macro")

    print(f"{action.capitalize()} Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, Loss:{total_loss:.4f}")

    return all_labels, all_preds


In [None]:
labels, preds = evaluate_model(model, config, criterion, test_loader, action="test")

In [None]:
print("Labels:     ",list(labels))
print("\nPredictions:",list(preds))