In [1]:
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from torch.utils.data import Dataset
import PIL.Image as Image
import os
import matplotlib.pyplot as plt
import numpy as np
import math
import cv2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class Attention_SEM_block_1(nn.Module):
    """
    Attention_SEM_block_1
    """

    def __init__(self, ch_1, ch_2, ch_3, ch_4, ch_5):
        super(Attention_SEM_block_1, self).__init__()
        self.n_channels = ch_1
        
        self.W_2 = nn.Sequential(
            nn.Conv2d(ch_2, ch_1, 3, padding=1),
            nn.BatchNorm2d(ch_1),
            nn.ReLU(inplace=True),
        )
        
        self.W_3 = nn.Sequential(
            nn.Conv2d(ch_3, ch_1, 3, padding=1),
            nn.BatchNorm2d(ch_1),
            nn.ReLU(inplace=True),
        )

        self.W_4 = nn.Sequential(
            nn.Conv2d(ch_4, ch_1, 3, padding=1),
            nn.BatchNorm2d(ch_1),
            nn.ReLU(inplace=True),
        )

        self.W_5 = nn.Sequential(
            nn.Conv2d(ch_5, ch_1, 3, padding=1),
            nn.BatchNorm2d(ch_1),
            nn.ReLU(inplace=True),
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(ch_1, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, layer1, layer2, layer3, layer4, layer5):
        conv2 = self.W_2(layer2)
        conv3 = self.W_3(layer3)
        conv4 = self.W_4(layer4)
        conv5 = self.W_5(layer5)
        w_2 = torch.nn.functional.interpolate(conv2, size = (layer1.shape[2], layer1.shape[3]), mode='bilinear', align_corners=False)
        w_3 = torch.nn.functional.interpolate(conv3, size = (layer1.shape[2], layer1.shape[3]), mode='bilinear', align_corners=False)
        w_4 = torch.nn.functional.interpolate(conv4, size = (layer1.shape[2], layer1.shape[3]), mode='bilinear', align_corners=False)
        w_5 = torch.nn.functional.interpolate(conv5, size = (layer1.shape[2], layer1.shape[3]), mode='bilinear', align_corners=False)
        
        psi = self.relu(layer1 + w_2 + w_3 + w_4 + w_5)
        psi = self.psi(psi)
        output = layer1 * psi
        
        return output

In [3]:
class Attention_SEM_block_2(nn.Module):
    """
    Attention_SEM_block_2
    """
    def __init__(self, ch_2, ch_3, ch_4, ch_5):
        super(Attention_SEM_block_2, self).__init__()
        self.n_channels = ch_2
        
        self.W_3 = nn.Sequential(
            nn.Conv2d(ch_3, ch_2, 3, padding=1),
            nn.BatchNorm2d(ch_2),
            nn.ReLU(inplace=True),
        )

        self.W_4 = nn.Sequential(
            nn.Conv2d(ch_4, ch_2, 3, padding=1),
            nn.BatchNorm2d(ch_2),
            nn.ReLU(inplace=True),
        )

        self.W_5 = nn.Sequential(
            nn.Conv2d(ch_5, ch_2, 3, padding=1),
            nn.BatchNorm2d(ch_2),
            nn.ReLU(inplace=True),
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(ch_2, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, layer2, layer3, layer4, layer5):
        conv3 = self.W_3(layer3)
        conv4 = self.W_4(layer4)
        conv5 = self.W_5(layer5)
        w_3 = torch.nn.functional.interpolate(conv3, size = (layer2.shape[2], layer2.shape[3]), mode='bilinear', align_corners=False)
        w_4 = torch.nn.functional.interpolate(conv4, size = (layer2.shape[2], layer2.shape[3]), mode='bilinear', align_corners=False)
        w_5 = torch.nn.functional.interpolate(conv5, size = (layer2.shape[2], layer2.shape[3]), mode='bilinear', align_corners=False)
        
        psi = self.relu(layer2 + w_3 + w_4 + w_5)
        psi = self.psi(psi)
        output = layer2 * psi
        
        return output

In [4]:
class Attention_SEM_block_3(nn.Module):
    """
    Attention_SEM_block_3
    """
    def __init__(self, ch_3, ch_4, ch_5):
        super(Attention_SEM_block_3, self).__init__()
        self.n_channels = ch_3

        self.W_4 = nn.Sequential(
            nn.Conv2d(ch_4, ch_3, 3, padding=1),
            nn.BatchNorm2d(ch_3),
            nn.ReLU(inplace=True),
        )

        self.W_5 = nn.Sequential(
            nn.Conv2d(ch_5, ch_3, 3, padding=1),
            nn.BatchNorm2d(ch_3),
            nn.ReLU(inplace=True),
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(ch_3, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, layer3, layer4, layer5):
        conv4 = self.W_4(layer4)
        conv5 = self.W_5(layer5)
        w_4 = torch.nn.functional.interpolate(conv4, size = (layer3.shape[2], layer3.shape[3]), mode='bilinear', align_corners=False)
        w_5 = torch.nn.functional.interpolate(conv5, size = (layer3.shape[2], layer3.shape[3]), mode='bilinear', align_corners=False)
        
        psi = self.relu(layer3 + w_4 + w_5)
        psi = self.psi(psi)
        output = layer3 * psi
        
        return output

In [5]:
class Attention_SEM_block_4(nn.Module):
    """
    Attention_SEM_block_4
    """
    def __init__(self, ch_4, ch_5):
        super(Attention_SEM_block_4, self).__init__()
        self.n_channels = ch_4

        self.W_5 = nn.Sequential(
            nn.Conv2d(ch_5, ch_4, 3, padding=1),
            nn.BatchNorm2d(ch_4),
            nn.ReLU(inplace=True),
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(ch_4, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, layer4, layer5):
        conv5 = self.W_5(layer5)
        w_5 = torch.nn.functional.interpolate(conv5, size = (layer4.shape[2], layer4.shape[3]), mode='bilinear', align_corners=False)
        
        psi = self.relu(layer4 + w_5)
        psi = self.psi(psi)
        output = layer4 * psi
        
        return output

In [6]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class SEC_Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(SEC_Unet, self).__init__()
        self.n_channels = in_ch
        self.n_classes = out_ch
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        
        # self.SEM1 = Attention_SEM_block_1(64, 128, 256, 512, 1024)
        self.SEM2 = Attention_SEM_block_2(128, 256, 512, 1024)
        self.SEM3 = Attention_SEM_block_3(256, 512, 1024)
        self.SEM4 = Attention_SEM_block_4(512, 1024)
        
        self.up6_mask = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6_mask = DoubleConv(1024, 512)
        self.up7_mask = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7_mask = DoubleConv(512, 256)
        self.up8_mask = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8_mask = DoubleConv(256, 128)
        self.up9_mask = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9_mask = DoubleConv(128, 64)
        
        self.up6_contour = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6_contour = DoubleConv(1024, 512)
        self.up7_contour = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7_contour = DoubleConv(512, 256)
        self.up8_contour = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8_contour = DoubleConv(256, 128)
        self.up9_contour = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9_contour = DoubleConv(128, 64)
        
        self.conv_mask = nn.Conv2d(64,out_ch, 1)
        self.conv_contour = nn.Conv2d(64,out_ch, 1)

    def forward(self,x):
        c1=self.conv1(x)
        p1=self.pool1(c1)
        c2=self.conv2(p1)
        p2=self.pool2(c2)
        c3=self.conv3(p2)
        p3=self.pool3(c3)
        c4=self.conv4(p3)
        p4=self.pool4(c4)
        c5=self.conv5(p4)
        
        # sem1 = self.SEM1(c1, c2, c3, c4, c5)
        sem2 = self.SEM2(c2, c3, c4, c5)
        sem3 = self.SEM3(c3, c4, c5)
        sem4 = self.SEM4(c4, c5)
        
        up_6_mask = self.up6_mask(c5)
        merge6_mask = torch.cat([up_6_mask, sem4], dim=1)
        c6_mask=self.conv6_mask(merge6_mask)
        up_7_mask=self.up7_mask(c6_mask)
        merge7_mask = torch.cat([up_7_mask, sem3], dim=1)
        c7_mask=self.conv7_mask(merge7_mask)
        up_8_mask=self.up8_mask(c7_mask)
        merge8_mask = torch.cat([up_8_mask, sem2], dim=1)
        c8_mask=self.conv8_mask(merge8_mask)
        up_9_mask=self.up9_mask(c8_mask)
        merge9_mask=torch.cat([up_9_mask, c1],dim=1)
        c9_mask=self.conv9_mask(merge9_mask)
        
        mask=self.conv_mask(c9_mask)
        
        up_6_contour = self.up6_contour(c5)
        merge6_contour = torch.cat([up_6_contour, c4], dim=1)
        c6_contour=self.conv6_contour(merge6_contour)
        up_7_contour=self.up7_contour(c6_contour)
        merge7_contour = torch.cat([up_7_contour, c3], dim=1)
        c7_contour=self.conv7_contour(merge7_contour)
        up_8_contour=self.up8_contour(c7_contour)
        merge8_contour = torch.cat([up_8_contour, c2], dim=1)
        c8_contour=self.conv8_contour(merge8_contour)
        up_9_contour=self.up9_contour(c8_contour)
        merge9_contour=torch.cat([up_9_contour, c1],dim=1)
        c9_contour=self.conv9_contour(merge9_contour)
        
        contour=self.conv_contour(c9_contour)
        
        return mask, contour

In [7]:
def make_dataset(root):
    imgs=[]
    for filename in os.listdir(root):
        tag = filename.split('.')[0][-1]
        if tag != 'r' and tag != 'k':            
            img = os.path.join(root, filename)
            mask = os.path.join(root, filename.split('.')[0] + '_mask.png')
            contour = os.path.join(root, filename.split('.')[0] + '_contour.png')
            imgs.append((img,mask, contour))
    return imgs


class BacteriaDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y1_path,y2_path = self.imgs[index]

        img_x = Image.open(x_path)
        img_y1 = Image.open(y1_path)
        img_y1 = img_y1.convert('L')
        img_y2 = Image.open(y2_path)
        img_y2 = img_y2.convert('L')
        
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y2 = self.target_transform(img_y2)
            img_y1 = self.target_transform(img_y1)
            
        return img_x, img_y1, img_y2

    def __len__(self):
        return len(self.imgs)

In [8]:
x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

y_transforms = transforms.Compose([
    transforms.ToTensor(),
])

In [9]:
def train_model(model, criterion, optimizer, dataload, w_contour=1.0, w_mask=0.8, num_epochs=20):
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        
        for x, y_mask, y_contour in dataload:
            step += 1
            inputs = x.to(device)
            labels_contour = y_contour.to(device)
            labels_mask = y_mask.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward
            mask_pred, contour_pred = model(inputs)
            loss_mask = w_mask * criterion(mask_pred, labels_mask)
            loss_contour = w_contour * criterion(contour_pred, labels_contour)

            loss = (loss_mask + loss_contour)/(w_contour + w_mask)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
    torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
    return model

In [None]:
model = SEC_Unet(3, 1).to(device)
batch_size = 5
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
bacteria_dataset = BacteriaDataset("./dataset/train",transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(bacteria_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
train_model(model, criterion, optimizer, dataloaders, num_epochs=9)

In [None]:
model = model.cpu()
bacteria_dataset_test = BacteriaDataset("dataset/test", transform=x_transforms,target_transform=y_transforms)
dataloaders_test = DataLoader(bacteria_dataset_test, batch_size=1)

In [None]:
class Score():
    def __init__(self, y_pred, y_true, size = 512, threshold = 0.5):
        self.TN = 0
        self.FN = 0
        self.FP = 0
        self.TP = 0
        self.y_pred = y_pred > threshold
        self.y_true = y_true
        self.threshold = threshold
        
        for i in range(0, size):
            for j in range(0, size):
                if self.y_pred[i,j] == 1:
                    if self.y_pred[i,j] == self.y_true[i][j]:
                        self.TP = self.TP + 1
                    else:
                        self.FP = self.FP + 1
                else:
                    if self.y_pred[i,j] == self.y_true[i][j]:
                        self.TN = self.TN + 1
                    else:
                        self.FN = self.FN + 1        
 
    def get_Se(self):
        return (self.TP)/(self.TP + self.FN)
    
    def get_Sp(self):
        return (self.TN)/(self.TN + self.FP)
    
    def get_Pr(self):
        return (self.TP)/(self.TP + self.FP)
    
    def F1(self):
        Pr = self.get_Pr()
        Se = self.get_Se()
        return (2*Pr*Se)/(Pr + Se)
    
    def G(self):
        Sp = self.get_Sp()
        Se = self.get_Se()
        return math.sqrt(Se*Sp)
    
    def IoU(self):
        Pr = self.get_Pr()
        Se = self.get_Se()
        return (Pr*Se) /(Pr + Se - Pr*Se)
    
    def DSC(self):
        return (2* self.TP)/(2* self.TP + self.FP + self.FN) 
    
    def PA(self):
        return (self.TP + self.TN)/(self.TP + self.FP + self.FN + self.TN)  

In [None]:
F1 = []
Pre = []
Se = []
PA = []
Sp = []
with torch.no_grad():
    for x, target1, target2 in dataloaders_test:
        y = model(x)
        y_pred = torch.squeeze(y[0] - 0.3*y[1]).numpy()
        y_true = torch.squeeze(target1).numpy()
        y_score = Score(y_pred, y_true, size = 512)
        # print((y_score.F1(), y_score.get_Se(), y_score.get_Pr()))
        F1 = np.append(F1, y_score.F1())
        Pre = np.append(Pre, y_score.get_Pr())
        Se = np.append(Se, y_score.get_Se())
        PA = np.append(PA, y_score.PA())
        Sp = np.append(Sp, y_score.get_Sp())

In [None]:
F1_mean = np.mean(F1)
F1_std = np.std(F1,ddof=1)
print(F1_mean)
print(F1_std)