In [1]:
from torch.utils import data
import os
import numpy as np
import sys
sys.path.append('../')
import custom_transforms as tr
import scipy.misc as m
from PIL import Image
from torchvision import transforms# the problem of version not found was solved by reverting to pillow=6.2.1

In [2]:
#this class is based on Cityscapes Dataset in the given repo..
class SaltmarshSegmentation(data.Dataset):
    NUM_CLASSES = 9

    def __init__(self, args,root=r"C:\Users\Jayant\Documents\segPipieline\pytorch-deeplab-xception-master\Data", split="train"):

        self.root = root
        self.split = split
        self.args = args
        self.images = []
        self.masks=[]

        self.set_data_names()

        #self.void_classes = [0]
        self.valid_classes = [0,1,2,3,4,5,6,7,8]
        self.class_names = ['Background','Limonium', 'Spartina', 'Batis', 'Other', 'Spart_dead', \
                            'Juncus', 'Sacricornia', 'Borrichia']

        #self.ignore_index = 255
        #self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES)))

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        
        img_path = self.root+"/"+self.split+"/"+self.images[index]
        lbl_path = self.root+"/"+self.split+"/"+self.masks[index]

        _img = Image.open(img_path).convert('RGB')
        _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _target = Image.fromarray(_tmp)

        sample = {'image': _img, 'label': _target}

        if self.split == 'train':
            return self.transform_tr(sample)
        elif self.split == 'val':
            return self.transform_val(sample)
        elif self.split == 'test':
            return self.transform_ts(sample)
    def set_data_names(self):  
        for file in os.listdir(self.root+"/"+self.split+"/"):
            if (file.endswith("mask.png")):
                    self.masks.append(file)
                    s=file.split('_')
                    imgname="_".join(s[:-1])+".jpg"
                    self.images.append(imgname)
        return True

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_ts(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

In [5]:
saltmash=SaltmarshSegmentation([513,513])

In [6]:
for i in range(0,114): 
    print(saltmash.images[i])
    print(saltmash.masks[i])
    print("-------------------------------------------")

full_2014_Row10_DSC_3516_18.jpg
full_2014_Row10_DSC_3516_18_mask.png
-------------------------------------------
full_2014_Row10_DSC_3516_35.jpg
full_2014_Row10_DSC_3516_35_mask.png
-------------------------------------------
full_2014_Row10_DSC_3516_37.jpg
full_2014_Row10_DSC_3516_37_mask.png
-------------------------------------------
full_2014_Row10_DSC_3516_39.jpg
full_2014_Row10_DSC_3516_39_mask.png
-------------------------------------------
full_2014_Row10_DSC_3516_40.jpg
full_2014_Row10_DSC_3516_40_mask.png
-------------------------------------------
full_2014_Row12_DSC_3692_16.jpg
full_2014_Row12_DSC_3692_16_mask.png
-------------------------------------------
full_2014_Row12_DSC_3692_21.jpg
full_2014_Row12_DSC_3692_21_mask.png
-------------------------------------------
full_2014_Row12_DSC_3692_35.jpg
full_2014_Row12_DSC_3692_35_mask.png
-------------------------------------------
full_2014_Row12_DSC_3692_36.jpg
full_2014_Row12_DSC_3692_36_mask.png
---------------------------

In [7]:
s=saltmash.masks[0].split('_')
imgname="_".join(s[:-1])+".jpg"
imgname

'full_2014_Row10_DSC_3516_18.jpg'

In [8]:
def test_get_data_names(rootdir,data_type,masks=[],images=[]):  
        for file in os.listdir(rootdir+"train"+"/"):
            if(data_type=="masks"):
                if (file.endswith("mask.png")):
                    masks.append(file)
            else : 
                if not (file.endswith("mask.png")):
                    images.append(file)
        return masks if data_type=="masks" else images

In [9]:
#from dataloaders.utils import decode_segmap
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser()
args = parser.parse_args()
args.base_size = 513
args.crop_size = 513

cityscapes_train = SaltmarshSegmentation(args, split='train')
dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)
for ii, sample in enumerate(dataloader):
    for jj in range(sample["image"].size()[0]):
        img = sample['image'].numpy()
        gt = sample['label'].numpy()
        segmap= np.array(gt[jj]).astype(np.uint8)
        img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
        img_tmp *= (0.229, 0.224, 0.225)
        img_tmp += (0.485, 0.456, 0.406)
        img_tmp *= 255.0
        img_tmp = img_tmp.astype(np.uint8)
        plt.figure()
        plt.title('display')
        plt.subplot(211)
        plt.imshow(img_tmp)
        plt.subplot(212)
        plt.imshow(segmap)
        
        if ii == 1:
            break

usage: ipykernel_launcher.py [-h]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\Jayant\AppData\Roaming\jupyter\runtime\kernel-fc545779-e3c9-4e0e-b134-4c465f5b965a.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
