In [1]:
!pip install -r /mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/requirements.txt


Defaulting to user installation because normal site-packages is not writeable


In [14]:
!git clone https://github.com/YubiaoYue/MedMamba.git

Cloning into 'MedMamba'...
remote: Enumerating objects: 127, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (94/94), done.[K
remote: Total 127 (delta 53), reused 90 (delta 28), pack-reused 4 (from 1)[K
Receiving objects: 100% (127/127), 353.24 KiB | 2.44 MiB/s, done.
Resolving deltas: 100% (53/53), done.


In [19]:
import os
os.chdir('/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/MedMambagit') 

In [20]:
import sys

sys.path.append('/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/MedMambagit')

In [21]:
import MedMamba

In [19]:
import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np 
import importlib
import MedMamba

importlib.reload(MedMamba)


def calculate_metrics(y_true, y_pred):
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    return precision, recall, f1


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    train_dataset = datasets.ImageFolder(root="/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/colored_images",
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root="/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/colored_images",
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    num_classes = len(cla_dict)
    model_name = "Medmamba"

    net = MedMamba.VSSM(num_classes=num_classes)
    net.to(device)

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    epochs = 100
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    patience = 5
    early_stop_counter = 0

    for epoch in range(epochs):
        # Train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

        # Validate
        net.eval()
        acc = 0.0  # accurate number / epoch
        val_loss = 0.0
        all_labels = []
        all_preds = []

        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                loss = loss_function(outputs, val_labels.to(device))
                val_loss += loss.item()
                
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                all_labels.extend(val_labels.cpu().numpy())
                all_preds.extend(predict_y.cpu().numpy())

        val_accurate = acc / val_num
        val_loss /= len(validate_loader)
        precision, recall, f1 = calculate_metrics(np.array(all_labels), np.array(all_preds))

        print('[epoch %d] train_loss: %.3f  val_loss: %.3f  val_accuracy: %.3f  precision: %.3f  recall: %.3f  f1: %.3f' %
              (epoch + 1, running_loss / train_steps, val_loss, val_accurate, precision, recall, f1))

        scheduler.step()

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
            early_stop_counter = 0
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            print("Early stopping triggered. Stopping training.")
            break

    print('Finished Training')


if __name__ == '__main__':
    main()


torch.Size([1, 6])
using cuda:0 device.
Using 8 dataloader workers every process
using 3656 images for training, 3656 images for validation.
train epoch[1/100] loss:1.883: 100%|██████████| 115/115 [02:14<00:00,  1.17s/it]
100%|██████████| 115/115 [00:26<00:00,  4.29it/s]
[epoch 1] train_loss: 1.309  val_loss: 1.158  val_accuracy: 0.520  precision: 0.408  recall: 0.520  f1: 0.388


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[2/100] loss:1.703: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.33it/s]
[epoch 2] train_loss: 1.204  val_loss: 0.999  val_accuracy: 0.647  precision: 0.519  recall: 0.647  f1: 0.567


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[3/100] loss:0.743: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.35it/s]
[epoch 3] train_loss: 1.116  val_loss: 0.954  val_accuracy: 0.675  precision: 0.520  recall: 0.675  f1: 0.587


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[4/100] loss:0.968: 100%|██████████| 115/115 [02:10<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.28it/s]
[epoch 4] train_loss: 1.085  val_loss: 0.860  val_accuracy: 0.685  precision: 0.667  recall: 0.685  f1: 0.611


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[5/100] loss:2.730: 100%|██████████| 115/115 [02:11<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.27it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 5] train_loss: 1.040  val_loss: 0.878  val_accuracy: 0.675  precision: 0.590  recall: 0.675  f1: 0.609
train epoch[6/100] loss:0.973: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.31it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 6] train_loss: 1.043  val_loss: 1.054  val_accuracy: 0.643  precision: 0.535  recall: 0.643  f1: 0.582
train epoch[7/100] loss:0.629: 100%|██████████| 115/115 [02:10<00:00,  1.13s/it]
100%|██████████| 115/115 [00:28<00:00,  3.99it/s]
[epoch 7] train_loss: 1.011  val_loss: 0.863  val_accuracy: 0.714  precision: 0.584  recall: 0.714  f1: 0.634


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[8/100] loss:0.763: 100%|██████████| 115/115 [02:10<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.28it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 8] train_loss: 1.003  val_loss: 0.845  val_accuracy: 0.696  precision: 0.614  recall: 0.696  f1: 0.616
train epoch[9/100] loss:0.481: 100%|██████████| 115/115 [02:11<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.29it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 9] train_loss: 0.988  val_loss: 0.870  val_accuracy: 0.711  precision: 0.579  recall: 0.711  f1: 0.631
train epoch[10/100] loss:1.221: 100%|██████████| 115/115 [02:12<00:00,  1.16s/it]
100%|██████████| 115/115 [00:26<00:00,  4.30it/s]
[epoch 10] train_loss: 0.969  val_loss: 0.810  val_accuracy: 0.720  precision: 0.581  recall: 0.720  f1: 0.637


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[11/100] loss:1.054: 100%|██████████| 115/115 [02:10<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.31it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 11] train_loss: 0.943  val_loss: 0.839  val_accuracy: 0.719  precision: 0.580  recall: 0.719  f1: 0.636
train epoch[12/100] loss:0.884: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.29it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 12] train_loss: 0.921  val_loss: 0.825  val_accuracy: 0.719  precision: 0.580  recall: 0.719  f1: 0.636
train epoch[13/100] loss:1.051: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.33it/s]
[epoch 13] train_loss: 0.914  val_loss: 0.839  val_accuracy: 0.721  precision: 0.581  recall: 0.721  f1: 0.637


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[14/100] loss:0.595: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:29<00:00,  3.93it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 14] train_loss: 0.917  val_loss: 0.831  val_accuracy: 0.718  precision: 0.574  recall: 0.718  f1: 0.634
train epoch[15/100] loss:1.326: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.30it/s]
[epoch 15] train_loss: 0.909  val_loss: 0.818  val_accuracy: 0.722  precision: 0.584  recall: 0.722  f1: 0.639


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[16/100] loss:0.639: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.33it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 16] train_loss: 0.900  val_loss: 0.819  val_accuracy: 0.721  precision: 0.584  recall: 0.721  f1: 0.639
train epoch[17/100] loss:0.816: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:29<00:00,  3.93it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 17] train_loss: 0.907  val_loss: 0.830  val_accuracy: 0.721  precision: 0.581  recall: 0.721  f1: 0.637
train epoch[18/100] loss:1.141: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.29it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 18] train_loss: 0.906  val_loss: 0.821  val_accuracy: 0.721  precision: 0.582  recall: 0.721  f1: 0.638
train epoch[19/100] loss:1.752: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.29it/s]
[epoch 19] train_loss: 0.912  val_loss: 0.817  val_accuracy: 0.722  precision: 0.585  recall: 0.722  f1: 0.640


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[20/100] loss:1.136: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:29<00:00,  3.92it/s]
[epoch 20] train_loss: 0.903  val_loss: 0.806  val_accuracy: 0.723  precision: 0.585  recall: 0.723  f1: 0.640


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


train epoch[21/100] loss:1.116: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.30it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 21] train_loss: 0.906  val_loss: 0.815  val_accuracy: 0.722  precision: 0.587  recall: 0.722  f1: 0.640
train epoch[22/100] loss:0.851: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.33it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 22] train_loss: 0.880  val_loss: 0.819  val_accuracy: 0.722  precision: 0.583  recall: 0.722  f1: 0.639
train epoch[23/100] loss:0.745: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:29<00:00,  3.94it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 23] train_loss: 0.884  val_loss: 0.826  val_accuracy: 0.722  precision: 0.583  recall: 0.722  f1: 0.639
train epoch[24/100] loss:0.889: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.31it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 24] train_loss: 0.894  val_loss: 0.821  val_accuracy: 0.722  precision: 0.582  recall: 0.722  f1: 0.639
train epoch[25/100] loss:1.380: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.31it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 25] train_loss: 0.887  val_loss: 0.837  val_accuracy: 0.721  precision: 0.580  recall: 0.721  f1: 0.637
train epoch[26/100] loss:1.842: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:29<00:00,  3.93it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 26] train_loss: 0.895  val_loss: 0.809  val_accuracy: 0.722  precision: 0.587  recall: 0.722  f1: 0.640
train epoch[27/100] loss:1.232: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.31it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 27] train_loss: 0.900  val_loss: 0.811  val_accuracy: 0.722  precision: 0.586  recall: 0.722  f1: 0.640
train epoch[28/100] loss:0.737: 100%|██████████| 115/115 [02:12<00:00,  1.15s/it]
100%|██████████| 115/115 [00:26<00:00,  4.30it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 28] train_loss: 0.885  val_loss: 0.821  val_accuracy: 0.722  precision: 0.583  recall: 0.722  f1: 0.639
train epoch[29/100] loss:0.723: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:29<00:00,  3.92it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[epoch 29] train_loss: 0.894  val_loss: 0.817  val_accuracy: 0.722  precision: 0.586  recall: 0.722  f1: 0.640
train epoch[30/100] loss:0.902: 100%|██████████| 115/115 [02:09<00:00,  1.13s/it]
100%|██████████| 115/115 [00:26<00:00,  4.31it/s]
[epoch 30] train_loss: 0.878  val_loss: 0.820  val_accuracy: 0.722  precision: 0.583  recall: 0.722  f1: 0.639
Early stopping triggered. Stopping training.
Finished Training


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
import importlib
import MedMamba  # import model
importlib.reload(MedMamba)

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    train_dataset = datasets.ImageFolder(root="/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/colored_images",
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root="/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/colored_images",
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))


    net = MedMamba.VSSM(num_classes=5)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 100
    best_acc = 0.0
    save_path = './{}Net.pth'.format("MedmambaNet1")
    train_steps = len(train_loader)
    for epoch in range(epochs):

        

        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        patience=5
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
            early_stop_counter = 0
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            print("Early stopping triggered. Stopping training.")
            break


    print('Finished Training')


if __name__ == '__main__':
    main()


torch.Size([1, 6])
using cuda:0 device.
Using 8 dataloader workers every process
using 3656 images for training, 3656 images for validation.
train epoch[1/100] loss:1.377: 100%|██████████| 115/115 [04:41<00:00,  2.45s/it]
100%|██████████| 115/115 [00:25<00:00,  4.55it/s]
[epoch 1] train_loss: 1.151  val_accuracy: 0.633
train epoch[2/100] loss:1.069: 100%|██████████| 115/115 [04:41<00:00,  2.45s/it]
100%|██████████| 115/115 [00:25<00:00,  4.53it/s]
[epoch 2] train_loss: 1.019  val_accuracy: 0.689
train epoch[3/100] loss:1.707: 100%|██████████| 115/115 [04:39<00:00,  2.43s/it]
100%|██████████| 115/115 [00:25<00:00,  4.53it/s]
[epoch 3] train_loss: 0.976  val_accuracy: 0.711
train epoch[4/100] loss:1.480: 100%|██████████| 115/115 [04:41<00:00,  2.45s/it]
100%|██████████| 115/115 [00:25<00:00,  4.51it/s]
[epoch 4] train_loss: 0.931  val_accuracy: 0.712
train epoch[5/100] loss:0.433: 100%|██████████| 115/115 [04:41<00:00,  2.45s/it]
100%|██████████| 115/115 [00:25<00:00,  4.52it/s]
[epoch 5

In [27]:
import torch
from sklearn.metrics import precision_score, recall_score, f1_score
from torchvision import transforms, datasets
import MedMamba  # Import your model class

# Define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the saved model
model = MedMamba.VSSM(num_classes=5)  # Replace with your model class
model.load_state_dict(torch.load("/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/MedMambagit/MedmambaNet1Net.pth"))
model.to(device)
model.eval()

from torchvision import transforms, datasets

data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load test dataset
test_dataset = datasets.ImageFolder(root="/mnt/c/Users/santhosh/Downloads/MedMamba-main/MedMamba-main/colored_images", transform=data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Get the number of classes
num_classes = len(test_dataset.classes)
class_names = test_dataset.classes


In [28]:
import numpy as np

# Initialize metrics
all_labels = []
all_preds = []
total_loss = 0.0
loss_function = torch.nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        # Compute loss
        loss = loss_function(outputs, labels)
        total_loss += loss.item()

        # Get predictions
        preds = torch.max(outputs, dim=1)[1]
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Convert to numpy arrays
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)

# Compute accuracy
accuracy = np.mean(all_preds == all_labels)

# Compute precision, recall, F1-score for each class
precision = precision_score(all_labels, all_preds, average=None, labels=range(num_classes))
recall = recall_score(all_labels, all_preds, average=None, labels=range(num_classes))
f1 = f1_score(all_labels, all_preds, average=None, labels=range(num_classes))

# Compute overall metrics (macro-average)
precision_macro = precision_score(all_labels, all_preds, average='macro')
recall_macro = recall_score(all_labels, all_preds, average='macro')
f1_macro = f1_score(all_labels, all_preds, average='macro')

print(f"Test Loss: {total_loss / len(test_loader):.4f}")
print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(f"Precision (per class): {precision}")
print(f"Recall (per class): {recall}")
print(f"F1-score (per class): {f1}")
print(f"Precision (macro): {precision_macro:.4f}")
print(f"Recall (macro): {recall_macro:.4f}")
print(f"F1-score (macro): {f1_macro:.4f}")


Test Loss: 0.4758
Test Accuracy: 83.29%
Precision (per class): [0.66666667 0.72226999 0.95018747 0.72897196 0.67088608]
Recall (per class): [0.60989011 0.84084084 0.98282548 0.52881356 0.2746114 ]
F1-score (per class): [0.63701578 0.77705828 0.96623094 0.6129666  0.38970588]
Precision (macro): 0.7478
Recall (macro): 0.6474
F1-score (macro): 0.6766
