# Import Libraries

In [None]:
import os
import time
import copy
import numpy as np
import pandas as pd
from PIL import Image
import torch,torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets,models,transforms
import torch.optim as optim
# from torchsummary import summary
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix
from torch.nn import Parameter
from pathlib import Path
import cv2

In [None]:
model_path='../input/whole-model/k_cross_CNN.pt'

# Load train, val, test csv files

In [None]:
train_df = pd.read_csv('../input/covidxct/train_COVIDx_CT-2A.txt', sep=" ", header=None)
train_df.columns=['filename', 'label', 'xmin','ymin','xmax','ymax']

# 读取test.txt
val_df = pd.read_csv('../input/covidxct/val_COVIDx_CT-2A.txt', sep=" ", header=None)
val_df.columns=['filename', 'label', 'xmin','ymin','xmax','ymax']

test_df = pd.read_csv('../input/covidxct/test_COVIDx_CT-2A.txt', sep=" ", header=None)
test_df.columns=['filename', 'label', 'xmin','ymin','xmax','ymax']

In [None]:
train_df.head()
train_df.label.value_counts()

In [None]:
image_path = '../input/covidxct/2A_images/'  #directory path
train_df['filename'] = image_path+train_df['filename']
val_df['filename'] = image_path+val_df['filename']
test_df['filename'] = image_path + test_df['filename']
train_df.head()

# Leaving pneumonia data

In [None]:
train_df = train_df[train_df['label']!=1]
val_df = val_df[val_df['label']!=1]
test_df = test_df[test_df['label']!=1]

In [None]:
train_df['label']=train_df['label'].replace(2,1)
val_df['label']=val_df['label'].replace(2,1)
test_df['label']=test_df['label'].replace(2,1)

In [None]:
train_df.label.value_counts()

In [None]:
train_df = shuffle(train_df) # 打乱顺序
val_df = shuffle(val_df)
test_df = shuffle(test_df)
train_df.head()

In [None]:
labels={0:'Normal',1:'COVID-19'}
class_names=['Normal','COVID-19']

train_df['label_n']=[labels[b] for b in train_df['label']]
val_df['label_n']=[labels[b] for b in val_df['label']]
test_df['label_n']=[labels[b] for b in test_df['label']]
train_df.head()

In [None]:
print(f"Normal and Covid-19 values of train: \n{train_df['label_n'].value_counts()}")
print(f"Normal and Covid-19 values of validation: \n{val_df['label_n'].value_counts()}")
print(f"Normal and Covid-19 values of test: \n{test_df['label_n'].value_counts()}")

In [None]:
train_df.head()

In [None]:
train_df=train_df.reset_index()
val_df=val_df.reset_index()
test_df=test_df.reset_index()

In [None]:
class CovidDataset(Dataset):
    def __init__(self, dataset_df, transform=None):
        self.dataset_df = dataset_df
        self.transform = transform
        
    def __len__(self):
        return self.dataset_df.shape[0]
    
    def __getitem__(self, idx):
        image_name = self.dataset_df['filename'][idx]
        xmin,ymin,xmax,ymax=self.dataset_df['xmin'][idx],self.dataset_df['ymin'][idx],self.dataset_df['xmax'][idx],self.dataset_df['ymax'][idx]
        img = cv2.imread(image_name)
        img = img[ymin:ymax, xmin:xmax, :]
        img=transforms.ToTensor()(img)
        img=transforms.ToPILImage()(img)
        label = self.dataset_df['label'][idx]
        
        if self.transform:
            img = self.transform(img)
        return img, label
    

In [None]:
batch_size = 128
input_channel = 3
input_size = (224,224)
num_classes=2
num_epochs = 10

In [None]:
transform = {
    'train':transforms.Compose([

#         transforms.CenterCrop(crop_size),
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(30),
        transforms.ToTensor(),
#         transforms.Normalize([0.6349431],[0.32605055])
#         transforms.Grayscale(input_channel),

    ]),
    'test':transforms.Compose([
#         transforms.CenterCrop(crop_size),
        transforms.Resize(input_size),
        transforms.ToTensor(),
#         transforms.Normalize([0.63507175],[0.3278614])
#         transforms.Grayscale(input_channel),

    ])
}

In [None]:
dataset_names=['train','val','test']
image_transforms = {'train':transform['train'], 'val':transform['test'],'test':transform['test']}

train_dataset = CovidDataset(train_df, transform=image_transforms['train'])
val_dataset = CovidDataset(val_df, transform=image_transforms['test'])
test_dataset = CovidDataset(test_df, transform=image_transforms['test'])

image_dataset = {'train':train_dataset, 'val':val_dataset,'test':test_dataset}

dataloaders = {x:DataLoader(image_dataset[x],batch_size=batch_size,shuffle=True,num_workers=8) for x in dataset_names}

dataset_sizes = {x:len(image_dataset[x]) for x in dataset_names}

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
def show_tensor_img(tensor_img):
#     to_pil = transforms.ToPILImage()
#     img = tensor_img.cpu().clone()
#     img = to_pil(img)
    img=transforms.ToPILImage()(tensor_img)
    plt.figure()
    plt.imshow(img,plt.cm.gray)
    plt.show()

def show_img(idx):
  show_tensor_img(train_dataset[(train_df[train_df['label']==(idx%2)].index)[idx]][0])
for i in range(4):
    show_img(i)

In [None]:
import itertools
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Input
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    cm=cm.numpy()
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        cm=cm.astype('int')
        print('Confusion matrix, without normalization')
#     print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = '{:.2f}' if normalize else '{}'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
      plt.text(i, j, fmt.format(cm[i, j]),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

# 生成混淆矩阵
def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[t, p] += 1
    return conf_matrix

def calculate_all_prediction(conf_matrix):
    '''
    计算总精度：对角线上所有值除以总数
    '''
    total_sum = conf_matrix.sum()
    correct_sum = (np.diag(conf_matrix)).sum()
    prediction = round(100*float(correct_sum)/float(total_sum),2)
    return prediction
 
def calculate_label_prediction(conf_matrix,labelidx):
    '''
    计算某一个类标预测精度：该类被预测正确的数除以该类的总数
    '''
    label_total_sum = conf_matrix.sum(axis=0)[labelidx]
    label_correct_sum = conf_matrix[labelidx][labelidx]
    prediction = 0
    if label_total_sum != 0:
        prediction = round(100*float(label_correct_sum)/float(label_total_sum),2)
    return prediction
 
def calculate_label_recall(conf_matrix,labelidx):
    '''
    计算某一个类标的召回率：
    '''
    label_total_sum = conf_matrix.sum(axis=1)[labelidx]
    label_correct_sum = conf_matrix[labelidx][labelidx]
    recall = 0
    if label_total_sum != 0:
        recall = round(100*float(label_correct_sum)/float(label_total_sum),2)
    return recall
 
def calculate_f1(prediction,recall):
    if (prediction+recall)==0:
        return 0
    return round(2*prediction*recall/(prediction+recall),2)

In [None]:
class RESNET_34_WoF(nn.Module):
    def __init__(self, num_classes,pretrained=True):
        super().__init__()
        # Use a pretrained model
        self.network = models.resnet34(pretrained=pretrained)
        # Replace last layer
        self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)
        for param in self.network.fc.parameters():
          param.requires_grad = True
          
    def forward(self, xb):
        return self.network(xb)

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=torch.load('../input/whole-model/k_cross_CNN.pt')
model=model.to(device)
# print(model)

# model=models.vgg16(pretrained=True)
# # 将所有参数都设置为不计算梯度
# for param in model.parameters():
#     param.requires_grad=False
# num_ftrs=model.classifier[6].in_features # feature_map 的大小
# model.classifier[6]=nn.Linear(num_ftrs,num_classes) #重新设计全连接层
# model=model.to(device)

criterion=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.999)) #优化函数
num_iter=(int(len(train_df)/batch_size))*num_epochs
sched=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=num_iter,
                                                     eta_min=0.00001)

In [None]:
def train(model,epoch,num_epochs,criterion,optimizer,sched):
    model.train()
    print('-' * 100)
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    running_loss = 0.0
    running_corrects = 0
    for idx, (inputs, labels) in enumerate(dataloaders['train']):# 对dataloader进行遍历，dataloader时包含数据及标签的元组
        inputs,labels=inputs.to(device),labels.to(device)
        outputs = model(inputs) # output接受结果
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)  # 默认平均，计算损失值

        #反向传播及更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        sched.step()

#         if idx % 100 == 99:
#             print('train iteration:{},loss:{},acc:{}%'.format( idx, loss.item(),torch.sum(preds == labels.data)/batch_size*100))
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / dataset_sizes['train']
    epoch_acc = running_corrects.double() / dataset_sizes['train']
    print('train_total Loss: {:.4f} Acc: {:.4f}%'.format( epoch_loss, epoch_acc*100))

In [None]:
def test(model,epoch,num_epochs,criterion,optimizer,best_acc):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    best_acc=best_acc
    best_model_wts=copy.deepcopy(model.state_dict())
    conf_matrix = torch.zeros(num_classes, num_classes) # 混淆矩阵初始化
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(dataloaders['val']):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            conf_matrix = confusion_matrix(outputs, labels, conf_matrix) # 生成混淆矩阵

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data) #preds是tensor,labels.data是tensor

        plot_confusion_matrix(conf_matrix, classes=class_names, normalize=False, title='confusion matrix') # 混淆矩阵的可视化

    epoch_loss = running_loss / dataset_sizes['val'] 
    epoch_acc = running_corrects.double() / dataset_sizes['val'] #type为tensor
    print('val_total Loss: {:.4f} Acc: {:.4f}%'.format( epoch_loss, epoch_acc*100))

    all_prediction = calculate_all_prediction(conf_matrix) # 总精度=准确率
    print('all_prediction:{}'.format(all_prediction))
    label_prediction = [] # 存放每个类的精确率
    label_recall = [] # 存放每个类的召回率
    for i in range(num_classes):
        label_prediction.append(calculate_label_prediction(conf_matrix,i))
        label_recall.append(calculate_label_recall(conf_matrix,i))

    keys=class_names
    values=list(range(num_classes))
    dictionary = dict(zip(keys, values))
    for ei,i in enumerate(dictionary):
        print(ei,'\t',i,'\t','prediction=',label_prediction[ei],'%,\trecall=',label_recall[ei],'%,\tf1=',calculate_f1(label_prediction[ei],label_recall[ei])) # 输出每个类的，精确率，召回率，F1
    p = round(np.array(label_prediction).sum()/len(label_prediction),2) # 总精确率
    r = round(np.array(label_recall).sum()/len(label_prediction),2) # 总召回率
    print('MACRO-averaged:\nprediction=',p,'%,recall=',r,'%,f1=',calculate_f1(p,r)) #输出总精确率和召回率

#     print(epoch_acc.tpye)
#     print(best_acc.type)
    if epoch_acc > best_acc:# 获取最好的模型和准确率
        best_acc=epoch_acc.item()
        best_model_wts=copy.deepcopy(model.state_dict())
#     model.load_state_dict(best_model_wts)

    return best_model_wts,best_acc,epoch_acc.item()

In [None]:
if __name__ == '__main__':
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    acc=[]
    for epoch in range(num_epochs):
        train(model,epoch,num_epochs,criterion,optimizer,sched)
        best_model_wts,best_acc,epoch_acc=test(model,epoch,num_epochs,criterion,optimizer,best_acc)
        acc.append(epoch_acc)
    print('*' * 100)
    print('best_acc:{}'.format(best_acc))
    print('*' * 100)
    torch.save(best_model_wts, 'resnet34_wof.pth')

In [None]:
_,_,acc=test(model,1,20,criterion,optimizer,best_acc)
print("Resnet 34 WoF on COVIDXCT dataset Accuray: ",acc)