In [1]:
import pandas as pd
import os
import shutil
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import Dataset,DataLoader,random_split

torch.manual_seed(666)
torch.cuda.manual_seed(666)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

## 数据集：多种肺炎X光图像

### 数据集

In [2]:
data_dir = "../data/COVID_Dataset"
TRAIN = 'train'
TEST = 'test'
VAL = 'val'

def apply_transform(mode=None):
    size = (299,299)
    crop = 299
    if mode == 'train':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.RandomHorizontalFlip(),
                               transforms.RandomRotation((-20,+20)),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    elif mode == 'test' or mode == 'val':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    return transform

trainset = datasets.ImageFolder(os.path.join(data_dir, TRAIN),
                                transform = apply_transform(TRAIN))

trainset,valset = random_split(trainset,[len(trainset)-int(len(trainset)*0.1),int(len(trainset)*0.1)])

testset = datasets.ImageFolder(os.path.join(data_dir, TEST),
                               transform = apply_transform(TEST))

train_loader = DataLoader(trainset,
                          batch_size=50,
                          shuffle=True)

val_loader = DataLoader(valset,
                        batch_size=10)

test_loader = DataLoader(testset,
                         batch_size=1)

len(train_loader),len(test_loader),len(val_loader)




(439, 2706, 244)

### 模型

In [3]:
# Inception 残差块
class Residual(nn.Module):
    def __init__(self,in_channels,out_channels,inception,kernel_size=(1,1),stride=1,padding=0) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding)
        self.inception = inception
    def forward(self,x):
        conv = self.conv1(x)
        y = self.inception(x)
        out = F.relu(conv+y)
        return out

def get_deepLavV3(out_ch):
    model = torchvision.models.inception_v3(pretrained=True)
    # 在每个Inception层 添加残差
    model.Mixed_5b = Residual(192,256,model.Mixed_5b)
    model.Mixed_5c = Residual(256,288,model.Mixed_5c)
    model.Mixed_5d = Residual(288,288,model.Mixed_5d)
    model.Mixed_6a = Residual(288,768,model.Mixed_6a,(5,5),2,1)
    model.Mixed_6b = Residual(768,768,model.Mixed_6b)
    model.Mixed_6c = Residual(768,768,model.Mixed_6c)
    model.Mixed_6d = Residual(768,768,model.Mixed_6d)
    model.Mixed_6e = Residual(768,768,model.Mixed_6e)
    model.Mixed_7a = Residual(768,1280,model.Mixed_7a,(5,5),2,1)
    model.Mixed_7b = Residual(1280,2048,model.Mixed_7b)
    model.Mixed_7c = Residual(2048,2048,model.Mixed_7c)

    # 冻结参数
    for param in model.parameters():
        param.requires_grad = False
    # 改全连接层
    in_features = model.fc.in_features
    classifier = nn.Sequential(
        nn.Linear(in_features, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, out_ch),
        nn.LogSoftmax(dim=1)
    )
    model.fc = classifier
    return model

# model = get_deepLavV3(2)
# print(model)
# x = torch.tensor(np.random.random((2,3,299,299))).float()
# y,_ = model(x)
# y.shape,_.shape


### 评价指标

In [4]:
def accuracy(preds, labels):
    preds = torch.exp(preds)
    top_p,top_class = preds.topk(1, dim=1)
    equals = top_class == labels.view(*top_class.shape)
    return torch.mean(equals.type(torch.FloatTensor))

### 冻结训练

In [5]:
model = get_deepLavV3(4)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

epochs = 25
val_loss_min = np.Inf
max_e = 10

# 模型存放路径
model_path = os.path.join('./model')
name = "Res_InceptionV3_pre"
model = model.to(DEVICE)
for epoch in range(epochs):

    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0

    model.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        preds,_ = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(preds, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)

    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(val_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_acc += accuracy(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

    schedular.step(avg_val_loss)

    print("Epoch : {} \ntrain_loss : {:.6f}, \tTrain_acc : {:.6f}, \nVal_loss : {:.6f}, \tVal_acc : {:.6f}".format(epoch + 1,
                                                                                                                   avg_train_loss, avg_train_acc,
                                                                                                                   avg_val_loss, avg_val_acc))
    if avg_val_loss <= val_loss_min:
        print('Validation loss decreased from ({:.6f} --> {:.6f}).\nSaving model ...'.format(val_loss_min, avg_val_loss))
        torch.save(model,os.path.join(model_path,f"{name}.pth"))
        val_loss_min = avg_val_loss
        max_e = 20
    max_e -= 1
    if max_e<=0:
        break


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /home/featurize/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

100%|██████████| 439/439 [03:27<00:00,  2.12it/s]
100%|██████████| 244/244 [00:24<00:00,  9.90it/s]


Epoch : 1 
train_loss : 1.130528, 	Train_acc : 0.519214, 
Val_loss : 0.981015, 	Val_acc : 0.617213
Validation loss decreased from (inf --> 0.981015).
Saving model ...


100%|██████████| 439/439 [03:28<00:00,  2.10it/s]
100%|██████████| 244/244 [00:24<00:00, 10.05it/s]


Epoch : 2 
train_loss : 1.044887, 	Train_acc : 0.559582, 
Val_loss : 0.945682, 	Val_acc : 0.628006
Validation loss decreased from (0.981015 --> 0.945682).
Saving model ...


100%|██████████| 439/439 [03:25<00:00,  2.14it/s]
100%|██████████| 244/244 [00:24<00:00, 10.12it/s]


Epoch : 3 
train_loss : 1.016349, 	Train_acc : 0.571699, 
Val_loss : 0.917051, 	Val_acc : 0.640984
Validation loss decreased from (0.945682 --> 0.917051).
Saving model ...


100%|██████████| 439/439 [03:29<00:00,  2.09it/s]
100%|██████████| 244/244 [00:23<00:00, 10.17it/s]


Epoch : 4 
train_loss : 1.008985, 	Train_acc : 0.575774, 
Val_loss : 0.976752, 	Val_acc : 0.625136


100%|██████████| 439/439 [03:27<00:00,  2.11it/s]
100%|██████████| 244/244 [00:23<00:00, 10.27it/s]


Epoch : 5 
train_loss : 1.001285, 	Train_acc : 0.580743, 
Val_loss : 0.902815, 	Val_acc : 0.647268
Validation loss decreased from (0.917051 --> 0.902815).
Saving model ...


100%|██████████| 439/439 [03:23<00:00,  2.15it/s]
100%|██████████| 244/244 [00:23<00:00, 10.27it/s]


Epoch : 6 
train_loss : 0.991850, 	Train_acc : 0.580696, 
Val_loss : 0.920465, 	Val_acc : 0.649590


100%|██████████| 439/439 [03:25<00:00,  2.14it/s]
100%|██████████| 244/244 [00:23<00:00, 10.49it/s]


Epoch : 7 
train_loss : 0.993904, 	Train_acc : 0.583271, 
Val_loss : 0.921895, 	Val_acc : 0.649590


100%|██████████| 439/439 [03:26<00:00,  2.13it/s]
100%|██████████| 244/244 [00:23<00:00, 10.36it/s]


Epoch : 8 
train_loss : 0.992066, 	Train_acc : 0.585112, 
Val_loss : 0.919057, 	Val_acc : 0.659016


100%|██████████| 439/439 [03:24<00:00,  2.15it/s]
100%|██████████| 244/244 [00:24<00:00, 10.05it/s]


Epoch : 9 
train_loss : 0.987968, 	Train_acc : 0.584800, 
Val_loss : 0.897197, 	Val_acc : 0.663661
Validation loss decreased from (0.902815 --> 0.897197).
Saving model ...


100%|██████████| 439/439 [03:27<00:00,  2.11it/s]
100%|██████████| 244/244 [00:23<00:00, 10.58it/s]


Epoch : 10 
train_loss : 0.979910, 	Train_acc : 0.588596, 
Val_loss : 0.890320, 	Val_acc : 0.661885
Validation loss decreased from (0.897197 --> 0.890320).
Saving model ...


100%|██████████| 439/439 [03:23<00:00,  2.16it/s]
100%|██████████| 244/244 [00:24<00:00, 10.06it/s]


Epoch : 11 
train_loss : 0.977702, 	Train_acc : 0.590375, 
Val_loss : 0.902381, 	Val_acc : 0.672541


100%|██████████| 439/439 [03:22<00:00,  2.17it/s]
100%|██████████| 244/244 [00:23<00:00, 10.41it/s]


Epoch : 12 
train_loss : 0.979777, 	Train_acc : 0.583773, 
Val_loss : 0.939931, 	Val_acc : 0.682787


100%|██████████| 439/439 [03:29<00:00,  2.09it/s]
100%|██████████| 244/244 [00:23<00:00, 10.18it/s]


Epoch : 13 
train_loss : 0.979406, 	Train_acc : 0.589375, 
Val_loss : 0.911889, 	Val_acc : 0.643033


100%|██████████| 439/439 [03:21<00:00,  2.18it/s]
100%|██████████| 244/244 [00:23<00:00, 10.41it/s]


Epoch : 14 
train_loss : 0.979113, 	Train_acc : 0.592266, 
Val_loss : 0.959644, 	Val_acc : 0.646311


100%|██████████| 439/439 [03:24<00:00,  2.15it/s]
100%|██████████| 244/244 [00:23<00:00, 10.53it/s]


Epoch : 15 
train_loss : 0.948419, 	Train_acc : 0.601788, 
Val_loss : 0.913864, 	Val_acc : 0.682787


100%|██████████| 439/439 [03:19<00:00,  2.20it/s]
100%|██████████| 244/244 [00:22<00:00, 10.91it/s]


Epoch : 16 
train_loss : 0.948916, 	Train_acc : 0.599330, 
Val_loss : 0.900223, 	Val_acc : 0.681557


100%|██████████| 439/439 [03:21<00:00,  2.18it/s]
100%|██████████| 244/244 [00:22<00:00, 10.73it/s]


Epoch : 17 
train_loss : 0.948235, 	Train_acc : 0.600446, 
Val_loss : 0.918013, 	Val_acc : 0.680464


100%|██████████| 439/439 [03:21<00:00,  2.18it/s]
100%|██████████| 244/244 [00:23<00:00, 10.53it/s]


Epoch : 18 
train_loss : 0.947002, 	Train_acc : 0.600857, 
Val_loss : 0.909396, 	Val_acc : 0.686066


100%|██████████| 439/439 [03:20<00:00,  2.19it/s]
100%|██████████| 244/244 [00:23<00:00, 10.49it/s]


Epoch : 19 
train_loss : 0.953090, 	Train_acc : 0.596209, 
Val_loss : 0.905628, 	Val_acc : 0.687432


100%|██████████| 439/439 [03:19<00:00,  2.20it/s]
100%|██████████| 244/244 [00:22<00:00, 10.73it/s]


Epoch : 20 
train_loss : 0.945794, 	Train_acc : 0.602382, 
Val_loss : 0.907010, 	Val_acc : 0.672268


100%|██████████| 439/439 [03:20<00:00,  2.19it/s]
100%|██████████| 244/244 [00:23<00:00, 10.33it/s]


Epoch : 21 
train_loss : 0.945022, 	Train_acc : 0.601628, 
Val_loss : 0.898196, 	Val_acc : 0.690573


100%|██████████| 439/439 [03:21<00:00,  2.18it/s]
100%|██████████| 244/244 [00:23<00:00, 10.57it/s]


Epoch : 22 
train_loss : 0.944868, 	Train_acc : 0.602862, 
Val_loss : 0.900685, 	Val_acc : 0.678825


100%|██████████| 439/439 [03:20<00:00,  2.19it/s]
100%|██████████| 244/244 [00:23<00:00, 10.55it/s]


Epoch : 23 
train_loss : 0.942678, 	Train_acc : 0.598603, 
Val_loss : 0.901195, 	Val_acc : 0.687295


100%|██████████| 439/439 [03:17<00:00,  2.22it/s]
100%|██████████| 244/244 [00:23<00:00, 10.59it/s]


Epoch : 24 
train_loss : 0.944298, 	Train_acc : 0.599398, 
Val_loss : 0.900669, 	Val_acc : 0.677049


100%|██████████| 439/439 [03:21<00:00,  2.18it/s]
100%|██████████| 244/244 [00:23<00:00, 10.49it/s]

Epoch : 25 
train_loss : 0.948375, 	Train_acc : 0.601128, 
Val_loss : 0.901749, 	Val_acc : 0.683607





In [6]:
val_loss_min

0.8903199996127457

### 解冻参数

In [7]:
model = torch.load(os.path.join(model_path,f"{name}.pth"))

for param in model.parameters():
        param.requires_grad = True

# for k,v in model.named_parameters():
#     print(f"{k}: {v.requires_grad}")
# model

### finetune

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

epochs =50
# val_loss_min = np.Inf
max_e = 10

name = "Res_InceptionV3_pre"
model = model.to(DEVICE)
for epoch in range(epochs):

    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0

    model.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds,_ = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(preds, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)

    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(val_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_acc += accuracy(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

    schedular.step(avg_val_loss)

    print("Epoch : {} \ntrain_loss : {:.6f}, \tTrain_acc : {:.6f}, \nVal_loss : {:.6f}, \tVal_acc : {:.6f}".format(epoch + 1,
                                                                                                                   avg_train_loss, avg_train_acc,
                                                                                                                   avg_val_loss, avg_val_acc))
    if avg_val_loss <= val_loss_min:
        print('Validation loss decreased from ({:.6f} --> {:.6f}).\nSaving model ...'.format(val_loss_min, avg_val_loss))
        torch.save(model, os.path.join(model_path,f"{name}.pth"))
        val_loss_min = avg_val_loss
        max_e = 10
    max_e -= 1
    if max_e<=0:
        break


100%|██████████| 439/439 [04:49<00:00,  1.52it/s]
100%|██████████| 244/244 [00:23<00:00, 10.59it/s]


Epoch : 1 
train_loss : 0.657715, 	Train_acc : 0.741113, 
Val_loss : 0.386418, 	Val_acc : 0.873224
Validation loss decreased from (0.890320 --> 0.386418).
Saving model ...


100%|██████████| 439/439 [04:45<00:00,  1.54it/s]
100%|██████████| 244/244 [00:23<00:00, 10.40it/s]


Epoch : 2 
train_loss : 0.392956, 	Train_acc : 0.857103, 
Val_loss : 0.372892, 	Val_acc : 0.864208
Validation loss decreased from (0.386418 --> 0.372892).
Saving model ...


100%|██████████| 439/439 [04:46<00:00,  1.53it/s]
100%|██████████| 244/244 [00:24<00:00, 10.15it/s]


Epoch : 3 
train_loss : 0.308455, 	Train_acc : 0.886991, 
Val_loss : 0.280242, 	Val_acc : 0.900136
Validation loss decreased from (0.372892 --> 0.280242).
Saving model ...


100%|██████████| 439/439 [04:44<00:00,  1.54it/s]
100%|██████████| 244/244 [00:22<00:00, 10.66it/s]


Epoch : 4 
train_loss : 0.278430, 	Train_acc : 0.898791, 
Val_loss : 0.339702, 	Val_acc : 0.872541


100%|██████████| 439/439 [04:47<00:00,  1.53it/s]
100%|██████████| 244/244 [00:23<00:00, 10.31it/s]


Epoch : 5 
train_loss : 0.259536, 	Train_acc : 0.906217, 
Val_loss : 0.256560, 	Val_acc : 0.901776
Validation loss decreased from (0.280242 --> 0.256560).
Saving model ...


100%|██████████| 439/439 [04:46<00:00,  1.53it/s]
100%|██████████| 244/244 [00:23<00:00, 10.51it/s]


Epoch : 6 
train_loss : 0.238317, 	Train_acc : 0.913075, 
Val_loss : 0.258006, 	Val_acc : 0.908606


100%|██████████| 439/439 [04:49<00:00,  1.52it/s]
100%|██████████| 244/244 [00:23<00:00, 10.41it/s]


Epoch : 7 
train_loss : 0.230541, 	Train_acc : 0.914439, 
Val_loss : 0.273416, 	Val_acc : 0.882787


100%|██████████| 439/439 [04:44<00:00,  1.54it/s]
100%|██████████| 244/244 [00:22<00:00, 10.65it/s]


Epoch : 8 
train_loss : 0.215989, 	Train_acc : 0.922369, 
Val_loss : 0.240677, 	Val_acc : 0.915710
Validation loss decreased from (0.256560 --> 0.240677).
Saving model ...


100%|██████████| 439/439 [04:48<00:00,  1.52it/s]
100%|██████████| 244/244 [00:23<00:00, 10.17it/s]


Epoch : 9 
train_loss : 0.203246, 	Train_acc : 0.925216, 
Val_loss : 0.207783, 	Val_acc : 0.923360
Validation loss decreased from (0.240677 --> 0.207783).
Saving model ...


100%|██████████| 439/439 [04:47<00:00,  1.53it/s]
100%|██████████| 244/244 [00:24<00:00, 10.16it/s]


Epoch : 10 
train_loss : 0.192066, 	Train_acc : 0.927331, 
Val_loss : 0.242928, 	Val_acc : 0.910109


100%|██████████| 439/439 [04:51<00:00,  1.51it/s]
100%|██████████| 244/244 [00:23<00:00, 10.50it/s]


Epoch : 11 
train_loss : 0.180419, 	Train_acc : 0.932733, 
Val_loss : 0.330035, 	Val_acc : 0.874454


100%|██████████| 439/439 [04:44<00:00,  1.54it/s]
100%|██████████| 244/244 [00:23<00:00, 10.53it/s]


Epoch : 12 
train_loss : 0.169766, 	Train_acc : 0.937450, 
Val_loss : 0.278304, 	Val_acc : 0.890573


100%|██████████| 439/439 [04:47<00:00,  1.53it/s]
100%|██████████| 244/244 [00:23<00:00, 10.40it/s]


Epoch : 13 
train_loss : 0.163676, 	Train_acc : 0.939932, 
Val_loss : 0.210001, 	Val_acc : 0.924180


100%|██████████| 439/439 [04:51<00:00,  1.51it/s]
100%|██████████| 244/244 [00:23<00:00, 10.24it/s]


Epoch : 14 
train_loss : 0.111316, 	Train_acc : 0.958383, 
Val_loss : 0.168599, 	Val_acc : 0.935655
Validation loss decreased from (0.207783 --> 0.168599).
Saving model ...


100%|██████████| 439/439 [04:49<00:00,  1.52it/s]
100%|██████████| 244/244 [00:23<00:00, 10.58it/s]


Epoch : 15 
train_loss : 0.095024, 	Train_acc : 0.965421, 
Val_loss : 0.169664, 	Val_acc : 0.933606


100%|██████████| 439/439 [04:51<00:00,  1.50it/s]
100%|██████████| 244/244 [00:23<00:00, 10.32it/s]


Epoch : 16 
train_loss : 0.088339, 	Train_acc : 0.968792, 
Val_loss : 0.169036, 	Val_acc : 0.936065


100%|██████████| 439/439 [04:47<00:00,  1.53it/s]
100%|██████████| 244/244 [00:22<00:00, 10.61it/s]


Epoch : 17 
train_loss : 0.082164, 	Train_acc : 0.968839, 
Val_loss : 0.178194, 	Val_acc : 0.936475


100%|██████████| 439/439 [04:46<00:00,  1.53it/s]
100%|██████████| 244/244 [00:23<00:00, 10.47it/s]


Epoch : 18 
train_loss : 0.075396, 	Train_acc : 0.972164, 
Val_loss : 0.188425, 	Val_acc : 0.939344


100%|██████████| 439/439 [04:47<00:00,  1.53it/s]
100%|██████████| 244/244 [00:24<00:00, 10.04it/s]


Epoch : 19 
train_loss : 0.068116, 	Train_acc : 0.975263, 
Val_loss : 0.183803, 	Val_acc : 0.942622


100%|██████████| 439/439 [04:49<00:00,  1.51it/s]
100%|██████████| 244/244 [00:23<00:00, 10.45it/s]


Epoch : 20 
train_loss : 0.067052, 	Train_acc : 0.975582, 
Val_loss : 0.175005, 	Val_acc : 0.938114


100%|██████████| 439/439 [04:44<00:00,  1.55it/s]
100%|██████████| 244/244 [00:23<00:00, 10.29it/s]


Epoch : 21 
train_loss : 0.066138, 	Train_acc : 0.975535, 
Val_loss : 0.177611, 	Val_acc : 0.938524


100%|██████████| 439/439 [04:50<00:00,  1.51it/s]
100%|██████████| 244/244 [00:23<00:00, 10.24it/s]


Epoch : 22 
train_loss : 0.063385, 	Train_acc : 0.977176, 
Val_loss : 0.180579, 	Val_acc : 0.933606


100%|██████████| 439/439 [04:45<00:00,  1.54it/s]
100%|██████████| 244/244 [00:23<00:00, 10.42it/s]

Epoch : 23 
train_loss : 0.062125, 	Train_acc : 0.977314, 
Val_loss : 0.179738, 	Val_acc : 0.937704





In [9]:
val_loss_min

0.16859943654195436

### 测试

In [7]:
name = "Res_InceptionV3_pre"
model_path = os.path.join('./model')
model = torch.load(os.path.join(model_path,f"{name}.pth"))
model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()

# 混淆矩阵
M = np.zeros((4,4))
with torch.no_grad():
    test_loss=0
    test_acc=0
    for images,labels in tqdm(test_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds = model(images)
        loss = criterion(preds, labels)
        test_loss += loss.item()
        test_acc += accuracy(preds, labels)
        
        M[preds.argmax(),labels[0]]+=1
        
    avg_test_loss = test_loss / len(test_loader)
    avg_test_acc = test_acc / len(test_loader)
    print("Test_loss : {:.6f}, \tTest_acc : {:.6f}".format(avg_test_loss,avg_test_acc))
    
    M = M.astype("int")
    print(M)
    

100%|██████████| 2706/2706 [01:10<00:00, 38.60it/s]

Test_loss : 0.160139, 	Test_acc : 0.942720
[[ 362    4    2    0]
 [   1  798   29   28]
 [   3   35 1137    2]
 [   1   42    8  254]]



