In [1]:
from tqdm.notebook import trange, tqdm

import torch
import torch.utils.data as data

import os
from glob import glob
import os.path as osp

import numpy as np
from skimage import io

from skimage.measure import label, regionprops_table
import pandas
from PIL import Image


In [2]:
class DataGenDataset(data.Dataset):

    def __init__(self,
                 root_fp='/media/miro/4tb_cached/usg_data/base_dataset/',
                 mask_dir_name="lung_mask_ss",
                 transform=None,
                 number_of_frames=32,
                 step = 16,
                 load2ram = False):
        super(DataGenDataset, self).__init__()
        self.transform = transform
        self.number_of_frames = number_of_frames
        self.step = step

        self.image_list = sorted(glob(os.path.join(root_fp, 'frames/*/*/*.png')))
        self.mask_list = [fp.replace("/frames/","/"+mask_dir_name+"/") for fp in self.image_list]

        self.classes, self.identifiers, self.frame_list, self.data, self.mask_data = \
            list(), list(), list(), dict(), dict()

        nos = len(self.image_list)
        idx = 0
        while idx < nos:
            if idx + self.number_of_frames < nos:
                if os.path.dirname(self.image_list[idx]) in os.path.dirname(self.image_list[idx + self.number_of_frames]):
                    self.frame_list += [self.image_list[idx:idx + self.number_of_frames]]
                    self.identifiers += [os.path.dirname(self.image_list[idx])]
            idx = idx+self.step

        for idx in tqdm(range(nos)):
            if "lung_sliding_absence" in self.image_list[idx]:
                self.classes += [1]
            else:
                self.classes+= [0]

            if load2ram:
                self.data[self.image_list[idx]] = io.imread(self.image_list[idx])
                self.mask_data[self.image_list[idx]] = io.imread(self.mask_list[idx])
            else:
                self.data[self.image_list[idx]] = Image.open(self.image_list[idx])
                self.mask_data[self.image_list[idx]] = Image.open(self.mask_list[idx])


    def __getitem__(self, index):
        frames = self.frame_list[index]

        img = np.zeros((self.number_of_frames,480,480))
        mask = np.zeros((256,256))

        for idx, frame in enumerate(frames):
            img[idx,:,:] = np.array(self.data[frame])

        mask[:,:] = self.mask_data[frames[self.number_of_frames//2]]

        if self.transform:
            img = self.transform(img)

        return img, mask, self.frame_list[index][0]

    def __len__(self):
        return len(self.frame_list)

    def get_class_vector(self, indices):
        im_classes = []
        for index in indices:
            if "lung_sliding_absence" in self.image_list[index]:
                im_classes += [1]
            else:
                im_classes += [0]
        return np.array(im_classes)

In [3]:
from skimage.measure import label, regionprops_table
from skimage.morphology import binary_opening,rectangle
from skimage import transform
import matplotlib.pyplot as plt

In [4]:
steps = [8]
nofs = [8]
for idx,nof in enumerate(nofs):
    dataset = DataGenDataset(root_fp = "/home/jin/Documents/miro/data/base_dataset/",number_of_frames=nof,step=steps[idx],load2ram= True)
    for i in tqdm(range(len(dataset))):
        img,mask,path = dataset.__getitem__(i)
        new_path = path.replace("/frames/","/mmode/"+str(dataset.number_of_frames)+"/")
        if not os.path.exists(os.path.dirname(new_path)):
            os.makedirs(os.path.dirname(new_path))

        im  = binary_opening(transform.resize(mask,(480,480)),rectangle(5,5))
        #plt.imshow(im)
        #plt.figure()
        label_image = label(im)
        if np.max(label_image)>0:
            #plt.subplot(122)
            #plt.imshow(label_image)
            regions = regionprops_table(label_image,properties=('label', 'bbox'))
            #plt.axvline(regions['bbox-1'][0])
            #plt.axvline(regions['bbox-3'][0])

            slices = np.arange(regions['bbox-1'][0],regions['bbox-3'][0])
            for slice in slices:
                mmode = img[:,:,slice].transpose(1,0)
                #plt.subplot(121)
                if np.max(mmode) !=0:
                    io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=17386.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2126.0), HTML(value='')))

  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io.imsave(new_path.replace(".png","_"+str(slice)+".png"),mmode.astype(np.uint8))
  io


