In [1]:
import json
import os

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


In [2]:
class BaseDataset(Dataset):
    def __init__(self, image_dir, ann_path, split, transform=None):
        self.image_dir = image_dir
        self.ann_path = ann_path
        self.split = split
        if self.split == 'train':
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))])
        self.ann = json.loads(open(self.ann_path, 'r').read())
        self.examples = self.ann[self.split]

    def __len__(self):
        return len(self.examples)


class IuxrayMultiImageDataset(BaseDataset):
    def __getitem__(self, idx):
        example = self.examples[idx]
        image_id = example['id']
        report = example['report']
        image_path = example['image_path']
        images = [Image.open(os.path.join(self.image_dir, i)).convert('RGB') for i in image_path]
        if self.transform is not None:
            images = [self.transform(i) for i in images]
        sample = (image_id, images, report)
        return sample

In [5]:
image_dir = '/restricted/projectnb/batmanlab/chyuwang/dataset/iu_xray/images/'
ann_path = '/restricted/projectnb/batmanlab/chyuwang/dataset/iu_xray/annotation.json'
ds = IuxrayMultiImageDataset(image_dir,ann_path,'train',True)