In [None]:
import os

import cv2
import torch
import albumentations as A
from matplotlib import pyplot as plt


In [None]:
TILES_BASE_DIR = "/media/przemek/data/corn_data/processed"
SUBDIRECTORIES_TO_PROCESS = [
    "kukurydza_5_ha",
]


UNCROPPED_TILE_SIZE = (512 + 256)  # in pixels
CROPPED_TILE_SIZE = 512
CROP_TILE_MARGIN = (UNCROPPED_TILE_SIZE - CROPPED_TILE_SIZE) // 2

In [None]:
tiles_img_paths = []
tiles_mask_paths = []


for dir_name in SUBDIRECTORIES_TO_PROCESS:
    dir_path = os.path.join(TILES_BASE_DIR, dir_name)
    file_names = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
    
    mask_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'mask' in f])
    img_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'img' in f])
    common_files_prefixes = mask_files_prefixes.intersection(img_files_prefixes)
    all_files_prefixes = mask_files_prefixes.union(img_files_prefixes)
    missing_files_prefixes = all_files_prefixes - common_files_prefixes
    
    if missing_files_prefixes:
        raise Exception(f"Some files don't have correponding pair in mask/image: {missing_files_prefixes}")
    
    
    for file_prefix in common_files_prefixes:
        img_file_name = file_prefix + '_img.png'
        mask_file_name = file_prefix + '_mask.png'
        tiles_img_paths.append(os.path.join(dir_path, img_file_name))
        tiles_mask_paths.append(os.path.join(dir_path, mask_file_name))
        
print(f'Number of tiles = {len(tiles_img_paths)}')

In [None]:
img = cv2.imread(tiles_img_paths[433])
mask = cv2.imread(tiles_mask_paths[433], cv2.IMREAD_GRAYSCALE)
plt.imshow(img)
plt.show()
plt.imshow(mask)

In [None]:
UNCROPPED_TILE_SIZE = (512 + 256)  # in pixels
CROPPED_TILE_SIZE = 512
CROP_TILE_MARGIN = (UNCROPPED_TILE_SIZE - CROPPED_TILE_SIZE) // 2


# Declare an augmentation pipeline
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomScale(scale_limit=0.15),  # above scale 0.16 images are too small
    A.Rotate(limit=90),  # degrees
    A.Crop(x_min=CROP_TILE_MARGIN, y_min=CROP_TILE_MARGIN, x_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN-1, y_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN-1),
])

# TODO - color, contrast, gamma, randomShadow, rain

transformed = transform(image=img, mask=mask)
transformed_image = transformed["image"]
transformed_mask = transformed["mask"]

plt.imshow(transformed_image)
plt.show()
plt.imshow(transformed_mask)

mask.shape


In [None]:
class CornFieldDamageDataset(torch.utils.data.Dataset):
    def __init__(self, img_file_paths, mask_file_paths):
        self.img_file_paths = img_file_paths
        self.mask_file_paths = mask_file_paths
        assert(len(self.img_file_paths) == len(mask_file_paths))
    
    def __len__(self):
        return len(self.mask_file_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
    def _get_img_and_mask_tranform(self):
        pass

    