In [9]:
import os
from glob import glob

import cv2
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedKFold


import sys
sys.path.append('/root/dacon-epitope/dacon-sem')

from src.datasets.sem_dataset import SEMDataset
from src.datasets.sem_datamodule import SEMDataModule

In [7]:
data_path = os.path.abspath('/shared/Samsung/')
simulation_sem_paths = os.path.join(data_path, 'simulation_data', 'SEM', '*', '*', '*.png')
simulation_sem_paths = np.array(sorted(glob(simulation_sem_paths)))
simulation_depth_paths = os.path.join(data_path, 'simulation_data', 'Depth', '*', '*', '*.png')
simulation_depth_paths = np.array(sorted(glob(simulation_depth_paths) + glob(simulation_depth_paths)))
data_len = len(simulation_sem_paths)

skf = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
splitlist = list(skf.split(range(data_len),[0]*data_len))

train_index = splitlist[0][0]
valid_index = splitlist[0][1]

train_sem_paths = simulation_sem_paths[train_index]
train_depth_paths = simulation_depth_paths[train_index]

valid_sem_paths = simulation_sem_paths[valid_index]
valid_depth_paths = simulation_depth_paths[valid_index]

test_sem_paths = os.path.join(data_path, 'test', 'SEM', '*.png')
test_sem_paths = np.array(sorted(glob(test_sem_paths)))

In [None]:
def get_transform(resize):
    label_transform = A.Compose([
        A.Resize(resize[0], resize[1], always_apply=True),
        ToTensorV2()
    ])

    transform = A.Compose([
        A.Normalize(mean=[0.5],std=[0.5]),
        A.Resize(resize[0], resize[1], always_apply=True),
        ToTensorV2()
    ])

    return transform, label_transform

In [92]:
sem_sample_path = train_sem_paths[0]
depth_sample_path = train_depth_paths[0]

sem_img = cv2.imread(sem_sample_path, cv2.IMREAD_GRAYSCALE)
depth_img = cv2.imread(depth_sample_path, cv2.IMREAD_GRAYSCALE)

transform = A.Compose([
    A.Resize(96, 64, always_apply=True),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, 
                       scale_limit=0.1, 
                       rotate_limit=5, 
                       p=0.5),
    # A.GaussNoise(var_limit=(10.0, 50.0), p=1),
    A.Normalize(mean=[0.5],std=[0.5]),
    ToTensorV2()
])

transformed = transform(image=sem_img, mask=depth_img)
transformed_image = transformed['image']
transformed_mask = transformed['mask']

imgs = [sem_img, transformed_image, transformed_mask]

# def display_images(imgs):
#     for i in range(len(imgs)):
#         plt.subplot(1, len(imgs), i+1)
#         plt.imshow(imgs[i])
#     plt.title(os.path.basename(sem_sample_path))

# display_images(imgs)

In [97]:
transformed_mask.unsqueeze(dim=0).shape

torch.Size([1, 96, 64])

In [94]:
transformed_image.shape

torch.Size([1, 96, 64])

In [None]:
transform, label_transform = get_transform([96, 64])

small_len = len(train_sem_paths) // 10
data_train = SEMDataset(train_sem_paths[:small_len], train_depth_paths[:small_len], transform, label_transform)
data_test = SEMDataset(test_sem_paths, None, transform, label_transform)