In [1]:
import pandas as pd
import os
import shutil
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import Dataset,DataLoader,random_split
from torch.nn.parameter import Parameter

torch.manual_seed(666)
torch.cuda.manual_seed(666)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

## 数据集：多种肺炎X光图像

### 数据集

In [2]:
data_dir = "./data/COVID_Dataset"
TRAIN = 'train'
TEST = 'test'
VAL = 'val'

def apply_transform(mode=None):
    size = (299,299)
    crop = 299
    if mode == 'train':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.RandomHorizontalFlip(),
                               transforms.RandomRotation((-20,+20)),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    elif mode == 'test' or mode == 'val':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    return transform

trainset = datasets.ImageFolder(os.path.join(data_dir, TRAIN),
                                transform = apply_transform(TRAIN))

trainset,valset = random_split(trainset,[len(trainset)-int(len(trainset)*0.1),int(len(trainset)*0.1)])

testset = datasets.ImageFolder(os.path.join(data_dir, TEST),
                               transform = apply_transform(TEST))

train_loader = DataLoader(trainset,
                          batch_size=50,
                          shuffle=True)

val_loader = DataLoader(valset,
                        batch_size=10)

test_loader = DataLoader(testset,
                         batch_size=1)

len(train_loader),len(test_loader),len(val_loader)




(439, 2706, 244)

### 模型

In [3]:
# SA模块
class SA_Layer(nn.Module):
    def __init__(self,channels,groups=64) -> None:
        super().__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1,channels//(2*groups),1,1))
        self.cbias = Parameter(torch.ones(1,channels//(2*groups),1,1))
        self.sweight = Parameter(torch.zeros(1,channels//(2*groups),1,1))
        self.sbias = Parameter(torch.ones(1,channels//(2*groups),1,1))
        
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channels//(2*groups),channels//(2*groups))

    @staticmethod
    def channel_suffle(x,groups):
        b,c,h,w = x.shape
        # 输入特征图分组
        x = x.reshape(b,groups,-1,h,w)
        # 洗牌
        x = x.permute(0,2,1,3,4)
        x = x.reshape(b,-1,h,w)
        return x
    

    def forward(self,x):
        b,c,h,w = x.shape
        x = x.reshape(b*self.groups,-1,h,w)
        x_0, x_1 = x.chunk(2,dim=1)

        # 通道注意力
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0*self.sigmoid(xn)

        # 空间注意力
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1*self.sigmoid(xs)

        # 在通道维度上拼接
        out = torch.cat([xn,xs],dim=1)
        out = out.reshape(b,-1,h,w)
        out = self.channel_suffle(out,2)

        return out


# 预训练InceptionV3 + SA
def get_deepLavV3(out_ch):
    model = torchvision.models.inception_v3(pretrained=True)
    # 在每个Inception层后 添加SA
    model.Mixed_5d = nn.Sequential(
        model.Mixed_5d,
        SA_Layer(288,8)
    )
    model.Mixed_6e = nn.Sequential(
        model.Mixed_6e,
        SA_Layer(768,64)
    )
    model.Mixed_7c = nn.Sequential(
        model.Mixed_7c,
        SA_Layer(2048,128)
    )
    



    # 冻结参数
    for param in model.parameters():
        param.requires_grad = False


    in_features = model.fc.in_features
    
    # 改全连接层
    classifier = nn.Sequential(
        nn.Linear(in_features, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, out_ch),
        nn.LogSoftmax(dim=1)
    )
    model.fc = classifier
    return model

# model = get_deepLavV3(4)
# print(model)
# x = torch.tensor(np.random.random((2,3,299,299))).float()
# y,_ = model(x)
# y.shape,_.shape


### 评价指标

In [4]:
def accuracy(preds, labels):
    preds = torch.exp(preds)
    top_p,top_class = preds.topk(1, dim=1)
    equals = top_class == labels.view(*top_class.shape)
    return torch.mean(equals.type(torch.FloatTensor))

### 冻结训练

In [5]:
model = get_deepLavV3(4)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

epochs = 15
val_loss_min = np.Inf
max_e = 10

# 模型存放路径
model_path = os.path.join('./model')
name = "SA_InceptionV3_pre"
model = model.to(DEVICE)
for epoch in range(epochs):

    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0

    model.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        preds,_ = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(preds, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)

    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(val_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_acc += accuracy(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

    schedular.step(avg_val_loss)

    print("Epoch : {} \ntrain_loss : {:.6f}, \tTrain_acc : {:.6f}, \nVal_loss : {:.6f}, \tVal_acc : {:.6f}".format(epoch + 1,
                                                                                                                   avg_train_loss, avg_train_acc,
                                                                                                                   avg_val_loss, avg_val_acc))
    if avg_val_loss <= val_loss_min:
        print('Validation loss decreased from ({:.6f} --> {:.6f}).\nSaving model ...'.format(val_loss_min, avg_val_loss))
        torch.save(model,os.path.join(model_path,f"{name}.pth"))
        val_loss_min = avg_val_loss
        max_e = 10
    max_e -= 1
    if max_e<=0:
        break


100%|██████████| 439/439 [03:28<00:00,  2.11it/s]
100%|██████████| 244/244 [00:24<00:00,  9.94it/s]


Epoch : 1 
train_loss : 1.181204, 	Train_acc : 0.476250, 
Val_loss : 1.059953, 	Val_acc : 0.571175
Validation loss decreased from (inf --> 1.059953).
Saving model ...


100%|██████████| 439/439 [03:30<00:00,  2.08it/s]
100%|██████████| 244/244 [00:23<00:00, 10.46it/s]


Epoch : 2 
train_loss : 1.136686, 	Train_acc : 0.505663, 
Val_loss : 1.083311, 	Val_acc : 0.560382


100%|██████████| 439/439 [03:25<00:00,  2.13it/s]
100%|██████████| 244/244 [00:22<00:00, 10.99it/s]


Epoch : 3 
train_loss : 1.123085, 	Train_acc : 0.505616, 
Val_loss : 1.065743, 	Val_acc : 0.543306


100%|██████████| 439/439 [03:25<00:00,  2.14it/s]
100%|██████████| 244/244 [00:24<00:00, 10.09it/s]


Epoch : 4 
train_loss : 1.111259, 	Train_acc : 0.512908, 
Val_loss : 1.031817, 	Val_acc : 0.587295
Validation loss decreased from (1.059953 --> 1.031817).
Saving model ...


100%|██████████| 439/439 [03:25<00:00,  2.13it/s]
100%|██████████| 244/244 [00:23<00:00, 10.39it/s]


Epoch : 5 
train_loss : 1.100511, 	Train_acc : 0.517231, 
Val_loss : 1.026407, 	Val_acc : 0.597131
Validation loss decreased from (1.031817 --> 1.026407).
Saving model ...


100%|██████████| 439/439 [03:27<00:00,  2.12it/s]
100%|██████████| 244/244 [00:23<00:00, 10.21it/s]


Epoch : 6 
train_loss : 1.104783, 	Train_acc : 0.516185, 
Val_loss : 1.053353, 	Val_acc : 0.553962


100%|██████████| 439/439 [03:23<00:00,  2.16it/s]
100%|██████████| 244/244 [00:23<00:00, 10.27it/s]


Epoch : 7 
train_loss : 1.095840, 	Train_acc : 0.522677, 
Val_loss : 1.022222, 	Val_acc : 0.591120
Validation loss decreased from (1.026407 --> 1.022222).
Saving model ...


100%|██████████| 439/439 [03:31<00:00,  2.07it/s]
100%|██████████| 244/244 [00:23<00:00, 10.25it/s]


Epoch : 8 
train_loss : 1.094567, 	Train_acc : 0.521470, 
Val_loss : 1.036007, 	Val_acc : 0.589891


100%|██████████| 439/439 [03:25<00:00,  2.13it/s]
100%|██████████| 244/244 [00:22<00:00, 10.75it/s]


Epoch : 9 
train_loss : 1.095666, 	Train_acc : 0.520242, 
Val_loss : 1.039991, 	Val_acc : 0.593989


100%|██████████| 439/439 [03:24<00:00,  2.15it/s]
100%|██████████| 244/244 [00:23<00:00, 10.51it/s]


Epoch : 10 
train_loss : 1.095100, 	Train_acc : 0.523928, 
Val_loss : 1.040453, 	Val_acc : 0.579372


100%|██████████| 439/439 [03:23<00:00,  2.15it/s]
100%|██████████| 244/244 [00:23<00:00, 10.39it/s]


Epoch : 11 
train_loss : 1.086179, 	Train_acc : 0.525753, 
Val_loss : 0.993801, 	Val_acc : 0.599863
Validation loss decreased from (1.022222 --> 0.993801).
Saving model ...


100%|██████████| 439/439 [03:26<00:00,  2.12it/s]
100%|██████████| 244/244 [00:24<00:00,  9.87it/s]


Epoch : 12 
train_loss : 1.082989, 	Train_acc : 0.532063, 
Val_loss : 1.034760, 	Val_acc : 0.589344


100%|██████████| 439/439 [03:22<00:00,  2.17it/s]
100%|██████████| 244/244 [00:22<00:00, 10.88it/s]


Epoch : 13 
train_loss : 1.087267, 	Train_acc : 0.529011, 
Val_loss : 1.012582, 	Val_acc : 0.581557


100%|██████████| 439/439 [03:21<00:00,  2.18it/s]
100%|██████████| 244/244 [00:22<00:00, 10.65it/s]


Epoch : 14 
train_loss : 1.074978, 	Train_acc : 0.535777, 
Val_loss : 1.022065, 	Val_acc : 0.605055


100%|██████████| 439/439 [03:25<00:00,  2.13it/s]
100%|██████████| 244/244 [00:24<00:00, 10.02it/s]

Epoch : 15 
train_loss : 1.078810, 	Train_acc : 0.532198, 
Val_loss : 1.025394, 	Val_acc : 0.609426





In [6]:
val_loss_min

0.993800625937884

### 解冻参数

In [7]:
# 载入冻结训练的模型,解冻卷积层参数
model_path = os.path.join('./model')
name = "SA_InceptionV3_pre"
model = torch.load(os.path.join(model_path,f"{name}.pth"))

for param in model.parameters():
        param.requires_grad = True

# for k,v in model.named_parameters():
#     print(f"{k}: {v.requires_grad}")
# model

### finetune

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

epochs =50
max_e = 10

name = "SA_InceptionV3_pre_finetune"
model = model.to(DEVICE)
for epoch in range(epochs):

    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0

    model.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds,_ = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(preds, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)

    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(val_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_acc += accuracy(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

    schedular.step(avg_val_loss)

    print("Epoch : {} \ntrain_loss : {:.6f}, \tTrain_acc : {:.6f}, \nVal_loss : {:.6f}, \tVal_acc : {:.6f}".format(epoch + 1,
                                                                                                                   avg_train_loss, avg_train_acc,
                                                                                                                   avg_val_loss, avg_val_acc))
    if avg_val_loss <= val_loss_min:
        print('Validation loss decreased from ({:.6f} --> {:.6f}).\nSaving model ...'.format(val_loss_min, avg_val_loss))
        torch.save(model, os.path.join(model_path,f"{name}.pth"))
        val_loss_min = avg_val_loss
        max_e = 10
    max_e -= 1
    if max_e<=0:
        break


100%|██████████| 439/439 [04:26<00:00,  1.65it/s]
100%|██████████| 244/244 [00:22<00:00, 10.70it/s]


Epoch : 1 
train_loss : 0.693221, 	Train_acc : 0.724438, 
Val_loss : 0.434616, 	Val_acc : 0.857240
Validation loss decreased from (0.993801 --> 0.434616).
Saving model ...


100%|██████████| 439/439 [04:20<00:00,  1.69it/s]
100%|██████████| 244/244 [00:22<00:00, 10.75it/s]


Epoch : 2 
train_loss : 0.405801, 	Train_acc : 0.850861, 
Val_loss : 0.355279, 	Val_acc : 0.881284
Validation loss decreased from (0.434616 --> 0.355279).
Saving model ...


100%|██████████| 439/439 [04:24<00:00,  1.66it/s]
100%|██████████| 244/244 [00:23<00:00, 10.25it/s]


Epoch : 3 
train_loss : 0.333487, 	Train_acc : 0.881296, 
Val_loss : 0.341112, 	Val_acc : 0.887978
Validation loss decreased from (0.355279 --> 0.341112).
Saving model ...


100%|██████████| 439/439 [04:25<00:00,  1.65it/s]
100%|██████████| 244/244 [00:22<00:00, 10.65it/s]


Epoch : 4 
train_loss : 0.296845, 	Train_acc : 0.890795, 
Val_loss : 0.313671, 	Val_acc : 0.887841
Validation loss decreased from (0.341112 --> 0.313671).
Saving model ...


100%|██████████| 439/439 [04:25<00:00,  1.65it/s]
100%|██████████| 244/244 [00:23<00:00, 10.42it/s]


Epoch : 5 
train_loss : 0.262976, 	Train_acc : 0.903029, 
Val_loss : 0.272745, 	Val_acc : 0.904644
Validation loss decreased from (0.313671 --> 0.272745).
Saving model ...


100%|██████████| 439/439 [04:24<00:00,  1.66it/s]
100%|██████████| 244/244 [00:22<00:00, 10.78it/s]


Epoch : 6 
train_loss : 0.250420, 	Train_acc : 0.907629, 
Val_loss : 0.257853, 	Val_acc : 0.912021
Validation loss decreased from (0.272745 --> 0.257853).
Saving model ...


100%|██████████| 439/439 [04:24<00:00,  1.66it/s]
100%|██████████| 244/244 [00:22<00:00, 10.87it/s]


Epoch : 7 
train_loss : 0.222975, 	Train_acc : 0.918472, 
Val_loss : 0.244942, 	Val_acc : 0.920628
Validation loss decreased from (0.257853 --> 0.244942).
Saving model ...


100%|██████████| 439/439 [04:26<00:00,  1.65it/s]
100%|██████████| 244/244 [00:24<00:00, 10.13it/s]


Epoch : 8 
train_loss : 0.218637, 	Train_acc : 0.918518, 
Val_loss : 0.250373, 	Val_acc : 0.914480


100%|██████████| 439/439 [04:24<00:00,  1.66it/s]
100%|██████████| 244/244 [00:22<00:00, 10.65it/s]


Epoch : 9 
train_loss : 0.204279, 	Train_acc : 0.925966, 
Val_loss : 0.246617, 	Val_acc : 0.910109


100%|██████████| 439/439 [04:20<00:00,  1.69it/s]
100%|██████████| 244/244 [00:23<00:00, 10.55it/s]


Epoch : 10 
train_loss : 0.193511, 	Train_acc : 0.929135, 
Val_loss : 0.221338, 	Val_acc : 0.923087
Validation loss decreased from (0.244942 --> 0.221338).
Saving model ...


100%|██████████| 439/439 [04:23<00:00,  1.67it/s]
100%|██████████| 244/244 [00:23<00:00, 10.44it/s]


Epoch : 11 
train_loss : 0.183451, 	Train_acc : 0.933554, 
Val_loss : 0.223082, 	Val_acc : 0.921857


100%|██████████| 439/439 [04:20<00:00,  1.69it/s]
100%|██████████| 244/244 [00:23<00:00, 10.54it/s]


Epoch : 12 
train_loss : 0.181779, 	Train_acc : 0.933164, 
Val_loss : 0.217249, 	Val_acc : 0.926502
Validation loss decreased from (0.221338 --> 0.217249).
Saving model ...


100%|██████████| 439/439 [04:26<00:00,  1.65it/s]
100%|██████████| 244/244 [00:23<00:00, 10.42it/s]


Epoch : 13 
train_loss : 0.164476, 	Train_acc : 0.939772, 
Val_loss : 0.218051, 	Val_acc : 0.921721


100%|██████████| 439/439 [04:30<00:00,  1.63it/s]
100%|██████████| 244/244 [00:25<00:00,  9.60it/s]


Epoch : 14 
train_loss : 0.160718, 	Train_acc : 0.940705, 
Val_loss : 0.205843, 	Val_acc : 0.919672
Validation loss decreased from (0.217249 --> 0.205843).
Saving model ...


100%|██████████| 439/439 [04:22<00:00,  1.67it/s]
100%|██████████| 244/244 [00:22<00:00, 10.78it/s]


Epoch : 15 
train_loss : 0.159324, 	Train_acc : 0.941594, 
Val_loss : 0.197504, 	Val_acc : 0.932786
Validation loss decreased from (0.205843 --> 0.197504).
Saving model ...


100%|██████████| 439/439 [04:28<00:00,  1.63it/s]
100%|██████████| 244/244 [00:25<00:00,  9.73it/s]


Epoch : 16 
train_loss : 0.145533, 	Train_acc : 0.946196, 
Val_loss : 0.202259, 	Val_acc : 0.919262


100%|██████████| 439/439 [04:24<00:00,  1.66it/s]
100%|██████████| 244/244 [00:23<00:00, 10.48it/s]


Epoch : 17 
train_loss : 0.139006, 	Train_acc : 0.948884, 
Val_loss : 0.185773, 	Val_acc : 0.924180
Validation loss decreased from (0.197504 --> 0.185773).
Saving model ...


100%|██████████| 439/439 [04:27<00:00,  1.64it/s]
100%|██████████| 244/244 [00:23<00:00, 10.21it/s]


Epoch : 18 
train_loss : 0.136202, 	Train_acc : 0.949999, 
Val_loss : 0.208609, 	Val_acc : 0.920082


100%|██████████| 439/439 [04:21<00:00,  1.68it/s]
100%|██████████| 244/244 [00:23<00:00, 10.31it/s]


Epoch : 19 
train_loss : 0.129167, 	Train_acc : 0.953053, 
Val_loss : 0.209262, 	Val_acc : 0.916803


100%|██████████| 439/439 [04:21<00:00,  1.68it/s]
100%|██████████| 244/244 [00:22<00:00, 10.70it/s]


Epoch : 20 
train_loss : 0.122246, 	Train_acc : 0.955262, 
Val_loss : 0.185013, 	Val_acc : 0.933606
Validation loss decreased from (0.185773 --> 0.185013).
Saving model ...


100%|██████████| 439/439 [04:26<00:00,  1.65it/s]
100%|██████████| 244/244 [00:23<00:00, 10.29it/s]


Epoch : 21 
train_loss : 0.107115, 	Train_acc : 0.960751, 
Val_loss : 0.198312, 	Val_acc : 0.926639


100%|██████████| 439/439 [04:22<00:00,  1.67it/s]
100%|██████████| 244/244 [00:22<00:00, 10.79it/s]


Epoch : 22 
train_loss : 0.114644, 	Train_acc : 0.959045, 
Val_loss : 0.221509, 	Val_acc : 0.922540


100%|██████████| 439/439 [04:28<00:00,  1.63it/s]
100%|██████████| 244/244 [00:23<00:00, 10.29it/s]


Epoch : 23 
train_loss : 0.103171, 	Train_acc : 0.963165, 
Val_loss : 0.209158, 	Val_acc : 0.931147


100%|██████████| 439/439 [04:29<00:00,  1.63it/s]
100%|██████████| 244/244 [00:22<00:00, 10.74it/s]


Epoch : 24 
train_loss : 0.092067, 	Train_acc : 0.967427, 
Val_loss : 0.219146, 	Val_acc : 0.929098


100%|██████████| 439/439 [04:27<00:00,  1.64it/s]
100%|██████████| 244/244 [00:22<00:00, 10.69it/s]


Epoch : 25 
train_loss : 0.064158, 	Train_acc : 0.977767, 
Val_loss : 0.181896, 	Val_acc : 0.939754
Validation loss decreased from (0.185013 --> 0.181896).
Saving model ...


100%|██████████| 439/439 [04:24<00:00,  1.66it/s]
100%|██████████| 244/244 [00:24<00:00, 10.15it/s]


Epoch : 26 
train_loss : 0.048955, 	Train_acc : 0.982278, 
Val_loss : 0.182527, 	Val_acc : 0.941803


100%|██████████| 439/439 [04:24<00:00,  1.66it/s]
100%|██████████| 244/244 [00:25<00:00,  9.75it/s]


Epoch : 27 
train_loss : 0.044990, 	Train_acc : 0.984055, 
Val_loss : 0.197718, 	Val_acc : 0.939344


100%|██████████| 439/439 [04:25<00:00,  1.65it/s]
100%|██████████| 244/244 [00:23<00:00, 10.36it/s]


Epoch : 28 
train_loss : 0.039294, 	Train_acc : 0.986151, 
Val_loss : 0.199061, 	Val_acc : 0.936885


100%|██████████| 439/439 [04:25<00:00,  1.66it/s]
100%|██████████| 244/244 [00:25<00:00,  9.46it/s]


Epoch : 29 
train_loss : 0.031342, 	Train_acc : 0.988885, 
Val_loss : 0.215062, 	Val_acc : 0.938524


100%|██████████| 439/439 [04:21<00:00,  1.68it/s]
100%|██████████| 244/244 [00:22<00:00, 10.96it/s]


Epoch : 30 
train_loss : 0.030256, 	Train_acc : 0.989385, 
Val_loss : 0.202590, 	Val_acc : 0.942622


100%|██████████| 439/439 [04:21<00:00,  1.68it/s]
100%|██████████| 244/244 [00:22<00:00, 10.73it/s]


Epoch : 31 
train_loss : 0.029753, 	Train_acc : 0.989614, 
Val_loss : 0.213156, 	Val_acc : 0.937295


100%|██████████| 439/439 [04:25<00:00,  1.65it/s]
100%|██████████| 244/244 [00:22<00:00, 10.66it/s]


Epoch : 32 
train_loss : 0.026630, 	Train_acc : 0.990662, 
Val_loss : 0.197757, 	Val_acc : 0.942213


100%|██████████| 439/439 [04:21<00:00,  1.68it/s]
100%|██████████| 244/244 [00:22<00:00, 10.79it/s]


Epoch : 33 
train_loss : 0.025757, 	Train_acc : 0.990501, 
Val_loss : 0.206194, 	Val_acc : 0.937295


100%|██████████| 439/439 [04:22<00:00,  1.67it/s]
100%|██████████| 244/244 [00:24<00:00, 10.16it/s]

Epoch : 34 
train_loss : 0.026983, 	Train_acc : 0.989751, 
Val_loss : 0.204006, 	Val_acc : 0.939344





In [9]:
val_loss_min

0.18189562410550367

### 测试

In [11]:
name = "SA_InceptionV3_pre_finetune"
model = torch.load(os.path.join(model_path,f"{name}.pth"))
model = model.to(DEVICE)

# 混淆矩阵
M = np.zeros((4,4))
with torch.no_grad():
    test_loss=0
    test_acc=0
    for images,labels in tqdm(test_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds = model(images)
        loss = criterion(preds, labels)
        test_loss += loss.item()
        test_acc += accuracy(preds, labels)
        
        M[preds.argmax(),labels[0]]+=1
        
    avg_test_loss = test_loss / len(test_loader)
    avg_test_acc = test_acc / len(test_loader)
    print("Test_loss : {:.6f}, \tTest_acc : {:.6f}".format(avg_test_loss,avg_test_acc))
    
    M = M.astype("int")
    print(M)
    

100%|██████████| 2706/2706 [01:05<00:00, 41.05it/s]

Test_loss : 0.159676, 	Test_acc : 0.946415
[[ 361    3    2    0]
 [   4  809   28   33]
 [   2   43 1141    1]
 [   0   24    5  250]]



