In [1]:
import os
import sys
import random
import warnings

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F

'''
from google.colab import drive
drive.mount('/content/drive')
'''

# Set some parameters
BATCH_SIZE = 5 # the higher the better
IMG_WIDTH = 512 # for faster computing on kaggle
IMG_HEIGHT = 512 # for faster computing on kaggle
IMG_CHANNELS = 3

#TRAIN_PATH = '/content/drive/My Drive/UNet/Cell'
#TEST_PATH = '/content/drive/My Drive/UNet/Cell'

TRAIN_PATH = './new_data/15'
TEST_PATH = './new_data/15'
CH1_PATH = './ch1_png'
CH2_PATH = './ch2_png'

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
seed = 42



In [2]:
# Get train and test IDs
train_ids = next(os.walk(TRAIN_PATH))[1]
#test_ids = next(os.walk(TEST_PATH))[1]
num_test_items = int(0.1 * len(train_ids))

# Slice the train_ids list to get the test IDs
test_ids = train_ids[-num_test_items:]
np.random.seed(10)

In [3]:

import re

#print(len(train_ids))
X_train = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.float64)
Y_train = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.uint8)
Y_train_1 = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.uint8)
Y_train_2 = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.uint8)

print('Getting and resizing train images and masks ... ')
sys.stdout.flush()
for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):
    
    path = os.path.join(TRAIN_PATH, id_)
    # pattern = r'z\d+'
    # matches = re.finditer(pattern, id_)
    # for match in matches:
    #     start = match.start()
    #     end = match.end()
    #     id_ = id_[:end] + '_mSLIM' + id_[end:]
    img_path = os.path.join(path, 'images', id_ +'_mSLIM'+ '.png')
    if path != "./Cell/Cell/f0_t0_i0_ch0_c15_r44_z0":
        img = imread(img_path)
        img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
        img = np.expand_dims(img, axis=-1).astype(np.float64)
        X_train[n] = img.astype(np.float64)

        # mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=bool)
        # # mask2 = None
        for mask_file in next(os.walk(path + '/masks/'))[2]:
            mask = imread(os.path.join(path, 'masks', mask_file))
            mask = resize(mask, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
            if mask_file.endswith('Prot.png'):
                Y_train_2[n] = np.where(mask[:, :, np.newaxis] > 0, 2, 0)
            else:
                Y_train_1[n] = np.where(mask[:, :, np.newaxis] > 0, 1, 0)
        # # mask1_rgb = mask1_color * mask1[:, :, np.newaxis]
        # # mask2_rgb = mask2_color * mask2[:, :, np.newaxis]
        # # background_rgb = background_color * (1 - (mask1 + mask2)[:, :, np.newaxis])
            # mask_path = next(os.walk(path + '/masks/'))[2] 
            # #print(mask_path)
            # mask_count = min(2, len(mask_path))
            # for i in range(mask_count):
            #     mask_file = mask_path[i]  
            #     mask_color = [0, 0, 0]  
            #     if mask_file.startswith('cell'):
            #         mask_color = [255, 255, 0]  
            #     elif mask_file.startswith('dead'):
            #         mask_color = [128, 0, 128] 
            #     else:
            #         mask_color = [255, 255, 0] 
            #     mask = imread(os.path.join(path, 'masks', mask_file))
            #     mask = resize(mask, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
            #     mask_rgb = mask_color * (mask[:, :, np.newaxis] > 0)  
            #     Y_train[n, :, :, :] += mask_rgb.astype(np.uint8)
        Y_train[n] = np.maximum(Y_train_1[n], Y_train_2[n])
        


Getting and resizing train images and masks ... 


100%|██████████| 750/750 [07:23<00:00,  1.69it/s]


In [4]:
print(Y_train.shape)
zero = 0
one = 0
two = 0
for array in Y_train[5]:
    for i in array:
        if i == 0:
            zero = zero + 1
        elif i == 1:
            one = one + 1
        else:
            two = two + 1
print(zero)
print(one)
print(two)

(750, 512, 512, 1)
260712
1196
236


In [13]:
print('Processing ch1 labels and updating Y_train ...')
white_threshold = 255  
for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):
    ch1_file = id_ + '_mFL#1.png'
    ch1_path = os.path.join(CH1_PATH, ch1_file)
    ch1_path = ch1_path.replace('ch0','ch1')
    #print(ch2_path)
    if os.path.exists(ch1_path):
        ch1_image = imread(ch1_path)
        ch1_image_resized = resize(ch1_image, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
        
        ch1_mask = np.all(ch1_image_resized == white_threshold, axis=-1).astype(np.uint8) * 1
        
        ch1_mask = np.expand_dims(ch1_mask, axis=-1)

        Y_train[n] = np.where((ch1_mask == 1) & (Y_train[n] > 0), 1, Y_train[n])
        
print('Done!')

Processing ch1 labels and updating Y_train ...


100%|██████████| 1119/1119 [10:01<00:00,  1.86it/s]

Done!





In [14]:
print(Y_train.shape)
zero = 0
one = 0
two = 0
for array in Y_train[5]:
    for i in array:
        if i == 0:
            zero = zero + 1
        elif i == 1:
            one = one + 1
        else:
            two = two + 1
print(zero)
print(one)
print(two)

(1119, 512, 512, 1)


258335
1824
1985


In [15]:
print('Processing ch2 labels and updating Y_train ...')
white_threshold = 255  
for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):
    ch2_file = id_ + '_mFL#2.png'
    ch2_path = os.path.join(CH2_PATH, ch2_file)
    ch2_path = ch2_path.replace('ch0','ch2')
    #print(ch2_path)
    if os.path.exists(ch2_path):
        ch2_image = imread(ch2_path)
        ch2_image_resized = resize(ch2_image, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
        
        ch2_mask = np.all(ch2_image_resized == white_threshold, axis=-1).astype(np.uint8) * 2
        
        ch2_mask = np.expand_dims(ch2_mask, axis=-1)

        Y_train[n] = np.where((ch2_mask == 2) & (Y_train[n] > 0), 2, Y_train[n])
        
print('Done!')

Processing ch2 labels and updating Y_train ...


100%|██████████| 1119/1119 [09:49<00:00,  1.90it/s]

Done!





In [16]:
print(Y_train.shape)
zero = 0
one = 0
two = 0
for array in Y_train[5]:
    for i in array:
        if i == 0:
            zero = zero + 1
        elif i == 1:
            one = one + 1
        else:
            two = two + 1
print(zero)
print(one)
print(two)

(1119, 512, 512, 1)
258335
1521
2288


In [5]:
np.save('X.npy',X_train)
np.save('Y.npy',Y_train)