In [None]:
import time
import os

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 

import torchvision
import torchvision.transforms.functional as TF

from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.autonotebook import tqdm 

from sklearn.metrics import precision_recall_fscore_support,confusion_matrix, jaccard_score

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
img_size = 68#538

epochs = 500
batch_size = 32
learning_rate = 0.2

features = [64, 128, 256, 512]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform = None,target_transform=None):
        self.img = img_dir
        self.mask = mask_dir
        self.transform = transform
        self.target_transform = target_transform
        self.imgpath = os.listdir(img_dir)
        self.mskpath = os.listdir(mask_dir)

    def __len__(self):
        return len(self.imgpath)        
    
    def __getitem__(self, index):
        img_path = os.path.join(self.img, self.imgpath[index])
        mask_path = os.path.join(self.mask, self.mskpath[index])

        image = cv.resize(cv.imread(img_path,cv.IMREAD_GRAYSCALE),(img_size,img_size))
        mask = cv.resize(cv.imread(mask_path,cv.IMREAD_GRAYSCALE),(img_size,img_size))
        ret,mask = cv.threshold(mask,127,1,cv.THRESH_BINARY)
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        
        return (image, mask)

In [None]:
path_data_train = "C:/Users/sonbu/Desktop/LVTN/Dataset/LIDC-IDRI/Images"
path_mask_train = "C:/Users/sonbu/Desktop/LVTN/Dataset/LIDC-IDRI/Masks"

dataset = CustomDataset(img_dir = path_data_train, mask_dir = path_mask_train,transform = transforms.ToTensor(), target_transform = transforms.Compose([
                                  lambda x:torch.LongTensor([x]),
                                  lambda x:F.one_hot(x, 2)
                                 ]))

In [None]:
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train, val, test = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, inputs):
        return self.conv(inputs)

In [None]:
class DoubleUnConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleUnConv, self).__init__()
        self.unconv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, inputs):
        return self.unconv(inputs)

In [None]:
class TripleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TripleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.conv(inputs)

In [None]:
class TripleUnConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TripleUnConv, self).__init__()
        self.unconv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size = 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.unconv(inputs)

In [None]:
class Segnet(nn.Module):
    def __init__(self,in_c,out_c):
        super().__init__()
        self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        
        self.conv1 = DoubleConv(in_c,features[0])
        self.conv2 = DoubleConv(features[0],features[1])
        self.conv3 = TripleConv(features[1],features[2])
        self.conv4 = TripleConv(features[2],features[3])
        self.conv5 = TripleConv(features[3],features[3])
        
        self.unconv1 = TripleUnConv(features[3],features[3])
        self.unconv2 = TripleUnConv(features[3],features[2])
        self.unconv3 = TripleUnConv(features[2],features[1])
        self.unconv4 = DoubleUnConv(features[1],features[0])
        self.unconv5 = DoubleUnConv(features[0],out_c)
    
    def forward(self, inputs):
        #-------------ENCODE-------------#
        c1 = self.conv1(inputs)
        p1,i1 = self.pool(c1)

        c2 = self.conv2(p1)
        p2,i2 = self.pool(c2)
        
        c3 = self.conv3(p2)
        p3,i3 = self.pool(c3)
        
        c4 = self.conv4(p3)
        p4,i4 = self.pool(c4)
        
        c5 = self.conv5(p4)
        p5,i5 = self.pool(c5)

        #-------------DECODE-------------#
        u1 = self.unpool(p5, indices = i5,output_size = c5.shape)
        uc1 = self.unconv1(u1)

        u2 = self.unpool(uc1, indices = i4,output_size = c4.shape)
        uc2 = self.unconv2(u2)
        
        u3 = self.unpool(uc2, indices = i3,output_size = c3.shape)
        uc3 = self.unconv3(u3)

        u4 = self.unpool(uc3, indices = i2,output_size = c2.shape)
        uc4 = self.unconv4(u4)
        
        u5 = self.unpool(uc4, indices = i1,output_size = c1.shape)
        
        output = self.unconv5(u5)

        return output

In [None]:

model = Segnet(1,2).to(device)
model.load_state_dict(torch.load('C:/Users/sonbu/Desktop/LVTN/Model/Segnet/Segnet_LIDC.pt'))
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_func = nn.CrossEntropyLoss()

In [None]:
test_accuracy, pre, sen, dsc, spe, iou = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
length_data = len(test_loader)

model.eval()

with torch.no_grad():
    for batch_idx, (data, label) in enumerate((test_loader)):
        
        data, label = data.to(device), label.to(device)
        label = label.squeeze().permute(0,3,1,2).float()
        
        pred = model(data)
        
        loss = loss_func(pred, label)
        label = torch.argmax(label,dim=1)
        pred = (1-F.softmax(pred,dim=1))[:,0,:,:]
       
        pred[pred > 0.25] = 1
        pred[pred <= 0.25] = 0
        lb = label
        lb[lb > 0.5] = 1
        lb[lb <= 0.5] = 0
            
        total_test= lb.numel()
        correct_test = pred.eq(lb).sum().item()
        test_accuracy += (float(correct_test*100) / total_test)

        precision,recall, f1_score, sup = precision_recall_fscore_support(lb.reshape(-1, 1).cpu().numpy(), pred.reshape(-1, 1).cpu().numpy(), average = "binary")
        cm = confusion_matrix(lb.reshape(-1, 1).cpu().numpy(), pred.reshape(-1, 1).cpu().numpy())
        specificity = cm[0,0]/(cm[0,0]+cm[0,1])
        iou_score = jaccard_score(lb.reshape(-1, 1).cpu().numpy(), pred.reshape(-1, 1).cpu().numpy())
        pre += precision
        sen += recall
        dsc += f1_score
        spe += specificity       
        
        plt.figure()
        plt.subplot(131)
        plt.title("Ảnh gốc")
        plt.imshow(data[0].reshape(img_size, img_size).cpu().detach().numpy())
        plt.subplot(132)
        plt.title("Ground Truth")
        plt.imshow(label[0].reshape(img_size, img_size).cpu().detach().numpy(), cmap='gray')
        plt.subplot(133)
        plt.title("Predicted")
        plt.imshow(pred[0].reshape(img_size, img_size).cpu().detach().numpy(), cmap='gray')
        plt.show()

    print('Accurcy: {:.2f} - DSC: {:.2f}% - PRE: {:.2f}% - SEN: {:.2f}% - SPE: {:.2f}%'.format(test_accuracy /length_data,dsc*100/length_data, pre*100/length_data, sen*100/length_data, spe*100/length_data))   