# Test Dataloader in dataset.py

In [1]:
# install packages
import os
import json
from torch.utils.data import Dataset
from torchvision.transforms import Compose, Resize, ToTensor
from PIL import Image

Functions from earth_obs_seg (Bjorn)

In [None]:
class SharkBody(Dataset):

    def __init__(self, cfg, split='train'):
        '''
            Constructor. Here, we collect and index the dataset inputs and
            labels.
        '''
        self.data_root = cfg['data_root']
        self.split = split
        # Transforms. Here's where we could add data augmentation 
        #  For now, we just resize the images to the same dimensions...and convert them to torch.Tensor.
        #  For other transformations see Björn's lecture on August 11 or 
        self.transform = Compose([              
            Resize((cfg['image_size'])),        
            ToTensor()                          
        ])
        
        # index data into list
        self.data = []

        # load annotation file
        annoPath = os.path.join(
            self.data_root,
            'eccv_18_annotation_files',
            'train_annotations.json' if self.split=='train' else 'cis_val_annotations.json'
        )
        meta = json.load(open(annoPath, 'r'))

        # enable filename lookup. Creates image IDs and assigns each ID one filename. 
        #  If your original images have multiple detections per image, this code assumes
        #  that you've saved each detection as one image that is cropped to the size of the
        #  detection, e.g., via megadetector.
        images = dict([[i['id'], i['file_name']] for i in meta['images']])
        # create custom indices for each category that start at zero. Note: if you have already
        #  had indices for each category, they might not match the new indices.
        labels = dict([[c['id'], idx] for idx, c in enumerate(meta['categories'])])
        
        # since we're doing classification, we're just taking the first annotation per image and drop the rest
        images_covered = set()      # all those images for which we have already assigned a label
        for anno in meta['annotations']:
            imgID = anno['image_id']
            if imgID in images_covered:
                continue
            
            # append image-label tuple to data
            imgFileName = images[imgID]
            label = anno['category_id']
            labelIndex = labels[label]
            self.data.append([imgFileName, labelIndex])
            images_covered.add(imgID)       # make sure image is only added once to dataset
    

    def __len__(self):
        '''
            Returns the length of the dataset.
        '''
        return len(self.data)

    
    def __getitem__(self, idx):
        '''
            Returns a single data point at given idx.
            Here's where we actually load the image.
        '''
        image_name, label = self.data[idx]              # see line 57 above where we added these two items to the self.data list

        # load image
        image_path = os.path.join(self.data_root, 'eccv_18_all_images_sm', image_name)
        img = Image.open(image_path).convert('RGB')     # the ".convert" makes sure we always get three bands in Red, Green, Blue order

        # transform: see lines 31ff above where we define our transformations
        img_tensor = self.transform(img)

        return img_tensor, label1

In [None]:
# test functions 
print(len(shark_dataset)) # verify len function
print(shark_dataset.get_mask(3)) # verify get mask function
shark_dataset.plot_mask(3) # verify plot mask function