In [1]:
from pathlib import Path
from PIL import Image
import cv2

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset as BaseDataset
import albumentations as albu

import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils

from tqdm import tqdm

import random

from cellseg_utils import (
    get_training_augmentation,
    get_validation_augmentation,
    get_preprocessing,
    get_all_fp_data,
    CellDataset4,
    split_image,
    get_squares,
    unsplit_image
)

In [4]:
root_dir = Path('datasets/Cells_2.0_for_Ivan/masked_MSC')
dir01 = root_dir / 'pics 2024-20240807T031703Z-001' / 'pics 2024'
dir02 = root_dir / 'pics 2024-20240807T031703Z-002' / 'pics 2024'

lf_dir = dir01 / 'LF1'

exps_dir_list = list()
for v in lf_dir.iterdir():
    exps_dir_list.append(v.name)
exps_dir_list.sort()
exps_dir_list

['+2024-05-05-LF1-p6-sl2',
 '+2024-05-06-LF1-p12',
 '+2024-05-06-LF1p9-sl2',
 '+2024-05-07-LF1p15',
 '+2024-05-08-LF1p18sl2',
 '+2024-05-31-LF1-p22',
 'test.pt',
 'train.pt',
 'val.pt']

In [5]:
dataset_dir = lf_dir
exp_class_dict = {'+2024-05-05-LF1-p6-sl2': 6,
                  '+2024-05-06-LF1-p12': 12,
                  '+2024-05-06-LF1p9-sl2': 9,
                  '+2024-05-07-LF1p15': 15,
                  '+2024-05-08-LF1p18sl2': 18,
                  '+2024-05-31-LF1-p22': 22
                  }

In [6]:
all_fp_data = get_all_fp_data(dataset_dir, exp_class_dict)
all_fp_data = all_fp_data[:10]
total_len = len(all_fp_data)
train_num = int(total_len * 0.7)
val_num = int(total_len * 0.2)
test_num = total_len - val_num

random.shuffle(all_fp_data)

train_fp_data = all_fp_data[:train_num]
val_fp_data = all_fp_data[train_num:train_num+val_num]
test_fp_data = all_fp_data[train_num+val_num:]

In [7]:
mask_img = Image.open(all_fp_data[0]['mask_fp'])
w, h = mask_img.size[0], mask_img.size[1]
w, h = int(w/2), int(h/2)
square_a = 512
border = 10

ENCODER = 'timm-efficientnet-b0'  # 'resnet101',  # 'efficientnet-b2',  # 'timm-efficientnet-b8',  # 'efficientnet-b0'
ENCODER_WEIGHTS = 'imagenet'

target_size = (w, h)
square_w, square_h = square_a, square_a
square_size = (square_w, square_h)

# full_size, full_size_with_borders, squares = get_squares(target_size,
#                                                          square_size,
#                                                          border)

full_size, squares = None, None
add_shadow_to_img = True

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
preprocessing_fn = None
preprocessing = get_preprocessing(preprocessing_fn)

train_dataset = CellDataset4(train_fp_data,
                      exp_class_dict,
                      full_size=full_size,
                      add_shadow_to_img=add_shadow_to_img,
                      squares=squares,
                      border=border,
                      channels=None,
                      classes_num=2,
                      augmentation=get_training_augmentation(target_size=target_size),
                      preprocessing=preprocessing,
                      classes=None,
                      target_size=target_size
                      )

val_dataset = CellDataset4(val_fp_data,
                      exp_class_dict,
                      full_size=full_size,
                      add_shadow_to_img=add_shadow_to_img,
                      squares=squares,
                      border=border,
                      channels=None,
                      classes_num=2,
                      augmentation=get_validation_augmentation(target_size=target_size),
                      preprocessing=preprocessing,
                      classes=None,
                      target_size=target_size
                      )

test_dataset = CellDataset4(test_fp_data,
                      exp_class_dict,
                      full_size=full_size,
                      add_shadow_to_img=add_shadow_to_img,
                      squares=squares,
                      border=border,
                      channels=None,
                      classes_num=2,
                      augmentation=get_validation_augmentation(target_size=target_size),
                      preprocessing=preprocessing,
                      classes=None,
                      target_size=target_size
                      )

In [8]:


# torch.save(train_dataset, lf_dir / 'train.pt')
# del train_dataset

100%|█████████████████████████████████████████████| 7/7 [00:05<00:00,  1.29it/s]


In [9]:

# torch.save(val_dataset, lf_dir / 'val.pt')
# del val_dataset

100%|█████████████████████████████████████████████| 2/2 [00:01<00:00,  1.27it/s]


In [10]:
test_dataset = CellDataset4(test_fp_data,
                      exp_class_dict,
                      full_size=full_size,
                      add_shadow_to_img=add_shadow_to_img,
                      squares=squares,
                      border=border,
                      channels=None,
                      classes_num=2,
                      augmentation=get_validation_augmentation(target_size=target_size),
                      preprocessing=preprocessing,
                      classes=None,
                      target_size=target_size
                      )
# torch.save(test_dataset, lf_dir / 'test.pt')
# del test_dataset

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.30it/s]


In [11]:
# train_dataset = torch.load(lf_dir / 'train.pt')
# val_dataset = torch.load(lf_dir / 'val.pt')
# test_dataset = torch.load(lf_dir / 'test.pt')

In [26]:
# idx = 0
# full_img, full_mask = train_dataset[idx]

# full_img_ = Image.fromarray((full_img[:3].transpose(1, 2, 0)).astype('uint8')).resize(target_size)
# display(full_img_)

# full_img_ = Image.fromarray((np.stack([full_img[-1]]*3, axis=-1)).astype('uint8')).resize(target_size)
# display(full_img_)

# # for idx in range(full_img.shape[0]):
# #     img = Image.fromarray((np.stack([full_img[idx]]*3, axis=-1)).astype('uint8')).resize(target_size)
# #     display(img)
    
# for idx in range(full_mask.shape[0]):
#     img = Image.fromarray((np.stack([full_mask[idx]]*3, axis=-1) * 255).astype('uint8')).resize(target_size)
#     display(img)