In [None]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pdimport os
import json
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out
    
class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]

        else:
            output = self.final(x0_4)
            
            return nn.Sigmoid()(output)

In [None]:
#given the confirmed folder path, return list of Image and corresponding label file path as a list
def make_pth_list(confirmed_pth):
    images_pth_list = []
    labels_pth_list = []
    confirmed_list = os.listdir(confirmed_pth)
    for num in confirmed_list:
        num_pth = os.path.join(confirmed_pth,num)

        num_images_pth = os.path.join(num_pth,"images/cur")
        num_images_list = os.listdir(num_images_pth)

        num_labels_pth = os.path.join(num_pth,"labels")
        num_labels_list = os.listdir(num_labels_pth)
        
        for i in range(len(num_labels_list)):
            num_pth = os.path.join(confirmed_pth,num)
            image_pth= os.path.join(num_images_pth,num_images_list[i])
            images_pth_list.append(image_pth)

            label_pth = os.path.join(num_labels_pth,num_labels_list[i])
            labels_pth_list.append(label_pth)
    return images_pth_list, labels_pth_list      

In [None]:
class MaskGenerator():
    """Segmentation Mask Generator.

    Notes:
        Pretrained Nested UNet generate segmentation mask from smoke-including images.
        Segmentation operate with image of size 128x128, with padding zeros to soure images to prevent boundary execption.
        Get a center position from given bbox label data.
        
        Parameters:
            pretrained (tensor): Given pretrained Nested UNet for direct segmentation
            image (str): Path of source image (.jpg)
            bbox (str): Path of bbox label data (.json)
            threshold (float): Segmentation output thresholding factor for binary mask synthetic
        
        Returns:
            mask (float): Segmented binary mask that has smoke
            masked_image (float): image conclude that only smoke from given source image
            image (float): Given source image
    """
    
    def __init__(self, pretrained, image, bbox, threshold = 0.5, target_size = 128., device = 'cpu'):
        super().__init__()
        
        self.crop = int(target_size / 2)
        self.segmentation = NestedUNet(1,3)
        self.segmentation.load_state_dict(torch.load(pretrained))
        if device == 'cuda':
            self.segmentation = self.segmentation.to(device)
        self.image_path = image
        self.bbox_path = bbox
        self.threshold = threshold
    
    def forward(self, idx):
        print(self.image_path)
        image = np.asarray(Image.open(self.image_path))/255.
        with open(self.bbox_path, 'r') as file:
            data = json.load(file)
        p_1 = data["objects"][0][0]["points"][0]
        p_2 = data["objects"][0][0]["points"][1]
        W1 = round((p_1["y"]+p_2["y"])/2) # Get a center position of given bounding box label
        H1 = round((p_1["x"]+p_2["x"])/2)
        w, h, c = image.shape
        
        image = torch.tensor(image, dtype = torch.float32).permute(2,0,1).unsqueeze(0)
        pad = nn.ZeroPad2d(self.crop) # zero pad for a case that bbox center is at boundary
        mask = pad(torch.zeros((1, 1, w, h)))
        image_pad = pad(image)
        
        bbox_center = [W1, H1]
        image_cropped = image_pad[:, :,
                                  bbox_center[0]:bbox_center[0]+self.crop * 2,   # 128 x 128 crop, considering padded boundary
                                  bbox_center[1]:bbox_center[1]+self.crop * 2]
        segmented = self.segmentation.forward(image_cropped)
        mask[0, 0,
             bbox_center[0]:bbox_center[0]+self.crop * 2,   # 128 x 128 crop, considering padded boundary
             bbox_center[1]:bbox_center[1]+self.crop * 2] = segmented
        
        mask = mask[0, 0, :, :] # 2d mask size of (W+pad, H+pad), it prevents boundary exception
        mask_stack = torch.stack((mask, mask, mask), dim = 0) # 3 channel stacking for masking
        masked_image = mask_stack * image_pad.squeeze(0) # no batch
        
        mask = mask[self.crop:self.crop + w, self.crop:self.crop + h] # crop (W, H)
        masked_image = masked_image[:, self.crop:self.crop + w, self.crop:self.crop + h] # crop (3, W, H)
        mask = mask.cpu().detach().numpy()
        mask[mask<self.threshold] = 0
        mask[mask>self.threshold] = 1 #
        masked_image = masked_image.permute(1,2,0).cpu().detach().numpy() # convert to PIL format
        return mask, masked_image, image

In [None]:
import matplotlib.pyplot as plt

def save_mask(confirmed_pth, model_pth):
    confirmed_list = os.listdir(confirmed_pth)
        num_pth = os.path.join(confirmed_pth,num)
        
        ##make mask directory
        num_masks_pth = os.path.join(num_pth,'masks')
        if os.path.exists(num_masks_pth)==False:
            os.mkdir(num_masks_pth)
        
        #make image list
        num_images_pth = os.path.join(num_pth,"images/cur")
        num_images_list = os.listdir(num_images_pth)
        
        #make label list
        num_labels_pth = os.path.join(num_pth,"labels")
        num_labels_list = os.listdir(num_labels_pth)
        
        
        for i in range(len(num_labels_list)):
            image_file_name = num_images_list[i] 
            mask_pth = os.path.join(num_masks_pth,image_file_name)
            
            image_pth= os.path.join(num_images_pth,num_images_list[i])

            label_pth = os.path.join(num_labels_pth,num_labels_list[i])
            x = MaskGenerator(model_pth,image_pth,label_pth)
            mask,masked_image,image = x.forward(0)
            mask_array = [mask,mask,mask]
            final_mask = np.stack(mask_array, axis=2)
            plt.imsave(mask_pth,final_mask)
    return   

In [None]:
model_pth = "./pretrained.pth"
confirmed_pth = "./datasets/confirmed"
images_pth_list, labels_pth_list = make_pth_list(confirmed_pth)

In [None]:
save_mask(confirmed_pth, model_pth)