In [31]:
import numpy as np
import cv2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import ternausnet
import ternausnet.models
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset
import matplotlib.pyplot as plt

In [32]:
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_names = os.listdir(img_dir) 

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx]) # Assumes masks have same filenames as images

        image = Image.open(img_path)
        if image.mode != 'RGB':
            image = Image.merge("RGB", (image, image, image))

        image = self.transform(image)

        return image,self.img_names[idx]

In [33]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
dataset = SegmentationDataset(img_dir='test_image_2_stage_prediction',transform=transform)
batch_S = 1
dataloader = DataLoader(dataset, batch_size=batch_S, shuffle=True)

In [34]:
class DiceLoss(nn.Module):
    def forward(self, input, target):
        smooth = 1.
        iflat = input.view(-1)
        tflat = target.view(-1)
        intersection = (iflat * tflat).sum()
        
        return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

class IoULoss(nn.Module):
    def forward(self, input, target):
        smooth = 1.
        intersection = (input * target).sum()
        total = (input + target).sum()
        union = total - intersection 
        
        return 1 - ((intersection + smooth) / (union + smooth))
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [35]:
model_vgg16_colorised = ternausnet.models.UNet11()
vgg16_colorised_model_path = "path_files/model_vgg11_colorised_4.pth"
checkpoint_vgg16_colorised = torch.load(
    vgg16_colorised_model_path, map_location=device)
model_vgg16_colorised.load_state_dict(
    checkpoint_vgg16_colorised['model_state_dict'])
model_vgg16_colorised = model_vgg16_colorised.to(device)
model_vgg16_colorised.eval()

model_vgg11_colorised = ternausnet.models.UNet11()
vgg11_colorised_model_path = "path_files/model_vgg11_colorised_4.pth"
checkpoint_vgg11_colorised = torch.load(
    vgg11_colorised_model_path, map_location=device)
model_vgg11_colorised.load_state_dict(
    checkpoint_vgg11_colorised['model_state_dict'])
model_vgg11_colorised = model_vgg11_colorised.to(device)
model_vgg11_colorised.eval()

model_vgg11_patchMask = ternausnet.models.UNet11()
model_vgg11_patchMask_path = "path_files/model_vgg11_patched_4.pth"
checkpoint_vgg11_patchMask = torch.load(
    model_vgg11_patchMask_path, map_location=device)
model_vgg11_patchMask.load_state_dict(
    checkpoint_vgg11_patchMask['model_state_dict'])
model_vgg11_patchMask = model_vgg11_patchMask.to(device)
model_vgg11_patchMask.eval()



UNet11(
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)

In [37]:
with torch.no_grad():
    k = 0  # Don't compute gradients
    for images, img_name in dataloader:
        k += 1
        images = images.to(device)
        img_names = img_name[0].split(".")[0]
        output_VGG16_color = model_vgg16_colorised(images)
        output_VGG16_color = torch.sigmoid(output_VGG16_color)

        output_VGG11_color = model_vgg11_colorised(images)
        output_VGG11_color = torch.sigmoid(output_VGG11_color)

        output_VGG11_patchMask = model_vgg11_patchMask(images)
        output_VGG11_patchMask = torch.sigmoid(output_VGG11_patchMask)

        main_model_input = torch.cat(
            (output_VGG16_color, output_VGG11_color, output_VGG11_patchMask), dim=1)
        # print(main_model_input.shape)
        out = main_model_input.permute(0, 2, 3, 1)
        out = out.cpu().detach().squeeze().numpy()
        image_0 = out[:, :, 0]
        
        image_1 = out[:, :, 1]
        image_2 = out[:, :, 2]
        img_0 = (image_0 > 0.5).astype(np.uint8)
        img_1 = (image_1 > 0.5).astype(np.uint8)
        img_2 = (image_2 > 0.5).astype(np.uint8)
        # plt.imshow(img_0)
        # plt.show()
        intersection_0_1 = np.minimum(img_0, img_1)
        image = np.minimum(intersection_0_1, img_2)
        IMG = image
        img = np.array(IMG)*255.0
        
        cv2.imwrite(f"model_outputs/{img_names}.png", img)