In [1]:
import pandas as pd
import os
import torch.nn.functional as F
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import Dataset,DataLoader,random_split
from torch.nn.parameter import Parameter


torch.manual_seed(666)
torch.cuda.manual_seed(666)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

## 数据集：多种肺炎X光图像

### 数据集

In [2]:
data_dir = "../data/COVID_Dataset"
TRAIN = 'train'
TEST = 'test'
VAL = 'val'

def apply_transform(mode=None):
    size = (299,299)
    crop = 299
    if mode == 'train':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.RandomHorizontalFlip(),
                               transforms.RandomRotation((-20,+20)),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    elif mode == 'test' or mode == 'val':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    return transform

trainset = datasets.ImageFolder(os.path.join(data_dir, TRAIN),
                                transform = apply_transform(TRAIN))

trainset,valset = random_split(trainset,[len(trainset)-int(len(trainset)*0.1),int(len(trainset)*0.1)])

testset = datasets.ImageFolder(os.path.join(data_dir, TEST),
                               transform = apply_transform(TEST))

train_loader = DataLoader(trainset,
                          batch_size=50,
                          shuffle=True)

val_loader = DataLoader(valset,
                        batch_size=10)

test_loader = DataLoader(testset,
                         batch_size=1)


print(testset.class_to_idx)
x,y = next(iter(test_loader))
print(x.shape,y.shape)
print(len(train_loader),len(test_loader),len(val_loader))




{'COVID': 0, 'Lung_Opacity': 1, 'Normal': 2, 'Viral': 3}
torch.Size([1, 3, 299, 299]) torch.Size([1])
439 2706 244


### 模型

In [3]:
# SA模块
class SA_Layer(nn.Module):
    def __init__(self,channels,groups=64) -> None:
        super().__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1,channels//(2*groups),1,1))
        self.cbias = Parameter(torch.ones(1,channels//(2*groups),1,1))
        self.sweight = Parameter(torch.zeros(1,channels//(2*groups),1,1))
        self.sbias = Parameter(torch.ones(1,channels//(2*groups),1,1))
        
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channels//(2*groups),channels//(2*groups))

    @staticmethod
    def channel_suffle(x,groups):
        b,c,h,w = x.shape
        # 输入特征图分组
        x = x.reshape(b,groups,-1,h,w)
        # 洗牌
        x = x.permute(0,2,1,3,4)
        x = x.reshape(b,-1,h,w)
        return x
    

    def forward(self,x):
        b,c,h,w = x.shape
        x = x.reshape(b*self.groups,-1,h,w)
        x_0, x_1 = x.chunk(2,dim=1)

        # 通道注意力
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0*self.sigmoid(xn)

        # 空间注意力
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1*self.sigmoid(xs)

        # 在通道维度上拼接
        out = torch.cat([xn,xs],dim=1)
        out = out.reshape(b,-1,h,w)
        out = self.channel_suffle(out,2)

        return out

# Inception 残差块
class Residual(nn.Module):
    def __init__(self,in_channels,out_channels,inception,kernel_size=(1,1),stride=1,padding=0) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding)
        self.inception = inception
    def forward(self,x):
        conv = self.conv1(x)
        y = self.inception(x)
        out = F.relu(conv+y)
        return out


# 预训练InceptionV3 + SA
def get_deepLavV3(out_ch):
    model = torchvision.models.inception_v3(pretrained=True)
    
    # 在每个Inception层 添加残差
    model.Mixed_5b = Residual(192,256,model.Mixed_5b)
    model.Mixed_5c = Residual(256,288,model.Mixed_5c)
    model.Mixed_5d = Residual(288,288,model.Mixed_5d)
    model.Mixed_6a = Residual(288,768,model.Mixed_6a,(5,5),2,1)
    model.Mixed_6b = Residual(768,768,model.Mixed_6b)
    model.Mixed_6c = Residual(768,768,model.Mixed_6c)
    model.Mixed_6d = Residual(768,768,model.Mixed_6d)
    model.Mixed_6e = Residual(768,768,model.Mixed_6e)
    model.Mixed_7a = Residual(768,1280,model.Mixed_7a,(5,5),2,1)
    model.Mixed_7b = Residual(1280,2048,model.Mixed_7b)
    model.Mixed_7c = Residual(2048,2048,model.Mixed_7c)

    # 在每个Inception层后 添加SA
    model.Mixed_5d = nn.Sequential(
        model.Mixed_5d,
        SA_Layer(288,8)
    )
    model.Mixed_6e = nn.Sequential(
        model.Mixed_6e,
        SA_Layer(768,64)
    )
    model.Mixed_7c = nn.Sequential(
        model.Mixed_7c,
        SA_Layer(2048,128)
    )
    
    # 冻结参数
    for param in model.parameters():
        param.requires_grad = False
    
    # 改全连接层
    in_features = model.fc.in_features
    classifier = nn.Sequential(
        nn.Linear(in_features, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, out_ch),
        nn.LogSoftmax(dim=1)
    )
    model.fc = classifier
    return model

model = get_deepLavV3(4)
# print(model)
x = torch.tensor(np.random.random((2,3,299,299))).float()
y,_ = model(x)
y.shape,_.shape


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /home/featurize/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

(torch.Size([2, 4]), torch.Size([2, 1000]))

### 评价指标

In [4]:
def accuracy(preds, labels):
    preds = torch.exp(preds)
    top_p,top_class = preds.topk(1, dim=1)
    equals = top_class == labels.view(*top_class.shape)
    return torch.mean(equals.type(torch.FloatTensor))

### 损失函数

In [5]:
# Focal Loss: alpha类型为tensor,是类别权重
class FocalLoss(nn.Module):
    def __init__(self,class_num,gamma=2,alpha=None,reduction="mean") -> None:
        super().__init__()
        if alpha is None:
            self.alpha = torch.ones((class_num,1))
        else:
            self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.class_num = class_num
    
    def forward(self,predict,target):
        pt = torch.nn.functional.softmax(predict,dim=1)
        class_mask = torch.nn.functional.one_hot(target,self.class_num)
        ids = target.view(-1,1)
        alpha = self.alpha[ids.data.view(-1)]

        probs = (pt*class_mask).sum(1).view(-1,1)
        log_p = probs.log()
        loss = -alpha*(torch.pow(1-probs,self.gamma))*log_p
        if self.reduction=="mean":
            loss = loss.mean()
        elif self.reduction=="sum":
            loss = loss.sum()
        return loss


# loss = FocalLoss(4,alpha=torch.tensor([0.135,0.325,0.435,0.105]))
# predict = torch.tensor([[1,2,3,4],[100,3,2,1]]).float()
# label = torch.tensor([3,0])
# print(predict.shape,label.shape)
# loss(predict,label)


### 冻结训练

In [6]:
model = get_deepLavV3(4)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

epochs = 25
val_loss_min = np.Inf
max_e = 10

# 模型存放路径
model_path = os.path.join('./model')
name = "SA_ResInceptionV3_pre"
model = model.to(DEVICE)
for epoch in range(epochs):

    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0

    model.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        preds,_ = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(preds, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)

    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(val_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_acc += accuracy(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

    schedular.step(avg_val_loss)

    print("Epoch : {} \ntrain_loss : {:.6f}, \tTrain_acc : {:.6f}, \nVal_loss : {:.6f}, \tVal_acc : {:.6f}".format(epoch + 1,
                                                                                                                   avg_train_loss, avg_train_acc,
                                                                                                                   avg_val_loss, avg_val_acc))
    if avg_val_loss <= val_loss_min:
        print('Validation loss decreased from ({:.6f} --> {:.6f}).\nSaving model ...'.format(val_loss_min, avg_val_loss))
        torch.save(model,os.path.join(model_path,f"{name}.pth"))
        val_loss_min = avg_val_loss
        max_e = 10
    max_e -= 1
    if max_e<=0:
        break


100%|██████████| 439/439 [03:32<00:00,  2.06it/s]
100%|██████████| 244/244 [00:24<00:00, 10.03it/s]


Epoch : 1 
train_loss : 1.221208, 	Train_acc : 0.472107, 
Val_loss : 1.196583, 	Val_acc : 0.481831
Validation loss decreased from (inf --> 1.196583).
Saving model ...


100%|██████████| 439/439 [03:31<00:00,  2.07it/s]
100%|██████████| 244/244 [00:24<00:00,  9.80it/s]


Epoch : 2 
train_loss : 1.189612, 	Train_acc : 0.488551, 
Val_loss : 1.212457, 	Val_acc : 0.449044


100%|██████████| 439/439 [03:36<00:00,  2.03it/s]
100%|██████████| 244/244 [00:24<00:00, 10.02it/s]


Epoch : 3 
train_loss : 1.192763, 	Train_acc : 0.470628, 
Val_loss : 1.175332, 	Val_acc : 0.556011
Validation loss decreased from (1.196583 --> 1.175332).
Saving model ...


100%|██████████| 439/439 [03:33<00:00,  2.06it/s]
100%|██████████| 244/244 [00:25<00:00,  9.62it/s]


Epoch : 4 
train_loss : 1.179735, 	Train_acc : 0.486846, 
Val_loss : 1.172207, 	Val_acc : 0.456831
Validation loss decreased from (1.175332 --> 1.172207).
Saving model ...


100%|██████████| 439/439 [03:30<00:00,  2.09it/s]
100%|██████████| 244/244 [00:25<00:00,  9.75it/s]


Epoch : 5 
train_loss : 1.180675, 	Train_acc : 0.478096, 
Val_loss : 1.152519, 	Val_acc : 0.556967
Validation loss decreased from (1.172207 --> 1.152519).
Saving model ...


100%|██████████| 439/439 [03:30<00:00,  2.09it/s]
100%|██████████| 244/244 [00:24<00:00,  9.99it/s]


Epoch : 6 
train_loss : 1.169997, 	Train_acc : 0.503591, 
Val_loss : 1.148966, 	Val_acc : 0.555874
Validation loss decreased from (1.152519 --> 1.148966).
Saving model ...


100%|██████████| 439/439 [03:28<00:00,  2.10it/s]
100%|██████████| 244/244 [00:23<00:00, 10.22it/s]


Epoch : 7 
train_loss : 1.173980, 	Train_acc : 0.498485, 
Val_loss : 1.155701, 	Val_acc : 0.496585


100%|██████████| 439/439 [03:29<00:00,  2.09it/s]
100%|██████████| 244/244 [00:24<00:00,  9.92it/s]


Epoch : 8 
train_loss : 1.157942, 	Train_acc : 0.511289, 
Val_loss : 1.131294, 	Val_acc : 0.569399
Validation loss decreased from (1.148966 --> 1.131294).
Saving model ...


100%|██████████| 439/439 [03:25<00:00,  2.14it/s]
100%|██████████| 244/244 [00:23<00:00, 10.47it/s]


Epoch : 9 
train_loss : 1.163224, 	Train_acc : 0.509285, 
Val_loss : 1.137196, 	Val_acc : 0.547131


100%|██████████| 439/439 [03:22<00:00,  2.16it/s]
100%|██████████| 244/244 [00:24<00:00, 10.07it/s]


Epoch : 10 
train_loss : 1.160836, 	Train_acc : 0.509987, 
Val_loss : 1.148393, 	Val_acc : 0.535792


100%|██████████| 439/439 [03:26<00:00,  2.13it/s]
100%|██████████| 244/244 [00:23<00:00, 10.20it/s]


Epoch : 11 
train_loss : 1.153625, 	Train_acc : 0.513772, 
Val_loss : 1.093538, 	Val_acc : 0.593169
Validation loss decreased from (1.131294 --> 1.093538).
Saving model ...


100%|██████████| 439/439 [03:29<00:00,  2.09it/s]
100%|██████████| 244/244 [00:25<00:00,  9.75it/s]


Epoch : 12 
train_loss : 1.156053, 	Train_acc : 0.509148, 
Val_loss : 1.119251, 	Val_acc : 0.549454


100%|██████████| 439/439 [03:26<00:00,  2.13it/s]
100%|██████████| 244/244 [00:23<00:00, 10.29it/s]


Epoch : 13 
train_loss : 1.149971, 	Train_acc : 0.516415, 
Val_loss : 1.107236, 	Val_acc : 0.555601


100%|██████████| 439/439 [03:23<00:00,  2.15it/s]
100%|██████████| 244/244 [00:23<00:00, 10.17it/s]


Epoch : 14 
train_loss : 1.144787, 	Train_acc : 0.513931, 
Val_loss : 1.121097, 	Val_acc : 0.582924


100%|██████████| 439/439 [03:22<00:00,  2.17it/s]
100%|██████████| 244/244 [00:23<00:00, 10.42it/s]


Epoch : 15 
train_loss : 1.150418, 	Train_acc : 0.512153, 
Val_loss : 1.081674, 	Val_acc : 0.568443
Validation loss decreased from (1.093538 --> 1.081674).
Saving model ...


100%|██████████| 439/439 [03:23<00:00,  2.15it/s]
100%|██████████| 244/244 [00:24<00:00,  9.99it/s]


Epoch : 16 
train_loss : 1.144337, 	Train_acc : 0.512628, 
Val_loss : 1.105959, 	Val_acc : 0.587842


100%|██████████| 439/439 [03:25<00:00,  2.13it/s]
100%|██████████| 244/244 [00:24<00:00,  9.87it/s]


Epoch : 17 
train_loss : 1.147061, 	Train_acc : 0.511699, 
Val_loss : 1.099244, 	Val_acc : 0.558470


100%|██████████| 439/439 [03:24<00:00,  2.14it/s]
100%|██████████| 244/244 [00:23<00:00, 10.23it/s]


Epoch : 18 
train_loss : 1.146056, 	Train_acc : 0.511083, 
Val_loss : 1.119119, 	Val_acc : 0.578825


100%|██████████| 439/439 [03:26<00:00,  2.13it/s]
100%|██████████| 244/244 [00:24<00:00,  9.86it/s]


Epoch : 19 
train_loss : 1.137377, 	Train_acc : 0.513562, 
Val_loss : 1.084656, 	Val_acc : 0.557240


100%|██████████| 439/439 [03:24<00:00,  2.15it/s]
100%|██████████| 244/244 [00:23<00:00, 10.29it/s]


Epoch : 20 
train_loss : 1.129827, 	Train_acc : 0.519968, 
Val_loss : 1.095895, 	Val_acc : 0.564891


100%|██████████| 439/439 [03:29<00:00,  2.10it/s]
100%|██████████| 244/244 [00:25<00:00,  9.72it/s]


Epoch : 21 
train_loss : 1.124647, 	Train_acc : 0.521173, 
Val_loss : 1.095177, 	Val_acc : 0.566940


100%|██████████| 439/439 [03:28<00:00,  2.11it/s]
100%|██████████| 244/244 [00:23<00:00, 10.30it/s]


Epoch : 22 
train_loss : 1.125574, 	Train_acc : 0.523813, 
Val_loss : 1.080096, 	Val_acc : 0.583060
Validation loss decreased from (1.081674 --> 1.080096).
Saving model ...


100%|██████████| 439/439 [03:26<00:00,  2.13it/s]
100%|██████████| 244/244 [00:24<00:00, 10.03it/s]


Epoch : 23 
train_loss : 1.124657, 	Train_acc : 0.523839, 
Val_loss : 1.083196, 	Val_acc : 0.580464


100%|██████████| 439/439 [03:22<00:00,  2.16it/s]
100%|██████████| 244/244 [00:23<00:00, 10.45it/s]


Epoch : 24 
train_loss : 1.118247, 	Train_acc : 0.526117, 
Val_loss : 1.083535, 	Val_acc : 0.566940


100%|██████████| 439/439 [03:19<00:00,  2.20it/s]
100%|██████████| 244/244 [00:23<00:00, 10.33it/s]

Epoch : 25 
train_loss : 1.118473, 	Train_acc : 0.525888, 
Val_loss : 1.081554, 	Val_acc : 0.579781





In [7]:
val_loss_min

1.0800955942908272

### 解冻参数

In [8]:
# 载入冻结训练的模型,解冻卷积层参数
model_path = os.path.join('./model')
name = "SA_ResInceptionV3_pre"
model = torch.load(os.path.join(model_path,f"{name}.pth"))

for param in model.parameters():
        param.requires_grad = True

# for k,v in model.named_parameters():
#     print(f"{k}: {v.requires_grad}")
# model

### finetune

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

epochs =50
max_e = 10

name = "SA_ResInceptionV3_pre_finetune"
model = model.to(DEVICE)
for epoch in range(epochs):

    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0

    model.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds,_ = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(preds, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)

    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(val_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_acc += accuracy(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

    schedular.step(avg_val_loss)

    print("Epoch : {} \ntrain_loss : {:.6f}, \tTrain_acc : {:.6f}, \nVal_loss : {:.6f}, \tVal_acc : {:.6f}".format(epoch + 1,
                                                                                                                   avg_train_loss, avg_train_acc,
                                                                                                                   avg_val_loss, avg_val_acc))
    if avg_val_loss <= val_loss_min:
        print('Validation loss decreased from ({:.6f} --> {:.6f}).\nSaving model ...'.format(val_loss_min, avg_val_loss))
        torch.save(model, os.path.join(model_path,f"{name}.pth"))
        val_loss_min = avg_val_loss
        max_e = 10
    max_e -= 1
    if max_e<=0:
        break


100%|██████████| 439/439 [04:44<00:00,  1.54it/s]
100%|██████████| 244/244 [00:23<00:00, 10.34it/s]


Epoch : 1 
train_loss : 0.737462, 	Train_acc : 0.695121, 
Val_loss : 0.787989, 	Val_acc : 0.673770
Validation loss decreased from (1.080096 --> 0.787989).
Saving model ...


100%|██████████| 439/439 [04:52<00:00,  1.50it/s]
100%|██████████| 244/244 [00:23<00:00, 10.17it/s]


Epoch : 2 
train_loss : 0.383825, 	Train_acc : 0.855122, 
Val_loss : 0.490741, 	Val_acc : 0.824590
Validation loss decreased from (0.787989 --> 0.490741).
Saving model ...


100%|██████████| 439/439 [04:52<00:00,  1.50it/s]
100%|██████████| 244/244 [00:24<00:00, 10.00it/s]


Epoch : 3 
train_loss : 0.320856, 	Train_acc : 0.882912, 
Val_loss : 0.478174, 	Val_acc : 0.818169
Validation loss decreased from (0.490741 --> 0.478174).
Saving model ...


100%|██████████| 439/439 [04:54<00:00,  1.49it/s]
100%|██████████| 244/244 [00:24<00:00, 10.13it/s]


Epoch : 4 
train_loss : 0.286972, 	Train_acc : 0.895781, 
Val_loss : 0.407104, 	Val_acc : 0.847678
Validation loss decreased from (0.478174 --> 0.407104).
Saving model ...


100%|██████████| 439/439 [04:50<00:00,  1.51it/s]
100%|██████████| 244/244 [00:24<00:00, 10.09it/s]


Epoch : 5 
train_loss : 0.267528, 	Train_acc : 0.902823, 
Val_loss : 0.267089, 	Val_acc : 0.909563
Validation loss decreased from (0.407104 --> 0.267089).
Saving model ...


100%|██████████| 439/439 [04:56<00:00,  1.48it/s]
100%|██████████| 244/244 [00:24<00:00,  9.77it/s]


Epoch : 6 
train_loss : 0.251985, 	Train_acc : 0.909426, 
Val_loss : 0.257137, 	Val_acc : 0.901366
Validation loss decreased from (0.267089 --> 0.257137).
Saving model ...


100%|██████████| 439/439 [05:00<00:00,  1.46it/s]
100%|██████████| 244/244 [00:24<00:00,  9.86it/s]


Epoch : 7 
train_loss : 0.231268, 	Train_acc : 0.914713, 
Val_loss : 0.292557, 	Val_acc : 0.886612


100%|██████████| 439/439 [05:03<00:00,  1.44it/s]
100%|██████████| 244/244 [00:25<00:00,  9.66it/s]


Epoch : 8 
train_loss : 0.220691, 	Train_acc : 0.919292, 
Val_loss : 0.247322, 	Val_acc : 0.902459
Validation loss decreased from (0.257137 --> 0.247322).
Saving model ...


100%|██████████| 439/439 [05:00<00:00,  1.46it/s]
100%|██████████| 244/244 [00:25<00:00,  9.72it/s]


Epoch : 9 
train_loss : 0.214383, 	Train_acc : 0.922233, 
Val_loss : 0.219643, 	Val_acc : 0.918579
Validation loss decreased from (0.247322 --> 0.219643).
Saving model ...


100%|██████████| 439/439 [04:58<00:00,  1.47it/s]
100%|██████████| 244/244 [00:23<00:00, 10.28it/s]


Epoch : 10 
train_loss : 0.201048, 	Train_acc : 0.924987, 
Val_loss : 0.251653, 	Val_acc : 0.910655


100%|██████████| 439/439 [05:07<00:00,  1.43it/s]
100%|██████████| 244/244 [00:26<00:00,  9.23it/s]


Epoch : 11 
train_loss : 0.196083, 	Train_acc : 0.926399, 
Val_loss : 0.244364, 	Val_acc : 0.910792


100%|██████████| 439/439 [05:04<00:00,  1.44it/s]
100%|██████████| 244/244 [00:26<00:00,  9.18it/s]


Epoch : 12 
train_loss : 0.189166, 	Train_acc : 0.929908, 
Val_loss : 0.251765, 	Val_acc : 0.899590


100%|██████████| 439/439 [05:09<00:00,  1.42it/s]
100%|██████████| 244/244 [00:25<00:00,  9.41it/s]


Epoch : 13 
train_loss : 0.182146, 	Train_acc : 0.931982, 
Val_loss : 0.246450, 	Val_acc : 0.911065


100%|██████████| 439/439 [05:05<00:00,  1.44it/s]
100%|██████████| 244/244 [00:25<00:00,  9.46it/s]


Epoch : 14 
train_loss : 0.133089, 	Train_acc : 0.950660, 
Val_loss : 0.179377, 	Val_acc : 0.930328
Validation loss decreased from (0.219643 --> 0.179377).
Saving model ...


100%|██████████| 439/439 [05:05<00:00,  1.44it/s]
100%|██████████| 244/244 [00:24<00:00,  9.99it/s]


Epoch : 15 
train_loss : 0.121078, 	Train_acc : 0.953962, 
Val_loss : 0.167714, 	Val_acc : 0.940164
Validation loss decreased from (0.179377 --> 0.167714).
Saving model ...


100%|██████████| 439/439 [05:05<00:00,  1.44it/s]
100%|██████████| 244/244 [00:25<00:00,  9.53it/s]


Epoch : 16 
train_loss : 0.107092, 	Train_acc : 0.959522, 
Val_loss : 0.166487, 	Val_acc : 0.941939
Validation loss decreased from (0.167714 --> 0.166487).
Saving model ...


100%|██████████| 439/439 [05:08<00:00,  1.42it/s]
100%|██████████| 244/244 [00:26<00:00,  9.34it/s]


Epoch : 17 
train_loss : 0.104198, 	Train_acc : 0.961093, 
Val_loss : 0.174143, 	Val_acc : 0.934425


100%|██████████| 439/439 [05:05<00:00,  1.44it/s]
100%|██████████| 244/244 [00:26<00:00,  9.38it/s]


Epoch : 18 
train_loss : 0.097041, 	Train_acc : 0.963509, 
Val_loss : 0.183033, 	Val_acc : 0.933196


100%|██████████| 439/439 [05:03<00:00,  1.44it/s]
100%|██████████| 244/244 [00:25<00:00,  9.54it/s]


Epoch : 19 
train_loss : 0.089136, 	Train_acc : 0.966902, 
Val_loss : 0.181977, 	Val_acc : 0.940573


100%|██████████| 439/439 [05:02<00:00,  1.45it/s]
100%|██████████| 244/244 [00:25<00:00,  9.62it/s]


Epoch : 20 
train_loss : 0.086710, 	Train_acc : 0.966424, 
Val_loss : 0.180540, 	Val_acc : 0.941803


100%|██████████| 439/439 [05:01<00:00,  1.46it/s]
100%|██████████| 244/244 [00:25<00:00,  9.51it/s]


Epoch : 21 
train_loss : 0.078191, 	Train_acc : 0.971117, 
Val_loss : 0.170797, 	Val_acc : 0.941803


100%|██████████| 439/439 [05:00<00:00,  1.46it/s]
100%|██████████| 244/244 [00:24<00:00,  9.92it/s]


Epoch : 22 
train_loss : 0.074339, 	Train_acc : 0.972187, 
Val_loss : 0.195507, 	Val_acc : 0.934016


100%|██████████| 439/439 [05:00<00:00,  1.46it/s]
100%|██████████| 244/244 [00:25<00:00,  9.67it/s]


Epoch : 23 
train_loss : 0.075144, 	Train_acc : 0.970753, 
Val_loss : 0.184421, 	Val_acc : 0.935246


100%|██████████| 439/439 [05:03<00:00,  1.44it/s]
100%|██████████| 244/244 [00:25<00:00,  9.57it/s]


Epoch : 24 
train_loss : 0.074019, 	Train_acc : 0.972072, 
Val_loss : 0.182784, 	Val_acc : 0.942622


100%|██████████| 439/439 [04:56<00:00,  1.48it/s]
100%|██████████| 244/244 [00:24<00:00,  9.89it/s]

Epoch : 25 
train_loss : 0.073104, 	Train_acc : 0.971641, 
Val_loss : 0.190575, 	Val_acc : 0.931147





In [10]:
val_loss_min

0.16648661456410543

### 测试

In [11]:
name = "SA_ResInceptionV3_pre_finetune"
model = torch.load(os.path.join(model_path,f"{name}.pth"))
model = model.to(DEVICE)

# 混淆矩阵
M = np.zeros((4,4))
with torch.no_grad():
    test_loss=0
    test_acc=0
    for images,labels in tqdm(test_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds = model(images)
        loss = criterion(preds, labels)
        test_loss += loss.item()
        test_acc += accuracy(preds, labels)
        
        M[preds.argmax(),labels[0]]+=1
        
    avg_test_loss = test_loss / len(test_loader)
    avg_test_acc = test_acc / len(test_loader)
    print("Test_loss : {:.6f}, \tTest_acc : {:.6f}".format(avg_test_loss,avg_test_acc))
    
    M = M.astype("int")
    print(M)
    

100%|██████████| 2706/2706 [01:14<00:00, 36.49it/s]

Test_loss : 0.156628, 	Test_acc : 0.944568
[[ 360    0    1    0]
 [   4  801   28   26]
 [   3   37 1140    3]
 [   0   41    7  255]]



