In [155]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2

In [156]:
images_folder_path = 'train_splitted/images/'
masks_folder_path = 'train_splitted/masks/'

In [157]:
def init_folder(path: str):
  if not os.path.exists(path):
      os.makedirs(path)


# init folders for images and masks if does not exist
init_folder(images_folder_path)
init_folder(masks_folder_path)

Методы обработки названий изображений

In [158]:
def add_lead_zeros(_id: int):
  if 0 <= _id <= 9:
    return '00' + str(_id)
  elif 10 <= _id <= 99:
    return '0' + str(_id)
  else:
    return str(_id)

def image_name(_id: int):
  return f'train_image_{add_lead_zeros(_id)}.png'

def mask_name(_id: int):
  return f'train_mask_{add_lead_zeros(_id)}.png'

def image_path(_id: int):
  return './train/train/images/' + image_name(_id)

def mask_path(_id: int):
  return f'./train/train/masks/' + mask_name(_id)


Augmentation functions

In [141]:
!pip install imgaug

zsh:1: command not found: pip


In [173]:
import random
import albumentations as A


def rotate_image(image, mask, rotate_angle):
  transform = A.ShiftScaleRotate(border_mode=cv2.BORDER_CONSTANT, 
                                scale_limit=0.3,
                                rotate_limit=rotate_angle,
                                p=0.7)
  augmented_image = transform(
    image=np.array(image),
    mask=np.array(mask),
  )
  
  rotated_image = augmented_image['image']
  rotated_mask = augmented_image['mask']
  return Image.fromarray(rotated_image), Image.fromarray(rotated_mask)


def transform_image(image):
  transform = A.Compose([
      A.HorizontalFlip(p=0.5),
      # A.ShiftScaleRotate(border_mode=cv2.BORDER_CONSTANT, 
      #                     scale_limit=0.3,
      #                     rotate_limit=(-180, 180),
      #                     p=0.7),
      # A.GridDistortion(p=0.5),
      A.OpticalDistortion(p=0.5),
      A.GaussianBlur(p=0.5),
      A.Equalize(p=0.5),
      A.RandomBrightnessContrast(p=0.5),
      A.RandomGamma(p=0.5)
  ])
  random.seed(42)
  transformed = transform(image=np.array(image))
  return Image.fromarray(transformed['image'])

In [175]:
# logic for dividing the image into fragments
def split_image(image_id: int, k: int = 2, save: bool = False):
  
  '''
  Function that makes only fragments of images
  '''
  
  base_image = Image.open(image_path(image_id))
  base_image_width, base_image_height = base_image.size
  width = base_image_width // k
  height = base_image_height // k
  
  crops = []
  image_index = 0
  
  for i in range(1, k+1):
    for j in range(1, k+1):
      im_crop = base_image.crop(
        (width * (i - 1), height * (j - 1), width * i, height * j)
      )
      
      if save:
        _image_name = f'image_{image_id}_fragment_{image_index}_k_{k}.png'
        im_crop.save(images_folder_path + _image_name, quality=95)
      else:
        crops.append(im_crop)
        
      image_index += 1
  
  return crops


# logic for dividing the mask into fragments
def split_mask(mask_id: int, k: int = 2, save: bool = False):
  
  '''
  Function that makes only fragments of mask
  '''
  
  base_mask = Image.open(mask_path(mask_id))
  base_mask_width, base_mask_height = base_mask.size
  width = base_mask_width // k
  height = base_mask_height // k
  
  crops = []
  mask_index = 0
  
  for i in range(1, k+1):
    for j in range(1, k+1):
      im_crop = base_mask.crop(
        (width * (i - 1), height * (j - 1), width * i, height * j)
      )
      if save:
        _mask_name = f'mask_{mask_id}_fragment_{mask_index}_k_{k}.png'
        im_crop.save(masks_folder_path + _mask_name, quality=95)
      else:
        crops.append(im_crop)
      
      mask_index += 1
  
  return crops

In [223]:
def get_fragments(image_id: int, rotation, k: int = 2):
  
  '''
  Function for getting fragments from image and mask
  Save in images_folder_path and masks_folder_path
  '''
  
  base_image = Image.open(image_path(image_id))
  base_mask = Image.open(mask_path(image_id))
  base_image_width, base_image_height = base_image.size
  width = base_image_width // k
  height = base_image_height // k

  index = 0
  
  for i in range(1, k+1):
    for j in range(1, k+1):
      image_crop = base_image.crop(
        (width * (i - 1), height * (j - 1), width * i, height * j)
      )
      
      mask_crop = base_mask.crop(
        (width * (i - 1), height * (j - 1), width * i, height * j)
      )
      
      _rotation = next(rotation)
      
      # rotations
      image_crop, mask_crop = rotate_image(image_crop, mask_crop, _rotation)
      
      # transform image
      image_crop = transform_image(image_crop)
      
      _image_name = f'image_{image_id}_fragment_{index}_k_{k}.png'
      _mask_name = f'mask_{image_id}_fragment_{index}_k_{k}.png'
      image_crop.save(images_folder_path + _image_name, quality=95)
      mask_crop.save(masks_folder_path + _mask_name, quality=95)
      
      index += 1

Сгенерируем датасет для 20 картинок с делением 1...10

In [220]:
def generate_fragments(image_id: int, n_splits: int = 10):
  
  '''
  Generate fragments
  '''
  
  for k in range(1, n_splits + 1):
    split_image(image_id, k)
    split_mask(image_id, k)
    


def generate_fragments_transformed(image_id: int, n_splits: int = 10):
  
  '''
  Generate fragments using augmentation transform
  '''
  
  rotations = []
  random.seed(42)
  
  for _ in range(n_splits**3):
    rotations.append(random.randint(-180, 180))
  
  def get_rotation():
    yield from rotations
  
  rotation_generator = get_rotation()
  
  for k in range(1, n_splits + 1):
    get_fragments(image_id, rotation_generator, k)


Сгенерируем фрагменты для картинки с индексом 001

In [None]:
N_SPLITS = 10

generate_fragments_transformed(1, n_splits=N_SPLITS)

In [37]:
import multiprocessing
from multiprocessing import Pool
num_cores = multiprocessing.cpu_count()
num_cores

12

In [None]:
# def process_image(arg):
#     image_path, mask_path, i = arg
#     x = Image.open(image_path)
#     y = Image.open(mask_path)
#     for j in range(2):
#         transform = A.Compose([
#             A.HorizontalFlip(p=0.5),
#             A.ShiftScaleRotate(border_mode=cv2.BORDER_CONSTANT, 
#                                 scale_limit=0.3,
#                                 rotate_limit=(10, 30),
#                                 p=0.7),
#             # A.GridDistortion(p=0.5),
#             A.OpticalDistortion(p=0.5),
#             A.GaussianBlur(p=0.5),
#             A.Equalize(p=0.5),
#             A.RandomBrightnessContrast(p=0.5),
#             A.RandomGamma(p=0.5)
#         ])
#         transformed = transform(image=np.array(x), mask=np.array(y))

#         image_trans = transformed['image']
#         mask_trans = transformed['mask']
#         x = Image.fromarray(image_trans)
#         y = Image.fromarray(mask_trans)
#         x.save(f'./input2/{i}v{j}.jpg')
#         y.save(f'./Output2/{i}v{j}.png', 'PNG')

# if __name__ == '__main__':
#     img = sorted([str(os.path.join(dp, f)) for dp, dn, filenames in os.walk(X_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
#     mask = sorted([str(os.path.join(dp, f)) for dp, dn, filenames in os.walk(Y_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
#     args_list = [(image_path, mask_path, i) for i, (image_path, mask_path) in enumerate(zip(img, mask))]
#     with Pool(processes=multiprocessing.cpu_count()) as pool:
#         list(tqdm(pool.imap(process_image, args_list), total=len(args_list)))