In [337]:
from os.path import join as pjoin
import os

import cv2

import matplotlib.pyplot as plt

import numpy as np

from torch.utils.data import DataLoader, Dataset
import albumentations as A

from additonFunc import create_image_plot, uniqufy_path

In [338]:
NORMALIZE_MEAN_IMG =  [0.4295, 0.4325, 0.3961]       #[0.485, 0.456, 0.406]
NORMALIZE_DEVIATIONS_IMG =  [0.2267, 0.2192, 0.2240] #[0.229, 0.224, 0.225]

DATASET_DIR ='dataset/tiff'
TEST_SET    = (pjoin(DATASET_DIR, "test_non_cutted"), pjoin(DATASET_DIR, "test_labels_non_cutted"))

aug_test_path = uniqufy_path(f"AugmentsTests/aug_tests")
if not os.path.exists(aug_test_path):
        os.makedirs(aug_test_path)

In [339]:
# transform = A.Compose(
#     [
#         A.OneOf(
#             [
#                 A.HorizontalFlip(p=1),
#                 A.VerticalFlip(p=1),
#                 A.RandomRotate90(p=1),
#             ],
#             p=0.75,
#         ),
#         A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True),
#         A.RGBShift(0, (-50, -10), 0, p=1),
#         A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=(0, 0.2), p=1),
#     ]
# )

based_transform= A.Compose(
    [
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True)
    ]
)

transform= A.Compose(
    [
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True),
        A.RGBShift(0, (0.1, 0.3), (0.1, 0.3), p=1),
        A.RandomBrightnessContrast(brightness_limit=(0.1), contrast_limit=(0, 0.3), p=1),
    ]
)

In [340]:
class RoadsDataset(Dataset):
    def __init__(self, values_dir, labels_dir, class_rgb_values=None, transform=None, readyToNetwork=None):
        self.values_dir = values_dir
        self.labels_dir = labels_dir
        self.class_rgb_values = class_rgb_values
        self.images = [pjoin(self.values_dir, filename) for filename in sorted(os.listdir(self.values_dir))]
        self.labels = [pjoin(self.labels_dir, filename) for filename in sorted(os.listdir(self.labels_dir))]
        self.transform = transform
        self.readyToNetwork = readyToNetwork

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        label_path = self.labels[index]

        image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
        label = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)

        if self.transform:
            sample = self.transform(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        if self.readyToNetwork:
            sample = self.readyToNetwork(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        return image, label

In [341]:
based_dataset = RoadsDataset(*TEST_SET, transform=based_transform)
sample_dataset = RoadsDataset(*TEST_SET, transform=transform)

In [342]:
for id, ((bImg, _), (img, _)) in enumerate(zip(based_dataset, sample_dataset)):
    fig = create_image_plot(figsize = (12,6), image = img, origin = bImg)
    fig.savefig(pjoin(aug_test_path, f"image_{id:03}"))
    fig.clear()
    plt.close(fig)
    if(id == 15):
        break