-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
72 lines (56 loc) · 2.48 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import cv2
from torch.utils.data import Dataset
import os
import numpy as np
def get_image_mask_pairs(base_folder):
images_folder = os.path.join(base_folder, 'images')
masks_folder = os.path.join(base_folder, 'masks')
if not os.path.exists(images_folder):
print(f'Note: No images folder found in {base_folder} we use only mask images if it works')
assert os.path.exists(masks_folder), f'No masks folder found in {base_folder}'
pairs = []
for mask_name in sorted(os.listdir(masks_folder)):
mask_base_name, _ = os.path.splitext(mask_name)
# Find the corresponding image
if os.path.exists(images_folder):
for image_name in sorted(os.listdir(images_folder)):
image_base_name, _ = os.path.splitext(image_name)
if image_base_name == mask_base_name:
image_path = os.path.join(images_folder, image_name)
mask_path = os.path.join(masks_folder, mask_name)
pairs.append((image_path, mask_path))
break
else:
mask_path = os.path.join(masks_folder, mask_name)
pairs.append((mask_path, mask_path))#both are masks
return pairs
class BasicDataset(Dataset):
def __init__(self, data_sample_pairs, transforms=None,vanilla_aug=False,aug_iter=1,gen_nc=1):
self.data_sample_pairs=data_sample_pairs
self.transforms=transforms
self.single_cell_mask_crop_bank=[]
self.gen_nc=gen_nc
if vanilla_aug:
tmp=[]
for _ in range(aug_iter):
tmp+=self.data_sample_pairs
self.data_sample_pairs=tmp
def __len__(self):
return len(self.data_sample_pairs)
def preprocess(cls, img, mask,transforms):
tensor_img,tensor_mask=transforms(img, mask)
return tensor_img,tensor_mask
def __getitem__(self, idx):
img_file,mask_file = self.data_sample_pairs[idx]
mask = cv2.imread(mask_file,0)>0
mask = mask.astype('float32')
if self.gen_nc==1:#2D
img = cv2.imread(img_file,0).astype('float32')
else:#3D
img = cv2.imread(img_file).astype('float32')
img=(255 * ((img - img.min()) / (img.ptp()+1e-6))).astype(np.uint8)
tensor_img,tensor_mask = self.preprocess(img, mask,self.transforms)
return {
'image': tensor_img,
'mask': tensor_mask,
}