
# MedMNIST Model Training and Evaluation

This notebook contains the code from `models.py` and `train_and_eval_pytorch.py`.

## models.py
This file defines the ResNet architectures (ResNet-18 and ResNet-50) used for image classification.

## train_and_eval_pytorch.py
This script handles the training and evaluation of the models on the MedMNIST dataset using PyTorch.

You can run each cell to understand the structure of the models and the training process. Make sure that all necessary dependencies are installed before running the cells.



In [None]:
# models.py content
'''
Adapted from kuangliu/pytorch-cifar .
'''

import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        identity = x  # 保存原始x的副本

        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = F.relu(out + self.shortcut(identity))  # 使用非原地操作
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        identity = x  # 保存原始x的副本

        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = F.relu(out + self.shortcut(identity))  # 使用非原地操作
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=1, num_classes=2):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(in_channels, num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)


def ResNet50(in_channels, num_classes):
    return ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)

In [None]:
# train_and_eval_pytorch.py content
import argparse
import os
import time
from collections import OrderedDict
from copy import deepcopy

import medmnist
import numpy as np
import PIL
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from medmnist import INFO, Evaluator
from models import ResNet18, ResNet50
from tensorboardX import SummaryWriter
from torchvision.models import resnet18, resnet50
from tqdm import trange

### 训练部分

In [None]:



def main(data_flag, output_root, num_epochs, gpu_ids, batch_size, download, model_flag, resize, as_rgb, model_path, run):

    lr = 0.001
    gamma=0.1
    milestones = [0.5 * num_epochs, 0.75 * num_epochs]

    info = INFO[data_flag]
    task = info['task']
    n_channels = 3 if as_rgb else info['n_channels']
    n_classes = len(info['label'])

    DataClass = getattr(medmnist, info['python_class'])

    str_ids = gpu_ids.split(',')
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)
    if len(gpu_ids) > 0:
        os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_ids[0])

    device = torch.device('cuda:{}'.format(gpu_ids[0])) if gpu_ids else torch.device('cpu') 
    
    output_root = os.path.join(output_root, data_flag, time.strftime("%y%m%d_%H%M%S"))
    if not os.path.exists(output_root):
        os.makedirs(output_root)

    print('==> Preparing data...')

    if resize:
        data_transform = transforms.Compose(
            [transforms.Resize((224, 224), interpolation=PIL.Image.NEAREST), 
            transforms.ToTensor(),
            transforms.Normalize(mean=[.5], std=[.5])])
    else:
        data_transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize(mean=[.5], std=[.5])])
     
    train_dataset = DataClass(split='train', transform=data_transform, download=download, as_rgb=as_rgb)
    val_dataset = DataClass(split='val', transform=data_transform, download=download, as_rgb=as_rgb)
    test_dataset = DataClass(split='test', transform=data_transform, download=download, as_rgb=as_rgb)

    
    train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True)
    train_loader_at_eval = data.DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=False)
    val_loader = data.DataLoader(dataset=val_dataset,
                                batch_size=batch_size,
                                shuffle=False)
    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                shuffle=False)

    print('==> Building and training model...')
    
    
    if model_flag == 'resnet18':
        model =  resnet18(pretrained=False, num_classes=n_classes) if resize else ResNet18(in_channels=n_channels, num_classes=n_classes)
    elif model_flag == 'resnet50':
        model =  resnet50(pretrained=False, num_classes=n_classes) if resize else ResNet50(in_channels=n_channels, num_classes=n_classes)
    else:
        raise NotImplementedError

    model = model.to(device)

    train_evaluator = medmnist.Evaluator(data_flag, 'train')
    val_evaluator = medmnist.Evaluator(data_flag, 'val')
    test_evaluator = medmnist.Evaluator(data_flag, 'test')

    if task == "multi-label, binary-class":
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    if model_path is not None:
        model.load_state_dict(torch.load(model_path, map_location=device)['net'], strict=True)
        train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, device, run, output_root)
        val_metrics = test(model, val_evaluator, val_loader, task, criterion, device, run, output_root)
        test_metrics = test(model, test_evaluator, test_loader, task, criterion, device, run, output_root)

        print('train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2]) + \
              'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2]) + \
              'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2]))

    if num_epochs == 0:
        return

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

    logs = ['loss', 'auc', 'acc']
    train_logs = ['train_'+log for log in logs]
    val_logs = ['val_'+log for log in logs]
    test_logs = ['test_'+log for log in logs]
    log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)
    
    writer = SummaryWriter(log_dir=os.path.join(output_root, 'Tensorboard_Results'))

    best_auc = 0
    best_epoch = 0
    best_model = deepcopy(model)

    global iteration
    iteration = 0
    
    for epoch in trange(num_epochs):        
        train_loss = train(model, train_loader, task, criterion, optimizer, device, writer)
        
        train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, device, run)
        val_metrics = test(model, val_evaluator, val_loader, task, criterion, device, run)
        test_metrics = test(model, test_evaluator, test_loader, task, criterion, device, run)
        
        scheduler.step()
        
        for i, key in enumerate(train_logs):
            log_dict[key] = train_metrics[i]
        for i, key in enumerate(val_logs):
            log_dict[key] = val_metrics[i]
        for i, key in enumerate(test_logs):
            log_dict[key] = test_metrics[i]

        for key, value in log_dict.items():
            writer.add_scalar(key, value, epoch)
            
        cur_auc = val_metrics[1]
        if cur_auc > best_auc:
            best_epoch = epoch
            best_auc = cur_auc
            best_model = deepcopy(model)
            print('cur_best_auc:', best_auc)
            print('cur_best_epoch', best_epoch)

    state = {
        'net': best_model.state_dict(),
    }

    path = os.path.join(output_root, 'best_model.pth')
    torch.save(state, path)

    train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, device, run, output_root)
    val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, device, run, output_root)
    test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, device, run, output_root)

    train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
    val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
    test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

    log = '%s\n' % (data_flag) + train_log + val_log + test_log
    print(log)
            
    with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
        f.write(log)  
            
    writer.close()


def train(model, train_loader, task, criterion, optimizer, device, writer):
    total_loss = []
    global iteration

    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs.to(device))

        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32).to(device)
            loss = criterion(outputs, targets)
        else:
            targets = torch.squeeze(targets, 1).long().to(device)
            loss = criterion(outputs, targets)

        total_loss.append(loss.item())
        writer.add_scalar('train_loss_logs', loss.item(), iteration)
        iteration += 1

        loss.backward()
        optimizer.step()
    
    epoch_loss = sum(total_loss)/len(total_loss)
    return epoch_loss


def test(model, evaluator, data_loader, task, criterion, device, run, save_folder=None):

    model.eval()
    
    total_loss = []
    y_score = torch.tensor([]).to(device)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            outputs = model(inputs.to(device))
            
            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32).to(device)
                loss = criterion(outputs, targets)
                m = nn.Sigmoid()
                outputs = m(outputs).to(device)
            else:
                targets = torch.squeeze(targets, 1).long().to(device)
                loss = criterion(outputs, targets)
                m = nn.Softmax(dim=1)
                outputs = m(outputs).to(device)
                targets = targets.float().resize_(len(targets), 1)

            total_loss.append(loss.item())
            y_score = torch.cat((y_score, outputs), 0)

        y_score = y_score.detach().cpu().numpy()
        auc, acc = evaluator.evaluate(y_score, save_folder, run)
        
        test_loss = sum(total_loss) / len(total_loss)

        return [test_loss, auc, acc]



    
# 在这里直接设置参数值
data_flag = 'organamnist'
output_root = './output'
num_epochs = 10
gpu_ids = '0'
batch_size = 64
download = False
model_flag = 'resnet18'
resize = True
as_rgb = True
model_path = None
run = 'experiment_to_shap_2'

# 调用main函数
main(data_flag, output_root, num_epochs, gpu_ids, batch_size, download, model_flag, resize, as_rgb, model_path, run)

### LIME示例 ###

In [None]:
import torch
from torchvision import transforms
import medmnist
from models import ResNet18  # 确保与您的模型定义一致
data_flag = 'organamnist'
output_root = './output'
num_epochs = 10
gpu_ids = '0'
batch_size = 64
download = False
model_flag = 'resnet18'
resize = True
as_rgb = True
model_path = "C:\\Users\\10618\\Desktop\\experiments-main\\MedMNIST2D\\output\\organamnist\\231216_213642\\best_model.pth"
#model_path = "C:\\Users\\10618\\Desktop\\experiments-main\\weights_organamnist\\resnet18_224_3.pth"
run = 'experiment_to_shap_2'


lr = 0.001
gamma=0.1
milestones = [0.5 * num_epochs, 0.75 * num_epochs]

info = INFO[data_flag]
task = info['task']
n_channels = 3 if as_rgb else info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

str_ids = gpu_ids.split(',')
gpu_ids = []
for str_id in str_ids:
    id = int(str_id)
    if id >= 0:
        gpu_ids.append(id)
if len(gpu_ids) > 0:
    os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_ids[0])

device = torch.device('cuda:{}'.format(gpu_ids[0])) if gpu_ids else torch.device('cpu') 

output_root = os.path.join(output_root, data_flag, time.strftime("%y%m%d_%H%M%S"))
if not os.path.exists(output_root):
    os.makedirs(output_root)

print('==> Preparing data...')

if resize:
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224), interpolation=PIL.Image.NEAREST), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5])])
else:
    data_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5])])

train_dataset = DataClass(split='train', transform=data_transform, download=download, as_rgb=as_rgb)
val_dataset = DataClass(split='val', transform=data_transform, download=download, as_rgb=as_rgb)
test_dataset = DataClass(split='test', transform=data_transform, download=download, as_rgb=as_rgb)


train_loader = data.DataLoader(dataset=train_dataset,
                            batch_size=batch_size,
                            shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset,
                            batch_size=batch_size,
                            shuffle=False)
val_loader = data.DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset,
                            batch_size=batch_size,
                            shuffle=False)

print('==> Building and training model...')


if model_flag == 'resnet18':
    model =  resnet18(pretrained=False, num_classes=n_classes) if resize else ResNet18(in_channels=n_channels, num_classes=n_classes)
elif model_flag == 'resnet50':
    model =  resnet50(pretrained=False, num_classes=n_classes) if resize else ResNet50(in_channels=n_channels, num_classes=n_classes)
else:
    raise NotImplementedError

model = model.to(device)

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

if model_path is not None:
    model.load_state_dict(torch.load(model_path, map_location=device)['net'], strict=True)
    model.eval()
        
images, _ = next(iter(test_loader))
images = images.to(device)  # 移动数据到相同的设备


In [None]:
from lime import lime_image

explainer = lime_image.LimeImageExplainer()


In [None]:
import numpy as np

# 加载数据集
data = np.load('C:\\Users\\10618\\.medmnist\\organamnist.npz')

X_test = data['test_images']  # 测试集图像
y_test = data['test_labels']  # 测试集标签

# 检查数据
print(X_test.shape, y_test.shape)


In [None]:
print(y_test)

In [None]:
import random

# 选择几个样本
num_samples_for_lime = 5  # 样本数量
selected_indices = random.sample(range(len(X_test)), num_samples_for_lime)
samples_for_lime = X_test[selected_indices]


In [None]:
import cv2

# 预处理函数
def preprocess_image(img):
    img = cv2.resize(img, (224, 224))  # 调整大小
    if len(img.shape) == 2:  # 如果是灰度图，转换为三通道
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    img = img.astype('float32') / 255  # 归一化
    return img

# 预处理样本
preprocessed_samples = np.array([preprocess_image(img) for img in samples_for_lime])


In [None]:
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm

# 创建 LIME ImageExplainer
explainer = lime_image.LimeImageExplainer()

# 选择一个样本
sample = preprocessed_samples[0]  # 选择第一个样本



In [None]:
def model_predict(images):
    model.eval()
    with torch.no_grad():
        # 将图像转换为模型可接受的格式
        images = torch.tensor(images.transpose((0, 3, 1, 2))).float()  # 从 NHWC 转换为 NCHW
        
        # 确定设备
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        images = images.to(device)
        
        # 进行预测
        preds = model(images)
        
        # 转换为概率
        probs = torch.nn.functional.softmax(preds, dim=1).cpu().numpy()
    return probs


In [None]:
from lime.wrappers.scikit_image import SegmentationAlgorithm

# 创建分割算法，用于定义解释的超像素
segmenter = SegmentationAlgorithm('slic', n_segments=100, compactness=1, sigma=1)

In [None]:


# 对选定的样本生成 LIME 解释
explanation = explainer.explain_instance(sample, model_predict, top_labels=5, hide_color=0, num_samples=1000, segmentation_fn=segmenter)


In [None]:
from lime.wrappers.scikit_image import SegmentationAlgorithm

# 创建分割算法，用于定义解释的超像素
segmenter = SegmentationAlgorithm('slic', n_segments=100, compactness=1, sigma=1)

# 对选定的样本生成 LIME 解释
explanation = explainer.explain_instance(sample, model_predict, top_labels=5, hide_color=0, num_samples=1000, segmentation_fn=segmenter)


In [None]:
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt

In [None]:


# 可视化第一个预测类别的解释
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=5, hide_rest=False)

plt.imshow(mark_boundaries(temp, mask))
plt.title('LIME Explanation')
plt.show()


In [None]:
import random

num_samples_for_lime = 81  # 选择 9x9 个样本
selected_indices = random.sample(range(len(X_test)), num_samples_for_lime)
samples_for_lime = X_test[selected_indices]

# 预处理样本
preprocessed_samples = np.array([preprocess_image(img) for img in samples_for_lime])


In [None]:
# 创建一个解释器实例
explainer = lime_image.LimeImageExplainer()

# 对每个样本进行 LIME 解释
explanations = []
predictions = []
for sample in preprocessed_samples:
    # 对样本进行解释
    explanation = explainer.explain_instance(sample, model_predict, top_labels=3, hide_color=0, num_samples=100, segmentation_fn=segmenter)
    explanations.append(explanation)
    # 获取预测
    prob = model_predict(sample[np.newaxis, ...])
    prediction = np.argmax(prob, axis=1)[0]  # 取概率最高的类别
    predictions.append(prediction)



In [None]:
# 显示图像
import matplotlib.pyplot as plt

true_labels = y_test[selected_indices]

# 显示图像、预测和解释
fig, axs = plt.subplots(18, 9, figsize=(45, 90))  # 创建 18x9 的图表，因为每对图像需要两列

for i, (sample, prediction, true_label) in enumerate(zip(samples_for_lime, predictions, true_labels)):
    row = 2 * (i // 9)
    col = i % 9
    
    # 显示原图、预测标签和真实标签
    axs[row, col].imshow(sample, cmap='gray')
    axs[row, col].set_title(f'Pred: {prediction}\nTrue: {true_label}', fontsize=22)  # 显示预测和真实标签
    axs[row, col].axis('off')
    
    # 显示解释图
    temp, mask = explanations[i].get_image_and_mask(explanations[i].top_labels[0], positive_only=False, num_features=5, hide_rest=False)
    axs[row + 1, col].imshow(mark_boundaries(temp / 2 + 0.5, mask))
    axs[row + 1, col].axis('off')

plt.tight_layout()
plt.show()



In [None]:
print("First few predictions:", predictions[:10])
print("First few true labels:", true_labels[:10])

In [None]:
if len(predictions) != len(true_labels):
    print("Error: The number of predictions does not match the number of true labels.")
else:
    # 将预测和真实标签转换为 NumPy 数组以便于计算
    predictions_array = np.array(predictions)
    true_labels_array = np.array(true_labels)

true_labels_array = np.squeeze(true_labels)

# 输出转换后的前几个真实标签进行比较
print("First few true labels after squeezing:", true_labels_array[:10])

# 重新计算正确预测的数量和正确率
correct_predictions = np.sum(predictions_array == true_labels_array)
accuracy = correct_predictions / len(true_labels_array)
print(f"Correct Predictions: {correct_predictions}")
print(f"Accuracy: {accuracy * 100:.2f}%")

In [None]:
# 显示图像
import matplotlib.pyplot as plt

# 显示图像和解释
fig, axs = plt.subplots(18, 9, figsize=(45, 90))  # 创建 18x9 的图表，每对图像需要两列

for i, (sample, explanation, prediction) in enumerate(zip(samples_for_lime, explanations, predictions)):
    row = 2 * (i // 9)
    col = i % 9

    # 显示原图及其预测
    axs[row, col].imshow(sample, cmap='gray')
    axs[row, col].set_title(f'Prediction: {prediction}', fontsize = 22)
    axs[row, col].axis('off')

    # 显示解释图
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[1], positive_only=False, num_features=5, hide_rest=False)
    axs[row+1, col].imshow(mark_boundaries(temp / 2 + 0.5, mask))
    axs[row+1, col].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# 显示图像
import matplotlib.pyplot as plt

# 显示图像和解释
fig, axs = plt.subplots(18, 9, figsize=(45, 90))  # 创建 18x9 的图表，每对图像需要两列

for i, (sample, explanation, prediction) in enumerate(zip(samples_for_lime, explanations, predictions)):
    row = 2 * (i // 9)
    col = i % 9

    # 显示原图及其预测
    axs[row, col].imshow(sample, cmap='gray')
    axs[row, col].set_title(f'Prediction: {prediction}', fontsize = 22)
    axs[row, col].axis('off')

    # 显示解释图
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[2], positive_only=False, num_features=5, hide_rest=False)
    axs[row+1, col].imshow(mark_boundaries(temp / 2 + 0.5, mask))
    axs[row+1, col].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# 显示多个标签的解释
temp, mask = explanation.get_image_and_mask(label=explanation.top_labels[0], positive_only=True, num_features=10, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))


In [None]:
# 显示正面影响的区域
temp, mask = explanation.get_image_and_mask(explanation.top_labels[1], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
plt.show()



In [None]:
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries

#  'explanations' 是样本集生成的一系列解释
for explanation in explanations:
    num_labels = len(explanation.top_labels)
    print(num_labels)

In [None]:

for explanation in explanations:
    # 如果只有一个标签，需要特殊处理
    if num_labels == 1:
        fig, axs = plt.subplots(1, figsize=(6, 6))
        axs = [axs]  # 将 axs 转换为一个列表
    else:
        fig, axs = plt.subplots(1, num_labels, figsize=(20, 10))

    for i, label in enumerate(explanation.top_labels):
        temp, mask = explanation.get_image_and_mask(label, positive_only=False, num_features=5, hide_rest=False)
        axs[i].imshow(mark_boundaries(temp / 2 + 0.5, mask))
        axs[i].set_title(f'Label: {label}')
        axs[i].axis('off')

    plt.tight_layout()
    plt.show()



In [None]:
num_labels = 3  # 考虑的 top labels 数量
num_columns = 1 + num_labels  # 1列原始图像 + num_labels 列 LIME 解释图

# 对每个样本和对应的解释进行迭代
for original, explanation in zip(preprocessed_samples, explanations):
    fig, axs = plt.subplots(1, num_columns, figsize=(20, 5))  # 创建一行，num_columns 列的图表

    # 显示原始图像
    axs[0].imshow(original)
    axs[0].set_title('Original Image')
    axs[0].axis('off')

    # 显示每个 top label 的 LIME 解释图
    for i, label in enumerate(explanation.top_labels[:num_labels]):
        temp, mask = explanation.get_image_and_mask(label, positive_only=False, num_features=5, hide_rest=False)
        axs[i + 1].imshow(mark_boundaries(temp / 2 + 0.5, mask))  # i + 1 因为第一列是原始图像
        axs[i + 1].set_title(f'Label: {label}')
        axs[i + 1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
unique_labels, counts = np.unique(y_test, return_counts=True)
print(dict(zip(unique_labels, counts)))


In [None]:
import matplotlib.pyplot as plt

# 创建一个 9x9 的图表
fig, axs = plt.subplots(5, 3, figsize=(8, 8))
axs = axs.flatten()  # 将 axs 转换为一维数组以便于索引

# 遍历每个唯一标签及其计数
for i, (label, count) in enumerate(zip(unique_labels, counts)):
    if i < 15:  # 只处理前 81 个标签
        axs[i].text(0.5, 0.5, f'Label: {label}\nCount: {count}', horizontalalignment='center', verticalalignment='center', fontsize=12)
        axs[i].axis('off')
    else:
        break

# 隐藏剩余的子图（如果有的话）
for j in range(i, 15):
    axs[j].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# wasted
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='RUN Baseline model of MedMNIST2D')

    parser.add_argument('--data_flag',
                        default='pathmnist',
                        type=str)
    parser.add_argument('--output_root',
                        default='./output',
                        help='output root, where to save models and results',
                        type=str)
    parser.add_argument('--num_epochs',
                        default=100,
                        help='num of epochs of training, the script would only test model if set num_epochs to 0',
                        type=int)
    parser.add_argument('--gpu_ids',
                        default='0',
                        type=str)
    parser.add_argument('--batch_size',
                        default=128,
                        type=int)
    parser.add_argument('--download',
                        action="store_true")
    parser.add_argument('--resize',
                        help='resize images of size 28x28 to 224x224',
                        action="store_true")
    parser.add_argument('--as_rgb',
                        help='convert the grayscale image to RGB',
                        action="store_true")
    parser.add_argument('--model_path',
                        default=None,
                        help='root of the pretrained model to test',
                        type=str)
    parser.add_argument('--model_flag',
                        default='resnet18',
                        help='choose backbone from resnet18, resnet50',
                        type=str)
    parser.add_argument('--run',
                        default='model1',
                        help='to name a standard evaluation csv file, named as {flag}_{split}_[AUC]{auc:.3f}_[ACC]{acc:.3f}@{run}.csv',
                        type=str)


    args = parser.parse_args()
    data_flag = args.data_flag
    output_root = args.output_root
    num_epochs = args.num_epochs
    gpu_ids = args.gpu_ids
    batch_size = args.batch_size
    download = args.download
    model_flag = args.model_flag
    resize = args.resize
    as_rgb = args.as_rgb
    model_path = args.model_path
    run = args.run