In [None]:
import pandas as pd
import os
import shutil
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import Dataset,DataLoader,random_split
import torch.nn.functional as F
from torch.nn.parameter import Parameter

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

## 数据集：多种肺炎X光图像,四分类
## 网络: Residua+Inception

### 数据集

In [None]:
data_dir = "./data/COVID_Dataset"
TRAIN = 'train'
TEST = 'test'
VAL = 'val'

def apply_transform(mode=None):
    size = (299,299)
    crop = 299
    if mode == 'train':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.RandomHorizontalFlip(),
                               transforms.RandomRotation((-20,+20)),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    elif mode == 'test' or mode == 'val':
        transform = transforms.Compose([transforms.Resize(size),
                               transforms.CenterCrop(crop),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406],
                                           [0.229, 0.224, 0.225])
                              ])

    return transform

trainset = datasets.ImageFolder(os.path.join(data_dir, TRAIN),
                                transform = apply_transform(TRAIN))

trainset,valset = random_split(trainset,[len(trainset)-int(len(trainset)*0.1),int(len(trainset)*0.1)])

testset = datasets.ImageFolder(os.path.join(data_dir, TEST),
                               transform = apply_transform(TEST))

train_loader = DataLoader(trainset,
                          batch_size=50,
                          shuffle=True)

val_loader = DataLoader(valset,
                        batch_size=10)

test_loader = DataLoader(testset,
                         batch_size=1)

len(train_loader),len(test_loader),len(val_loader)




### 模型

In [None]:
# SA模块
class SA_Layer(nn.Module):
    def __init__(self,channels,groups=64) -> None:
        super().__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1,channels//(2*groups),1,1))
        self.cbias = Parameter(torch.ones(1,channels//(2*groups),1,1))
        self.sweight = Parameter(torch.zeros(1,channels//(2*groups),1,1))
        self.sbias = Parameter(torch.ones(1,channels//(2*groups),1,1))
        
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channels//(2*groups),channels//(2*groups))

    @staticmethod
    def channel_suffle(x,groups):
        b,c,h,w = x.shape
        # 输入特征图分组
        x = x.reshape(b,groups,-1,h,w)
        # 洗牌
        x = x.permute(0,2,1,3,4)
        x = x.reshape(b,-1,h,w)
        return x
    

    def forward(self,x):
        b,c,h,w = x.shape
        x = x.reshape(b*self.groups,-1,h,w)
        x_0, x_1 = x.chunk(2,dim=1)

        # 通道注意力
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0*self.sigmoid(xn)

        # 空间注意力
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1*self.sigmoid(xs)

        # 在通道维度上拼接
        out = torch.cat([xn,xs],dim=1)
        out = out.reshape(b,-1,h,w)
        out = self.channel_suffle(out,2)

        return out

# Inception层
class Inception(nn.Module):
    def __init__(self,in_channels,c1,c2,c3,c4) -> None:
        super().__init__()
        # 路线1    1*1conv
        self.route1x1_1 = nn.Conv2d(in_channels,c1,kernel_size=(1,1))
        # 路线2    1*1conv,3*3conv
        self.route1x1_2 = nn.Conv2d(in_channels,c2[0],kernel_size=(1,1))
        self.route3x3_2 = nn.Conv2d(c2[0],c2[1],kernel_size=(3,3),padding=1)       
        # 路线3    1*1conv,5*5conv
        self.route1x1_3 = nn.Conv2d(in_channels,c3[0],kernel_size=(1,1))
        self.route5x5_3 = nn.Conv2d(c3[0],c3[1],kernel_size=(5,5),padding=2)
        # 路线4    3*3pool,1*1conv
        self.route3x3_4 = nn.MaxPool2d((3,3),stride=1,padding=1)
        self.route1x1_4 = nn.Conv2d(in_channels,c4,kernel_size=(1,1))
    
    def forward(self,x):
        route1 = F.relu(self.route1x1_1(x))
        route2 = F.relu(self.route3x3_2(F.relu(self.route1x1_2(x))))
        route3 = F.relu(self.route5x5_3(F.relu(self.route1x1_3(x))))
        route4 = F.relu(self.route1x1_4(self.route3x3_4(x)))
        out = torch.concat([route1,route2,route3,route4],dim=1)
        return out

# Basic卷积层
class BasicConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel,stride=1,padding=0) -> None:
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=kernel,stride=stride,padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self,x):
        return self.conv(x)

# Inception 残差块
class Residual(nn.Module):
    def __init__(self,in_channels,out_channels,strid=1) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=(1,1),stride=1)
        self.inception = nn.Sequential(Inception(256,64,(64,128),(16,32),32))
    def forward(self,x):
        x = self.conv1(x)
        y = self.inception(x)
        out = F.relu(x+y)
        return out

# 网络
class Res_Inception(nn.Module):
    def __init__(self,in_channel,num_classes) -> None:
        super().__init__()
        self.b1 = nn.Sequential(
            BasicConv2d(in_channel,out_channels=64,kernel=(3,3),stride=2,padding=1),
            nn.MaxPool2d((2,2),2)
        )
        self.b2 = nn.Sequential(
            BasicConv2d(64,128,kernel=(3,3),padding=1),
            nn.MaxPool2d((2,2),2)
        )
        self.b3 = nn.Sequential(
            BasicConv2d(128,256,kernel=(3,3),padding=1),
            nn.MaxPool2d((2,2),2),
        )
        self.b4 = nn.Sequential(
            BasicConv2d(256,256,kernel=(3,3),padding=1),
            nn.MaxPool2d((2,2),2)
        )
        self.b5 = nn.Sequential(
            Residual(256,256),
            nn.MaxPool2d((2,2),2),
            Residual(256,256),
            nn.MaxPool2d((2,2),2),
            Residual(256,256)
        )
        self.AvgPool2D = nn.AvgPool2d((2,2),2)
        self.flatten = nn.Flatten()
        self.b6 = nn.Linear(256,num_classes)
    
    def forward(self,x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        x = self.AvgPool2D(x)
        x = self.flatten(x)
        x = self.b6(x)
        return x
        

### 评价指标

In [None]:
def accuracy(preds, labels):
    preds = torch.exp(preds)
    top_p,top_class = preds.topk(1, dim=1)
    equals = top_class == labels.view(*top_class.shape)
    return torch.mean(equals.type(torch.FloatTensor))

### 训练

In [None]:
model = Res_Inception(3,4)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

epochs = 50
val_loss_min = np.Inf
max_e = 10

# 模型存放路径
model_path = os.path.join('./model')
name = "ResInceptionv3"
model = model.to(DEVICE)
for epoch in range(epochs):

    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0

    model.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        preds = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(preds, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)

    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(val_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_acc += accuracy(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

    schedular.step(avg_val_loss)

    print("Epoch : {} \ntrain_loss : {:.6f}, \tTrain_acc : {:.6f}, \nVal_loss : {:.6f}, \tVal_acc : {:.6f}".format(epoch + 1,
                                                                                                                   avg_train_loss, avg_train_acc,
                                                                                                                   avg_val_loss, avg_val_acc))
    if avg_val_loss <= val_loss_min:
        print('Validation loss decreased from ({:.6f} --> {:.6f}).\nSaving model ...'.format(val_loss_min, avg_val_loss))
        torch.save(model,os.path.join(model_path,f"{name}.pth"))
        val_loss_min = avg_val_loss
        max_e = 20
    max_e -= 1
    if max_e<=0:
        break


In [None]:
val_loss_min

### 测试

In [None]:
name = "ResInceptionv3"
model = torch.load(os.path.join(model_path,f"{name}.pth"))
model = model.to(DEVICE)

with torch.no_grad():
    test_loss=0
    test_acc=0
    for images,labels in tqdm(test_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds = model(images)
        loss = criterion(preds, labels)
        test_loss += loss.item()
        test_acc += accuracy(preds, labels)

    avg_test_loss = test_loss / len(test_loader)
    avg_test_acc = test_acc / len(test_loader)
    print("Test_loss : {:.6f}, \tTest_acc : {:.6f}".format(avg_test_loss,avg_test_acc))

In [None]:
name = "ResInceptionv3"
model = torch.load(os.path.join(model_path,f"{name}.pth"))
model = model.to(DEVICE)

M = np.zeros((4,4))
with torch.no_grad():
    test_loss=0
    test_acc=0
    for images,labels in tqdm(test_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds = model(images)
        loss = criterion(preds, labels)
        test_loss += loss.item()
        test_acc += accuracy(preds, labels)
        
        M[preds.argmax(),labels[0]]+=1
        
    
    print(M)
    
    