In [None]:
import torch
from torchvision.io.image import read_image,ImageReadMode
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torchvision.transforms.functional as F
import os
import torchvision.transforms as T
import albumentations as A

In [None]:
segm_dataset_dir = os.path.join("data", "new_segmentation_data")
class ExpertSegmentationDataset(Dataset):
  """ reads dataset from `segm_dataset_dir`, providing pairs 
   (grayscale img 1xWxH, image of target indices WxH) """

  mapping=torch.tensor([ (0,0,0),(255,0,0),(0,255,0),(0,0,255) ],dtype=torch.uint8)
  # encoding: 0 (black)=background,  1 (red)=wall, 2 (green)=plaque, 3 (blue)=lumen
  missing=255

  @classmethod
  def rgb_to_index(cls,seg):
    #print(f"rgb_to_index {seg.shape}")
    assert(seg.shape[0]==3) 
    #print("rgb_to_index ",torch.unique(torch.reshape(seg,(3,-1)),dim=1))  
    c=cls.missing*torch.ones(seg.shape[1:],dtype=torch.uint8)
    for k in range(cls.mapping.shape[0]): # go over all classes
        mask=(seg==cls.mapping[k,:].unsqueeze(1).unsqueeze(2)).all(0) # binary mask
        c[mask]=k  
    assert((c!=cls.missing).all())
    return c.unsqueeze(0)

  @classmethod
  def inds2rgb(cls,inds):
    y=torch.zeros((3,inds.shape[0],inds.shape[1]),dtype=torch.uint8)
    for k in range(cls.mapping.shape[0]):
        y[:,(inds==k).squeeze(0)]=cls.mapping[k,:].unsqueeze(1)
    return y


  @classmethod
  def crop(self,img):
    return F.crop(img,120,185,595,645) 


  def __init__(self,names,atransform=None,crop=True):
    """ `names` is a list of filenames, `atransform` is an albumentation transform """
    super().__init__()
    self.names=names
    self.atransform=atransform
    self.do_crop=crop

  def  __getitem__(self,index):
    filename=self.names[index]
    #print(f"SegmentationDataset reading file {filename}")
    img=read_image(os.path.join(segm_dataset_dir,'data/trans',filename),ImageReadMode.GRAY)
    if self.do_crop:
       img=self.crop(img)
    seg=read_image(os.path.join(segm_dataset_dir,'references/trans',filename),ImageReadMode.RGB)
    if self.do_crop:
       seg=self.crop(seg)
    inds=self.rgb_to_index(seg)
    #print(f"inds.shape={inds.shape} {inds.dtype}")
    assert(img.shape[0]==1)
    assert(inds.shape[0]==1)
    if self.atransform:
        # tranform image and segmentation together
        transformed=self.atransform(image=img[0].numpy(),mask=inds[0].numpy())
        img=T.ToTensor()(transformed['image'])
        inds=torch.as_tensor(transformed['mask'])
        
    #print(f"transformed inds.shape={inds.shape} {inds.dtype}")
    assert(type(img)==torch.Tensor)    
    return (img,inds)

  def __len__(self):
         return len(self.names)


In [None]:
from segmentation.common.visualization import plot_image_label
atransformations_test=A.Compose([
     A.CenterCrop(512,512)
     ])

ds = ExpertSegmentationDataset(names = ["p_201501061025140196VAS.png"], atransform=atransformations_test )

img, label = ds[0]
plot_image_label(img, label)


In [None]:
""