In [None]:
import os
import torch 
import torchvision 
from tqdm import tqdm
from torchvision import transforms, datasets, models
import numpy as np 
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torch.utils.model_zoo as model_zoo
import math
import torch.nn.functional as F
from PIL import Image

from models.resFPNCBAM import resnet50
from sklearn.metrics import roc_auc_score,f1_score

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
# super parameter setting
batch_size = 16
num_workers = 8

#处理设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# data enhancement
transform_train = transforms.Compose([
    transforms.Resize([512,512]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(90),
    #transforms.RandomAffine(degrees=0, translate=(0.05,0.05)),
    transforms.ToTensor()
])

transform_test = transforms.Compose([
    transforms.Resize([512,512]),
    transforms.ToTensor()
])

In [None]:
# Data loading
class CovidDataset(Dataset):
    def __init__(self, path, transform=None):
        df = pd.read_csv(path)
        df = df.set_index('case')
        self.X = list(df.index)
        self.Y = df['type'].values.tolist()
        self.lenth = len(self.Y)
        self.transform = transform
        
    def __len__(self):
        return self.lenth
    
    def __getitem__(self, index):
        picture_path = os.path.join('./data/',self.X[index])
        image = np.asarray(Image.open(picture_path),dtype='float32')/4095
        image = Image.fromarray(image)
        image = self.transform(image)
        x = torch.Tensor(3,512,512)
        for i in range(3):    
            x[i,:,:]=image
        y = self.Y[index]
        return x,y

In [None]:
train_loader_list = []
test_loader_list = []
for i in range(5):
    csv_train_path = './CSV/train' + str(i+1) + '.csv'
    train_dataset = CovidDataset(csv_train_path,transform_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=num_workers, pin_memory=True)
    train_loader_list.append(train_loader)
    
    csv_path = './CSV/test' + str(i+1) + '.csv'
    test_dataset = CovidDataset(csv_path,transform_test)
    test_loader = DataLoader(test_dataset, batch_size=8, 
                          shuffle=False, num_workers=8, pin_memory=True)
    test_loader_list.append(test_loader)

In [None]:
def calc_accuracy(outputs, targets):
    #outputs = F.softmax(outputs)
    #outputs[:,1][outputs[:,1]<0.7] = 0
    _, pred = torch.max(outputs, 1)
    acc = torch.sum(pred==targets.data).item() / len(targets)
    return acc 

In [None]:
# model (The specific model is in the resFPNCBAM file in the models folder)
class MyModel(nn.Module):
    def __init__(self, pretrained):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(2*128*128, 2)
        if pretrained:
            self.backbone = resnet50(pretrained=True, num_classes=2)
        else:
            self.backbone = resnet50(pretrained=False, num_classes=2)

    def forward(self, x):
        out = self.backbone(x)
        out = out.view(-1, 2*128*128)
        out = self.fc(out)
        return out

In [None]:
for fold in range(5):
    model = MyModel(True)
    model = model.to(device)
    
    train_loader = train_loader_list[fold]
    test_loader = test_loader_list[fold]
    
    best_acc = 0.
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(100):
        epoch_loss = 0.0
        epoch_acc = 0.0
        # training
        model.train()
        for imgs, targets in tqdm(train_loader):
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = model(imgs)
            curr_loss = criterion(outputs, targets)
    
            epoch_loss += curr_loss.item()
            epoch_acc += calc_accuracy(outputs, targets)

            optimizer.zero_grad()
            curr_loss.backward()
            optimizer.step()
        epoch_loss = epoch_loss / len(train_loader)
        epoch_acc = epoch_acc / len(train_loader)
        print('Epoch {}, loss: {:.4f}, acc: {:.4f}'.format(epoch+1, epoch_loss, epoch_acc))
    
        # testing
        model.eval()
        criterion = nn.CrossEntropyLoss()
        epoch_loss = 0.0
        epoch_acc = 0.0

        with torch.no_grad():
            for imgs, targets in tqdm(test_loader):
                imgs = imgs.to(device)
                targets = targets.to(device)
                outputs = model(imgs)

                curr_loss = criterion(outputs, targets)
                epoch_loss += curr_loss.item()
                epoch_acc += calc_accuracy(outputs, targets)

        epoch_loss = epoch_loss / len(test_loader)
        epoch_acc = epoch_acc / len(test_loader)

        if(epoch_acc > best_acc):     
            best_acc = epoch_acc
            file_name = './weight/' + str(fold+1) + '-resFPNCBAM.pth'
            torch.save(model.state_dict(), file_name)
        print('test_loss: {:.4f}, val_acc: {:.4f}'.format(epoch_loss, epoch_acc))