In [35]:
# Sources
# https://medium.com/@navneetkumar11/loading-image-data-from-google-drive-to-google-colab-using-pytorchs-dataloader-2e5617978a63
# https://github.com/pytorch/examples

from __future__ import print_function
import argparse
import numpy as np
from random import sample
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import copy
import os, os.path
from PIL import Image
from torchvision.utils import save_image
from torch.autograd import Variable
import time
import cv2
from skimage.measure import label 

In [36]:
project_base = '/home/ks/Projects/Microplastics'

# Arguments for semi-synthetic data
class ArgumentsSSD:
  # directory containing the raw training images
  # images and masks should be seperated to the foldders imgs and masks
  segmentation_raw = project_base + '/Data/segmentation_raw/'
  # directory containing microplastics with count of 1
  # do not seperate the images and masks to seperate folders, place them all in the same folder
  count_1_path = project_base + '/Data/segmentation_raw/mask_template/'
  # number of images generated will be num_batches * batch_size
  num_batches = 1
  batch_size = 30 # 20 # 10
  min_count = 8
  max_count = 14
  crop_height = 1536
  crop_width = 2048
  mode = 5  # 1 = uniform bg.+ green mp; 2 = uniform bg. + rand. colour; 3 = clut. bg. + green mp;
            # 4 = clut. bg + rand. col.; 5 = clut. bg. + actual/raw mp.
  invert_imgs = True
  data_transforms = transforms.Compose([
                        transforms.CenterCrop((crop_height, crop_width)),
                        transforms.ToTensor()]);
    
  final_transform_norm = transforms.Normalize((0.1307,), (0.3081,))

In [37]:
# helper function to un-normalize and display an image
def imshow(img, a_title):
    img = img / 2 + 0.5  # unnormalize
    plt.title(a_title)
    plt.imshow(np.transpose(img, (1, 2, 0)))  # convert from Tensor image
    
def disp_image(img, title):
    # ipdb.set_trace()
    img = img.numpy() # convert images to numpy for display
    # plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(25, 4))
    imshow(img, title)
    # Display images
    #for idx in np.arange(num_images):
    #    ax = fig.add_subplot(num_rows, 10, idx+1, xticks=[], yticks=[])
    #    imshow(images[idx], labels[idx].item())
    
def disp_image_grayscale(img, title):
    # ipdb.set_trace()
    img = img.numpy() # convert images to numpy for display
    # plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(25, 4))
    plt.imshow(img, cmap="gray")
    # Display images
    #for idx in np.arange(num_images):
    #    ax = fig.add_subplot(num_rows, 10, idx+1, xticks=[], yticks=[])
    #    imshow(images[idx], labels[idx].item())

def disp_batch_images(data, labels):
    # ipdb.set_trace()
    images = data.numpy() # convert images to numpy for display
    num_images = images.shape[0]
    num_rows = np.ceil(num_images/10)
    # plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(25, 4))
    # Display images
    for idx in np.arange(num_images):
        ax = fig.add_subplot(num_rows, 10, idx+1, xticks=[], yticks=[])
        imshow(images[idx], labels[idx].item())

# Save a particular image
def image_save(image, full_path):
  max_val = image.max()
  #norm_image = ((image/max_val)*255.)
  norm_image = image
  #ipdb.set_trace()
  save_image(norm_image, full_path)

# Save a batch 
def save_batch(data, target, epoch, root_path):
  # Get number of images
  num_imgs = data.shape[0]
  # Scan through images
  for ii in range(num_imgs):
    # Create filename: batch_x_y_z, where x = epoch number, y = img. num.; z=target/count
    filename = 'batch_img_' + str(epoch) + '_' + str(ii+1) + '_' + str(target[ii].item()) + '.tiff'
    # Create path
    path = root_path + filename
    # Save
    image_save(data[ii], path)

# Adapted from https://discuss.pytorch.org/t/how-to-classify-single-image-using-loaded-net/1411/2
def image_loader(transformations, image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name)
    image = transformations(image).float()
    #image = Variable(image, requires_grad=True)
    #image = image.unsqueeze(0)  #this is for VGG, may not be needed for ResNet
    #return image.cuda()  #assumes that you're using GPU
    return image

def get_pixel_raw(mp_img,mp_row_i,mp_col_i,a_pix):
  return mp_img[:,mp_row_i,mp_col_i]

def get_pixel_green(mp_img,mp_row_i,mp_col_i,a_pix):
    pixel = mp_img[:,mp_row_i,mp_col_i]
    pixel[0] = 0.
    pixel[1] = 1. # 255.
    pixel[2] = 0.
    return pixel

def get_pixel_rand(mp_img,mp_row_i,mp_col_i,a_pix):
    pixel = mp_img[:,mp_row_i,mp_col_i]
    pixel[0] = a_pix[0]
    pixel[1] = a_pix[1]
    pixel[2] = a_pix[2]
    return pixel

In [38]:
# Insert one microplastic
def insert_one_mp(image,coords,rand_x,rand_y,mp_img,args_ssd,a_pix):
  # Basic information
  num_coord = coords.shape[0]
  # Select appropriate get_pixel function
  if ((args_ssd.mode == 1) or (args_ssd.mode == 3)):
    a_get_pixel = get_pixel_green
  elif ((args_ssd.mode == 2) or (args_ssd.mode == 4)):
    a_get_pixel = get_pixel_rand
  else:
    a_get_pixel = get_pixel_raw
  # Scan coordinates
  # ipdb.set_trace()
  for ci in range(num_coord):
    # Get pixel value
    mp_row_i = coords[ci,0]
    mp_col_i = coords[ci,1]
    pixel = a_get_pixel(mp_img,mp_row_i,mp_col_i,a_pix)
    # ipdb.set_trace()
    # Get translated coordinate for the insertion
    img_row_i = rand_y + mp_row_i
    img_col_i = rand_x + mp_col_i
    #ipdb.set_trace()
    # Insert the pixel
    image[:,img_row_i,img_col_i] = 0.35 * image[:,img_row_i, img_col_i] + 0.65 * pixel
  
  return image

# Insert one mask
def insert_one_mask(mask,coords,rand_x,rand_y,mp_mask,args_ssd):
  # Basic information
  num_coord = coords.shape[0]
  # Scan coordinates
  # ipdb.set_trace()
  for ci in range(num_coord):
    # Get pixel value
    mp_row_i = coords[ci,0]
    mp_col_i = coords[ci,1]
    # ipdb.set_trace()
    # Get translated coordinate for the insertion
    mask_row_i = rand_y + mp_row_i
    mask_col_i = rand_x + mp_col_i
    #ipdb.set_trace()
    # Insert the pixel
    mask[mask_row_i,mask_col_i] = mp_mask[mp_row_i, mp_col_i]
  
  return mask

# Create a random bright pixel
def get_rand_pix_col():
  rp = np.random.random_sample(3)
  return rp


# Insert microplastics
# image --> torch.Size([3, 1400, 1900])
def insert_microplastics(image, mask, all_mp_img, all_mp_mask, count, args_ssd):
  # Basic information
  num_c1 = len(all_mp_mask)
  num_channels, height, width = image.shape
  if (args_ssd.mode < 3):
    image = torch.ones((num_channels, height, width))
  
  # Loop through microplastics
  for ci in range(count):
    # Select a random microplastic
    rand_mp_ind = np.random.randint(0,num_c1)
    # Extract microplastic
    mp_img = all_mp_img[rand_mp_ind]
    mp_mask = all_mp_mask[rand_mp_ind]
    #print('mp_img shape: {0}'.format(mp_img.shape))
    #print('mp_mask shape: {0}'.format(mp_mask.shape))
    # Apply a random rotation to the microplastic
    rand_rotat = np.random.randint(-180,180)
    rot_transform = transforms.Compose([ # not composing anything for now
                       transforms.RandomRotation((rand_rotat,rand_rotat),expand=True,fill=0)])
    # Convert to PIL
    conv2PIL = transforms.ToPILImage() # (mode='RGB')
    conv2tensor = transforms.ToTensor()
    # Rotate
    mp_img = conv2tensor(rot_transform(conv2PIL(mp_img)))
    mp_mask = conv2tensor(rot_transform(conv2PIL(mp_mask)))
    mp_mask = torch.squeeze(mp_mask)
    # ipdb.set_trace()
    # Get coordinates of microplastic
    #ipdb.set_trace()
    coords = (mp_mask > 0.6).nonzero()
    # Get max coordinates
    max_height = torch.max(coords[:,0])
    max_width = torch.max(coords[:,1])
    # Generate a random position
    rand_x = np.random.randint(0,width-max_width)
    #print(height)
    #print(max_height)
    rand_y = np.random.randint(0,height-max_height)
    #print(height)
    #print(max_height)
    # print('x: {0}, y:{1}'.format(rand_x,rand_y))
    # Insert the microplastics into the count-zero image.
    rand_pixel = get_rand_pix_col()
    # print('rand_pixel: {0}'.format(rand_pixel))
    image = insert_one_mp(image,coords,rand_x,rand_y,mp_img,args_ssd,rand_pixel)  
    mask = insert_one_mask(mask,coords,rand_x,rand_y,mp_mask,args_ssd)  

  return image, mask

In [39]:
# Get all microplastics
def get_all_microplastics(count_1_path, num_c1):
  # Initializations
  all_imgs = []
  all_masks = []
  a_transform = transforms.Compose([
                       transforms.ToTensor()]) # not the most efficient, but convenient for now
  # Note: ToTensor normalizes to [0,1]
  # Scan images
  image_list = os.listdir(count_1_path)
  image_list = [im for im in image_list if 'tiff' not in im]
  # print(image_list)
  for im_name in image_list:
    # Load mask
    img_path = os.path.join(count_1_path, im_name)
    mask_path = img_path.replace('jpeg', 'tiff')
    a_mask_3d = image_loader(a_transform, mask_path)
    a_mask, _ = torch.max(a_mask_3d, dim=0)
    an_image = image_loader(a_transform, img_path)
    # Get mask image size
    # num_channels, height, width = a_mask.shape
    # Get mask coordinates
    mask_coords = (a_mask > 0.6).nonzero()
    #print('=== Mask coordinates for image {0}'.format(iLEEHKCi))
    #print(mask_coords)
    # Get extremes
    left = torch.min(mask_coords[:,1]).item()
    right = torch.max(mask_coords[:,1]).item()
    top = torch.min(mask_coords[:,0]).item()
    bottom = torch.max(mask_coords[:,0]).item()
    # print(left, right, top, bottom)
    # Extract mask patch and store in list
    a_mask_patch = a_mask[top:bottom+1,left:right+1]
    all_masks.append(a_mask_patch)
    # Extract image patch and store in list
    an_image_patch = an_image[:,top:bottom+1,left:right+1]
    all_imgs.append(an_image_patch)
    
  return all_imgs, all_masks 

In [40]:
# Generate a synthetic batch
# Arguments of batch generator: batch_size, max_count, transformations.
# Use balanced counts by default. 
def gen_synth_batch(args, all_mp_img, all_mp_mask, batch_num):

  # Extract basic information
  batch_size = args.batch_size
  max_count = args.max_count
  data_path = args.segmentation_raw
  transformations = args.data_transforms 
  
  # Number of images in each directory
  image_path = data_path + 'imgs/'
  mask_path = data_path + 'masks/'
  image_dir_list = [name for name in os.listdir(image_path)]
  mask_dir_list = [name for name in os.listdir(mask_path)]
  num_images = int(len(image_dir_list))

  # Initialize batch data structures
  # data.shape - torch.Size([20, 3, 960, 960])
  # data.dtype torch.float32
  # target.shape torch.Size([20])
  # target.dtype torch.int64
  data = torch.zeros((batch_size,3,args.crop_height,args.crop_width),dtype=torch.float32)
  target = torch.zeros((batch_size),dtype=torch.int64)

  # Get microplastics
  # TODO
  # Is there a reason to regenerate the images and masks every iteration?
  #all_mp_img, all_mp_mask = get_all_microplastics(count_1_path, num_c1)
  
  # Loop through batch_size.
  for img_i in range(batch_size):
    # Select a random count-zero image
    rand_img_file = sample(image_dir_list, 1)[0]
    img_path = image_path + rand_img_file
    #print(img_path)
    # Load and apply transformations to the count-zero image
    image = image_loader(transformations, img_path)
    #============================================
    # Create mask variable here
    msk_path = mask_path + rand_img_file.replace('jpeg', 'tiff')
    #print(msk_path)
    mask_3d = image_loader(transformations, msk_path)
    mask, _ = torch.max(mask_3d, dim=0)
    #============================================
    # Select a_rand_count (within max_count) with balanced probabilities
    a_rand_count = np.random.randint(args.min_count, args.max_count) 
    #print(a_rand_count)
    # Insert microplastics
    image, mask = insert_microplastics(image, mask, all_mp_img, all_mp_mask, a_rand_count, args)
    # disp_image_grayscale(mask, rand_img_file)
    #ipdb.set_trace()
    # --- Testing --- store images
    save_path = project_base + '/Data/segmentation_training/synth_mp_img' + 'batch_' + str(batch_num) + str(img_i)
    # full_path = save_path + 'synth_mp_img' + str(img_i) + '_c' + str(a_rand_count) + '.jpeg'
    image_save(image, save_path + '.jpeg')
    image_save(mask, save_path + '.tiff')
    # Final normalization and store image and countsinvert_imgs
    if (args.invert_imgs):
      image = 1.0 - image # invert
    data[img_i,:,:,:] = args.final_transform_norm(image)
    # data[img_i,:,:,:] = image
    target[img_i] = a_rand_count

  # Return images and target labels
  return data, target

In [41]:
def main():

    args_ssd = ArgumentsSSD()
        
    count_1_path = args_ssd.count_1_path
    c1_dir_list = [name for name in os.listdir(count_1_path)]
    num_c1 = int(len(c1_dir_list)/2)
    all_mp_img, all_mp_mask = get_all_microplastics(count_1_path, num_c1)
    for i in range(args_ssd.num_batches):
        data, target = gen_synth_batch(args_ssd, all_mp_img, all_mp_mask, i)

In [42]:
if __name__ == '__main__':
    main()