<a href="https://colab.research.google.com/github/Jack1447/Medical-Interpretable-LLM/blob/main/Task2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### MLP在复杂数据集上的分类与决策过程可视化

In [2]:
import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix
import statistics
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import wandb

建立数据集类

In [3]:
def generate_dataset(batch_size):
    # 下载数据集
    train_transformer=transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2,contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ])

    test_transformer=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ])


    train_dataset=datasets.FashionMNIST(
        root='./data',train=True,download=True,transform=train_transformer
    )
    test_dataset=datasets.FashionMNIST(
        root='./data',train=False,download=True,transform=test_transformer
    )

    train_dataloader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True)
    test_dataloader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False,pin_memory=True)


    print(f"训练集数目:{len(train_dataset)} 测试集数目:{len(test_dataset)}")
    print(f"训练集批数目:{len(train_dataloader)} 测试集批数目:{len(test_dataloader)}\n")

    return train_dataloader,test_dataloader


建立MLP模型

In [4]:
class ClassMLP(nn.Module):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.layers(x)


指标计算及可视化

In [5]:
def calculate_metrics(label,predicted,average):
    if label.is_cuda:
        label=label.cpu()
    if predicted.is_cuda:
        predicted=predicted.cpu()

    label=label.numpy()
    predicted=predicted.numpy()   # sklearn函数需要numpy数组

    # 准确率
    accuracy=accuracy_score(label,predicted)

    # 精确率 召回率 F1分数
    precision=precision_score(label,predicted,average=average,zero_division=0)
    recall=recall_score(label,predicted,average=average,zero_division=0)
    f1=f1_score(label,predicted,average=average,zero_division=0)

    return accuracy,precision,recall,f1


def validate(model,test_dataloader,criterion,device):
    model.eval()

    with torch.no_grad():
        total_loss=0
        for image,label in test_dataloader:
            image,label=image.to(device),label.to(device)
            output=model(image)
            loss=criterion(output,label)
            total_loss+=loss.item()
    model.train()

    return total_loss/len(test_dataloader)

# 绘制混淆矩阵
def plot_confusion_matrix(model,test_dataloader,device):
    model.eval()
    preds=[]
    labels=[]

    with torch.no_grad():
        for image,label in test_dataloader:
            image,label=image.to(device),label.to(device)
            output=model(image)   # 可见 model 以batch的形式输入
            predicted = torch.argmax(output, dim=1)

            preds.append(predicted.cpu().numpy())
            labels.append(label.cpu().numpy())

    preds=np.concatenate(preds)
    labels=np.concatenate(labels)
    cm = confusion_matrix(labels, preds)

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("Labels")
    plt.title("Confusion Matrix")
    plt.savefig("Confusion Matrix.png")
    plt.close()

    wandb.log({"Confusion Matrix": wandb.Image("Confusion Matrix.png")})

开始训练

In [6]:
def train(model,train_dataloader,test_dataloader,epochs,device):
    # 定义优化器和学习率调度器
    optimizer = optim.AdamW(
        model.parameters(),
        lr=5e-3,
        weight_decay=0.01
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=epochs,
        eta_min=1e-5
    )

    # 损失函数
    criterion=nn.CrossEntropyLoss()

    # 初始化
    wandb.init(
    project="FashionMNIST",
    name=f"MLP_Class_Model",
    config={
        "epochs": epochs,
        "device": str(device),
        "optimizer": str(optimizer),
        "scheduler": str(scheduler),
        'criterion':str(criterion),
        "model_parameters": sum(p.numel() for p in model.parameters()),
    }
    )

    # 记录模型架构
    wandb.watch(model, log="all", log_freq=10)


    model=model.to(device)
    model.train()
    for epoch in range(epochs):
        total_loss=0
        accuracies=[]
        precisions=[]
        recalls=[]
        f1_scores=[]

        # if (epoch+1)%20==0:
        pbar=tqdm(train_dataloader,desc=f'Train:{epoch+1}/{epochs}',unit='batch')
        # else:
            # pbar=train_dataloader
        for batch_idx,(image,label) in enumerate(pbar):
            image,label=image.to(device),label.to(device)

            optimizer.zero_grad()
            output=model(image)
            loss=criterion(output,label)
            loss.backward()
            optimizer.step()

            total_loss+=loss.item()
            if (batch_idx+1)%50==0:
                pbar.set_postfix({
                    'current loss':loss.item(),
                    'avg loss':total_loss/(batch_idx+1)
                })

            # 计算各项指标
            _,predicted=torch.max(output,dim=1)
            average='weighted'
            accuracy,precision,recall,f1=calculate_metrics(label,predicted,average)
            accuracies.append(accuracy)
            precisions.append(precision)
            recalls.append(recall)
            f1_scores.append(f1)

        # 计算损失
        train_loss=total_loss/len(train_dataloader)
        test_loss=validate(model,test_dataloader,criterion,device)

        # 打印损失
        # if (epoch+1)%20==0:
        print(f"Epoch:{epoch+1} train loss:{train_loss:.4f} test_loss:{test_loss:.4f}")

        accuracy=statistics.mean(accuracies)
        precision=statistics.mean(precisions)
        recall=statistics.mean(recalls)
        f1=statistics.mean(f1_scores)
        print(f"accuracy:{accuracy} precision:{precision} recall:{recall} f1:{f1}\n")

        # 记录损失及指标
        wandb.log({
            'epoch':epoch+1,
            'train loss':train_loss,
            'test loss':test_loss,
            'learning_rate': optimizer.param_groups[0]['lr'],
            'accuracy':accuracy,
            'precision':precision,
            'recall':recall,
            'f1':f1,
        })


        scheduler.step()

    # 生成混淆矩阵 预测可视化
    plot_confusion_matrix(model,test_dataloader,device)

    torch.save(model.state_dict(),'MNIST_Class_Model.pth')
    print(f"模型保存在 'MNIST_Class_Model.pth'")
    print("==================训练完成===================")

    wandb.finish()


主函数

In [7]:
def main():
    batch_size=128
    epochs=35
    train_dataloader,test_dataloader=generate_dataset(batch_size)

    # 实例化模型
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model=ClassMLP(28*28,10)

    print("==================开始训练==================")
    print(f"device:{device} 模型的参数:{sum(p.numel() for p in model.parameters())}\n")
    train(model,train_dataloader,test_dataloader,epochs,device)


if __name__=="__main__":
    main()

100%|██████████| 26.4M/26.4M [00:02<00:00, 11.6MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 209kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.88MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 12.6MB/s]


训练集数目:60000 测试集数目:10000
训练集批数目:469 测试集批数目:79

device:cpu 模型的参数:568970



  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhaojiaqi406[0m ([33mhaojiaqi406-sun-yat-sen-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Train:1/35: 100%|██████████| 469/469 [00:58<00:00,  7.98batch/s, current loss=0.503, avg loss=0.603]


Epoch:1 train loss:0.6005 test_loss:0.4633
accuracy:0.7819718372423596 precision:0.7990858747438382 recall:0.7819718372423596 f1:0.7782606372011924



Train:2/35: 100%|██████████| 469/469 [00:58<00:00,  8.02batch/s, current loss=0.415, avg loss=0.48]


Epoch:2 train loss:0.4805 test_loss:0.4262
accuracy:0.8232387171286425 precision:0.8377110002852882 recall:0.8232387171286425 f1:0.8220519622750821



Train:3/35: 100%|██████████| 469/469 [00:57<00:00,  8.15batch/s, current loss=0.426, avg loss=0.445]


Epoch:3 train loss:0.4443 test_loss:0.4357
accuracy:0.8375421997157072 precision:0.8497398087424717 recall:0.8375421997157072 f1:0.8364006223548047



Train:4/35: 100%|██████████| 469/469 [00:57<00:00,  8.11batch/s, current loss=0.356, avg loss=0.421]


Epoch:4 train loss:0.4208 test_loss:0.3737
accuracy:0.8445107054015636 precision:0.8570410498304711 recall:0.8445107054015636 f1:0.8437810357487258



Train:5/35: 100%|██████████| 469/469 [00:58<00:00,  8.01batch/s, current loss=0.278, avg loss=0.403]


Epoch:5 train loss:0.4029 test_loss:0.3772
accuracy:0.8521344171997157 precision:0.8638919902673775 recall:0.8521344171997157 f1:0.8514939195761105



Train:6/35: 100%|██████████| 469/469 [00:58<00:00,  7.98batch/s, current loss=0.349, avg loss=0.386]


Epoch:6 train loss:0.3866 test_loss:0.3718
accuracy:0.857437144633973 precision:0.8684818027489544 recall:0.857437144633973 f1:0.8565043249280783



Train:7/35: 100%|██████████| 469/469 [01:00<00:00,  7.75batch/s, current loss=0.376, avg loss=0.375]


Epoch:7 train loss:0.3748 test_loss:0.3583
accuracy:0.8613017501776831 precision:0.872517719873792 recall:0.8613017501776831 f1:0.8610743876910745



Train:8/35: 100%|██████████| 469/469 [00:57<00:00,  8.15batch/s, current loss=0.308, avg loss=0.37]


Epoch:8 train loss:0.3685 test_loss:0.3388
accuracy:0.8650108830845771 precision:0.8757304905142361 recall:0.8650108830845771 f1:0.8646498231802098



Train:9/35: 100%|██████████| 469/469 [00:57<00:00,  8.09batch/s, current loss=0.37, avg loss=0.355]


Epoch:9 train loss:0.3546 test_loss:0.3454
accuracy:0.8698249822316987 precision:0.880369163178534 recall:0.8698249822316987 f1:0.8694865318609831



Train:10/35: 100%|██████████| 469/469 [00:58<00:00,  7.96batch/s, current loss=0.417, avg loss=0.349]


Epoch:10 train loss:0.3492 test_loss:0.3461
accuracy:0.8701581378820185 precision:0.8802784041636726 recall:0.8701581378820185 f1:0.8696225652038813



Train:11/35: 100%|██████████| 469/469 [00:58<00:00,  8.01batch/s, current loss=0.458, avg loss=0.344]


Epoch:11 train loss:0.3439 test_loss:0.3232
accuracy:0.8726123845060412 precision:0.8819960736443163 recall:0.8726123845060412 f1:0.8721222983058613



Train:12/35: 100%|██████████| 469/469 [00:58<00:00,  8.05batch/s, current loss=0.328, avg loss=0.337]


Epoch:12 train loss:0.3371 test_loss:0.3250
accuracy:0.8735674307036247 precision:0.8835213861548865 recall:0.8735674307036247 f1:0.873302177299698



Train:13/35: 100%|██████████| 469/469 [00:57<00:00,  8.13batch/s, current loss=0.268, avg loss=0.327]


Epoch:13 train loss:0.3276 test_loss:0.3422
accuracy:0.8786369491826581 precision:0.8876401247507218 recall:0.8786369491826581 f1:0.878525031600109



Train:14/35: 100%|██████████| 469/469 [00:59<00:00,  7.93batch/s, current loss=0.285, avg loss=0.321]


Epoch:14 train loss:0.3204 test_loss:0.3155
accuracy:0.8815242981520967 precision:0.8902216976361202 recall:0.8815242981520967 f1:0.881369444147511



Train:15/35: 100%|██████████| 469/469 [00:58<00:00,  8.05batch/s, current loss=0.293, avg loss=0.315]


Epoch:15 train loss:0.3161 test_loss:0.3084
accuracy:0.8830790245202559 precision:0.8920393552882608 recall:0.8830790245202559 f1:0.8829249024576976



Train:16/35: 100%|██████████| 469/469 [00:57<00:00,  8.13batch/s, current loss=0.321, avg loss=0.308]


Epoch:16 train loss:0.3075 test_loss:0.3011
accuracy:0.884944696162047 precision:0.8936093128480639 recall:0.884944696162047 f1:0.8847058052711494



Train:17/35: 100%|██████████| 469/469 [00:57<00:00,  8.12batch/s, current loss=0.268, avg loss=0.3]


Epoch:17 train loss:0.3000 test_loss:0.3072
accuracy:0.8882929104477612 precision:0.8966236232086661 recall:0.8882929104477612 f1:0.8881282845460629



Train:18/35: 100%|██████████| 469/469 [00:58<00:00,  8.08batch/s, current loss=0.243, avg loss=0.298]


Epoch:18 train loss:0.2981 test_loss:0.3118
accuracy:0.8871934968017058 precision:0.8955814205296913 recall:0.8871934968017058 f1:0.8870337070134309



Train:19/35: 100%|██████████| 469/469 [00:58<00:00,  8.02batch/s, current loss=0.29, avg loss=0.288]


Epoch:19 train loss:0.2882 test_loss:0.3046
accuracy:0.8910969705046198 precision:0.8987627011678524 recall:0.8910969705046198 f1:0.8908806756068461



Train:20/35: 100%|██████████| 469/469 [00:59<00:00,  7.95batch/s, current loss=0.325, avg loss=0.283]


Epoch:20 train loss:0.2837 test_loss:0.2907
accuracy:0.8939565565031983 precision:0.9016382112515424 recall:0.8939565565031983 f1:0.893853214506675



Train:21/35: 100%|██████████| 469/469 [00:58<00:00,  8.07batch/s, current loss=0.198, avg loss=0.277]


Epoch:21 train loss:0.2768 test_loss:0.2963
accuracy:0.8959999111584932 precision:0.9034450667899782 recall:0.8959999111584932 f1:0.8959450077529247



Train:22/35: 100%|██████████| 469/469 [00:58<00:00,  7.97batch/s, current loss=0.311, avg loss=0.273]


Epoch:22 train loss:0.2734 test_loss:0.2905
accuracy:0.8968328002842928 precision:0.9041166434230592 recall:0.8968328002842928 f1:0.8967037468179833



Train:23/35: 100%|██████████| 469/469 [00:58<00:00,  8.06batch/s, current loss=0.309, avg loss=0.262]


Epoch:23 train loss:0.2635 test_loss:0.2891
accuracy:0.9000421997157072 precision:0.9072897292136216 recall:0.9000421997157072 f1:0.8999322693725705



Train:24/35: 100%|██████████| 469/469 [00:59<00:00,  7.94batch/s, current loss=0.378, avg loss=0.262]


Epoch:24 train loss:0.2609 test_loss:0.2840
accuracy:0.900314276830135 precision:0.9076278026747371 recall:0.900314276830135 f1:0.9002677269471225



Train:25/35: 100%|██████████| 469/469 [01:00<00:00,  7.79batch/s, current loss=0.244, avg loss=0.258]


Epoch:25 train loss:0.2561 test_loss:0.2821
accuracy:0.9017301883439943 precision:0.9089303098286282 recall:0.9017301883439943 f1:0.9015961118060626



Train:26/35: 100%|██████████| 469/469 [00:58<00:00,  8.05batch/s, current loss=0.254, avg loss=0.252]


Epoch:26 train loss:0.2522 test_loss:0.2805
accuracy:0.9046508528784648 precision:0.9113905424332505 recall:0.9046508528784648 f1:0.9045410413605398



Train:27/35: 100%|██████████| 469/469 [00:59<00:00,  7.93batch/s, current loss=0.337, avg loss=0.245]


Epoch:27 train loss:0.2449 test_loss:0.2791
accuracy:0.9064887615493958 precision:0.9135546630018513 recall:0.9064887615493958 f1:0.9064683202449225



Train:28/35: 100%|██████████| 469/469 [00:58<00:00,  8.07batch/s, current loss=0.198, avg loss=0.244]


Epoch:28 train loss:0.2436 test_loss:0.2772
accuracy:0.9071050995024875 precision:0.9140013288787158 recall:0.9071050995024875 f1:0.9070271925507536



Train:29/35: 100%|██████████| 469/469 [00:57<00:00,  8.16batch/s, current loss=0.314, avg loss=0.24]


Epoch:29 train loss:0.2400 test_loss:0.2765
accuracy:0.9084765902629709 precision:0.9150850302932617 recall:0.9084765902629709 f1:0.9084654857389941



Train:30/35: 100%|██████████| 469/469 [00:58<00:00,  8.04batch/s, current loss=0.198, avg loss=0.235]


Epoch:30 train loss:0.2347 test_loss:0.2738
accuracy:0.9105199449182658 precision:0.9168848195280157 recall:0.9105199449182658 f1:0.910450128953472



Train:31/35: 100%|██████████| 469/469 [00:58<00:00,  8.08batch/s, current loss=0.233, avg loss=0.233]


Epoch:31 train loss:0.2333 test_loss:0.2749
accuracy:0.9118969882729211 precision:0.9183347942782345 recall:0.9118969882729211 f1:0.9117192776672951



Train:32/35: 100%|██████████| 469/469 [00:58<00:00,  8.07batch/s, current loss=0.214, avg loss=0.23]


Epoch:32 train loss:0.2298 test_loss:0.2751
accuracy:0.9125577469793887 precision:0.9185232375278474 recall:0.9125577469793887 f1:0.9123851306042373



Train:33/35: 100%|██████████| 469/469 [00:59<00:00,  7.93batch/s, current loss=0.309, avg loss=0.228]


Epoch:33 train loss:0.2285 test_loss:0.2768
accuracy:0.9134406094527363 precision:0.91944936011317 recall:0.9134406094527363 f1:0.913455699469168



Train:34/35: 100%|██████████| 469/469 [00:59<00:00,  7.94batch/s, current loss=0.18, avg loss=0.227]


Epoch:34 train loss:0.2273 test_loss:0.2737
accuracy:0.9116304637526652 precision:0.9181748166484776 recall:0.9116304637526652 f1:0.9116699221601545



Train:35/35: 100%|██████████| 469/469 [01:00<00:00,  7.75batch/s, current loss=0.21, avg loss=0.23]


Epoch:35 train loss:0.2301 test_loss:0.2751
accuracy:0.9127076670220327 precision:0.9190030756559149 recall:0.9127076670220327 f1:0.9127326673580692

模型保存在 'MNIST_Class_Model.pth'


0,1
accuracy,▁▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████████
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
f1,▁▃▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇██████████
learning_rate,███████▇▇▇▇▆▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁▁▁
precision,▁▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████████
recall,▁▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████████
test loss,█▇▇▅▅▅▄▃▄▄▃▃▄▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
train loss,█▆▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.91271
epoch,35.0
f1,0.91273
learning_rate,2e-05
precision,0.919
recall,0.91271
test loss,0.27511
train loss,0.23008
