In [None]:
import time
import os

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 

import torchvision
import torchvision.transforms.functional as TF

from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.autonotebook import tqdm 

from sklearn.metrics import precision_recall_fscore_support,confusion_matrix, jaccard_score

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
img_size = 256#538
msk_size = 68#340

epochs = 250
learning_rate = 0.2
batch_size = 32

features = [16, 32, 64, 128, 256]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform = None):
        self.img = img_dir
        self.mask = mask_dir
        self.transform = transform
        self.imgpath = os.listdir(img_dir)
        self.mskpath = os.listdir(mask_dir)

    def __len__(self):
        return len(self.imgpath)
        
    
    def __getitem__(self, index):
        img_path = os.path.join(self.img, self.imgpath[index])
        image = cv.resize(cv.imread(img_path,cv.IMREAD_GRAYSCALE),(img_size,img_size)) 
        mask_path = os.path.join(self.mask, self.mskpath[index])
        mask = cv.resize(cv.imread(mask_path,cv.IMREAD_GRAYSCALE),(msk_size,msk_size)) 
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return (image, mask)

In [None]:
path_data_train = "C:/Users/sonbu/Desktop/LVTN/Dataset/LIDC-IDRI/Images"
path_mask_train = "C:/Users/sonbu/Desktop/LVTN/Dataset/LIDC-IDRI/Masks"

dataset = CustomDataset(img_dir = path_data_train, mask_dir = path_mask_train,transform = transforms.ToTensor())

In [None]:
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train, val, test = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, inputs):
        return self.conv(inputs)

In [None]:
class Unet(nn.Module):
    def __init__(self,in_c,out_c):
        super().__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.down1 = DoubleConv(in_c,features[0])
        self.down2 = DoubleConv(features[0],features[1])
        self.down3 = DoubleConv(features[1],features[2])
        self.down4 = DoubleConv(features[2],features[3])
        self.down5 = DoubleConv(features[3],features[4])

        self.ConvT1 = nn.ConvTranspose2d(features[4],features[3], kernel_size=2, stride=2)
        self.ConvT2 = nn.ConvTranspose2d(features[3],features[2], kernel_size=2, stride=2)
        self.ConvT3 = nn.ConvTranspose2d(features[2],features[1], kernel_size=2, stride=2)
        self.ConvT4 = nn.ConvTranspose2d(features[1],features[0], kernel_size=2, stride=2)
        
        self.up1 = DoubleConv(features[4],features[3])
        self.up2 = DoubleConv(features[3],features[2])
        self.up3 = DoubleConv(features[2],features[1])
        self.up4 = DoubleConv(features[1],features[0])
        
        self.output = nn.Conv2d(features[0], out_c, kernel_size=1)
    
    def forward(self, inputs):
        d1 = self.down1(inputs)
        m1 = self.maxpool(d1)
        
        d2 = self.down2(m1)
        m2 = self.maxpool(d2)
        
        d3 = self.down3(m2)
        m3 = self.maxpool(d3)
        
        d4 = self.down4(m3)
        m4 = self.maxpool(d4)
        
        d5 = self.down5(m4)
    
        bottleneck = self.ConvT1(d5)
        c1 = TF.resize(d4,size = [bottleneck.shape[2],bottleneck.shape[2]])
        u1 = self.up1(torch.cat([bottleneck,c1],1))
        
        u1 = self.ConvT2(u1)
        c2 = TF.resize(d3,size = [u1.shape[2],u1.shape[2]])
        u2 = self.up2(torch.cat([u1,c2],1))

        u2 = self.ConvT3(u2)
        c3 = TF.resize(d2,size = [u2.shape[2],u2.shape[2]])
        u3 = self.up3(torch.cat([u2,c3],1))

        u3 = self.ConvT4(u3)
        c4 = TF.resize(d1,size = [u3.shape[2],u3.shape[2]])
        u4 = self.up4(torch.cat([u3,c4],1))

        output = self.output(u4)
        return torch.sigmoid(output)

In [None]:
model = Unet(1,1).to(device)
model.load_state_dict(torch.load('C:/Users/sonbu/Desktop/LVTN/Model/Unet/Unet_LIDC.pt'))
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_func = nn.BCELoss()


In [None]:
test_accuracy, pre, sen, dsc, spe, iou = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
length_data = len(test_loader)
model.eval()
with torch.no_grad():
    for batch_idx, (data, label) in enumerate((test_loader)):
        data, label = data.to(device), label.to(device)
        pred = model(data) 
        loss = loss_func(pred, label)

        pred[pred > 0.5] = 1
        pred[pred <= 0.5] = 0
        lb = label
        lb[lb > 0.5] = 1
        lb[lb <= 0.5] = 0
            
        total_test= lb.numel()
        correct_test = pred.eq(lb).sum().item()
        test_accuracy += (float(correct_test*100) / total_test)

        precision,recall, f1_score, sup = precision_recall_fscore_support(lb.view(-1, 1).cpu().numpy(), pred.view(-1, 1).cpu().numpy(), average = "binary")
        cm = confusion_matrix(lb.view(-1, 1).cpu().numpy(), pred.view(-1, 1).cpu().numpy())
        specificity = cm[0,0]/(cm[0,0]+cm[0,1])
        iou_score = jaccard_score(lb.view(-1, 1).cpu().numpy(), pred.view(-1, 1).cpu().numpy())
        pre += precision
        sen += recall
        dsc += f1_score
        spe += specificity
        
        #print('IoU: {:.2f}%'.format(iou_score*100))      

        plt.figure()
        plt.subplot(131)
        plt.title("Ảnh gốc")
        plt.imshow(data[0].view(img_size, img_size).cpu().detach().numpy())
        plt.subplot(132)
        plt.title("Ground Truth")
        plt.imshow(label[0].view(msk_size, msk_size).cpu().detach().numpy(), cmap='gray')
        plt.subplot(133)
        plt.title("Predicted")
        plt.imshow(pred[0].view(msk_size, msk_size).cpu().detach().numpy(), cmap='gray')
        plt.show()
    
    print('Accurcy: {:.2f} - DSC: {:.2f}% - PRE: {:.2f}% - SEN: {:.2f}% - SPE: {:.2f}%'.format(test_accuracy /length_data,dsc*100/length_data, pre*100/length_data, sen*100/length_data, spe*100/length_data))   