In [1]:
import os
from typing import Tuple, List, Callable, Iterator, Optional, Dict, Any
from collections import defaultdict

In [2]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import tensorboardX
from tensorboardX import SummaryWriter
import os

import pandas as pd 
import numpy as np
import glob
from tqdm import tqdm
import cv2
from sklearn.model_selection import train_test_split
import fiona
import segmentation_models_pytorch as smp

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.models import resnet18
from torchvision.utils import draw_segmentation_masks
import torchvision.transforms as T

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

import json
import secrets

In [3]:
from unet import *

In [4]:
model_type='unet'
model = UNet_classic(n_channels=3, n_classes=1, bilinear=False)
model = model.cuda()
model.eval()

UNet_classic(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_run

In [5]:
weights = "./weights/unet_1024_crops_dee943129a3beffb/weights_last_epoch.pth"
model.load_state_dict(torch.load(weights))

<All keys matched successfully>

In [9]:
out_dir = "unet_1024_crops_dee943129a3beffb_probs"

In [10]:
test_dataset_path = "../test_dataset_mc2/eye_test"
test_imgs = glob.glob(test_dataset_path + '/*.png')

In [11]:
def read_image(path: str) -> np.ndarray:
    image = cv2.imread(str(path), cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = np.array(image / 255, dtype=np.float32)
    return image

In [12]:
import torchvision.transforms as T
from PIL import Image
import PIL.ImageOps
transform = T.ToPILImage()

In [13]:
def make_mask(i, j, patch_size, overlap, patch):
    
    if i == 0:
        if j == 0:
            d = patch_size - overlap
            mask = np.ones((d, d, 3))
            mask = np.pad(mask, ((0, overlap), (0, overlap), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
        elif j+patch_size > 1623:
            mask = np.ones((patch.shape[0] - overlap, patch.shape[1] - overlap, 3))
            mask = np.pad(mask, ((0, overlap), (overlap, 0), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
            cust = np.zeros((patch_size, patch_size, 3))
            cust[:mask.shape[0], :mask.shape[1],:] = mask
            return cust
        else:
            d = patch_size - overlap
            mask = np.ones((d, d-overlap, 3))
            mask = np.pad(mask, ((0, overlap), (overlap, overlap), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
        return mask
    
    elif i + patch_size > 1231:
        if j == 0:
            mask = np.ones((patch.shape[0] - overlap, patch.shape[1] - overlap, 3))
            mask = np.pad(mask, ((overlap, 0), (0, overlap), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
            cust = np.zeros((patch_size, patch_size, 3))
            cust[:mask.shape[0], :mask.shape[1],:] = mask
            return cust
        elif j + patch_size > 1623:
            mask = np.ones((patch.shape[0] - overlap, patch.shape[1] - overlap, 3))
            mask = np.pad(mask, ((overlap, 0), (overlap, 0), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
            cust = np.zeros((patch_size, patch_size, 3))
            cust[:mask.shape[0], :mask.shape[1],:] = mask
            return cust
        else:
            mask = np.ones((patch.shape[0] - overlap, patch.shape[1] - 2*overlap, 3))
            mask = np.pad(mask, ((overlap, 0), (overlap, overlap), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
            cust = np.zeros((patch_size, patch_size, 3))
            cust[:mask.shape[0], :mask.shape[1],:] = mask
            return cust
    
    if j == 0:
        d = patch_size - overlap
        mask = np.ones((d-overlap, d, 3))
        mask = np.pad(mask, ((overlap, overlap), (0, overlap), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
    elif j + patch_size > 1623:
        mask = np.ones((patch.shape[0] - 2*overlap, patch.shape[1] - overlap, 3))
        mask = np.pad(mask, ((overlap, overlap), (overlap, 0), (0, 0)), "linear_ramp", end_values=(0.05, 0.05))
        
        cust = np.zeros((patch_size, patch_size, 3))
        cust[:mask.shape[0], :mask.shape[1],:] = mask
        return cust
    else:
        d = patch_size - 2*overlap
        mask = np.ones((d, d, 3))
        mask = np.pad(mask, ((overlap, overlap), (overlap, overlap), (0, 0)), "linear_ramp", end_values=(0.02, 0.02)) 
        
    return mask
    
    

In [15]:
def make_patches(image, patch_size, overlap, model, out_dir, name):
    h, w, _ = image.shape
    patches = []
    coords = []
    d = patch_size - overlap
    result = np.zeros(image.shape)
    for i in range(0, h, d):
        for j in range(0, w, d):
            x = min(h-1, i + patch_size)
            y = min(w-1, j + patch_size)
            patch = image[i:x,j:y,:]
            #print(i, j)
            if patch.shape == (patch_size, patch_size, 3):
                new_patch = patch
            else:
                tmp = np.zeros((patch_size, patch_size, 3))
                tmp[:patch.shape[0], :patch.shape[1],:] = patch
                new_patch = tmp
                
            new_patch = np.expand_dims(new_patch, 0)
            new_patch = np.transpose(new_patch, (0, 3, 1, 2))
            input = torch.tensor(new_patch, dtype=torch.float).cuda()
            #print(type(input))
            out = model(input)
            out = out.cpu().detach().numpy()[0]
            out = np.transpose(out, (1, 2, 0))
            mask = make_mask(i, j, patch_size, overlap, image[i:x,j:y,:])
            result[i:x, j:y, :] += (mask*out)[:(x-i),:(y-j),:]
    
    
    #result = np.array(np.where(result >0.5, 1, 0)*255, dtype='uint8')
    #cv2.imwrite(os.path.join(out_dir, name + ".png"), result)
    np.save(os.path.join(out_dir, name + ".npy"), out)

In [17]:
for img in test_imgs:
    name = img.split("/")[-1].split(".")[0]
    input_path = os.path.join(test_dataset_path, name + ".png")
    cur_image = read_image(input_path)
    make_patches(cur_image, 1024, 32, model, out_dir, name)