In [1]:
from fastai import *
from fastai.vision import *
datasetname = 'multiframe_001'
data_path = Path('/scratch/bpho')
datasets = data_path/'datasets'
datasources = data_path/'datasources'
dataset = datasets/datasetname

hr_path = dataset/'hr'
lr_path = dataset/'lr'
lr_up_path = dataset/'lr_up'

In [2]:
class NpyRawImageList(ImageList):
    def open(self, fn):
        img_data = np.load(fn)
        return Image(tensor(img_data[None]))

class TransformableLists(ItemBase):
    def __init__(self, img_lists):
        self.img_lists = img_lists
       
    def __repr__(self):
        return f'MultiImg: {len(self.img_lists)}'
    
    @property
    def data(self):
        img_data = torch.stack([torch.stack([img.data for img in img_list]) 
                                            for img_list in self.img_lists])
        data = tensor(img_data)
        return data
    
    def apply_tfms(self, tfms, **kwargs):
        first_time = True
        save_img_lists = []
        for img_list in self.img_lists:
            save_img_list = []
            for img in img_list:
                new_img = img.apply_tfms(tfms, do_resolve=first_time, **kwargs)
                first_time = False
                save_img_list.append(new_img)
            save_img_lists.append(save_img_list)
        self.img_lists = save_img_lists
        return self
    
class MultiImageDataBunch(ImageDataBunch):
    def batch_stats(self, funcs:Collection[Callable]=None)->Tensor:
        "Grab a batch of data and call reduction function `func` per channel"
        funcs = ifnone(funcs, [torch.mean,torch.std])
        ds_type = DatasetType.Valid if self.valid_dl else DatasetType.Train
        x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu()
        def multi_channel_view(x):
            return x.transpose(3,0).contiguous().view(x.shape[3],-1)
        return [func(multi_channel_view(x), 1) for func in funcs]

    def normalize(self, stats:Collection[Tensor]=None, do_x:bool=True, do_y:bool=False)->None:
        "Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)"
        if getattr(self,'norm',False): raise Exception('Can not call normalize twice')
        if stats is None: self.stats = self.batch_stats()
        else:             self.stats = stats
        self.norm,self.denorm = normalize_funcs(*self.stats, do_x=do_x, do_y=do_y)
        self.add_tfm(self.norm)
        return self
    
class MultiImageList(ItemList):
    _bunch = MultiImageDataBunch
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def open(self, fn):
        img_data = np.load(fn)
        img_lists = []
        for j in range(img_data.shape[0]):
            imgs = []
            for i in range(img_data.shape[1]):
                imgs.append(Image(tensor(img_data[j,i,:,:][None])))
            img_lists.append(imgs)
        return TransformableLists(img_lists)
    
    def get(self, i):
        fn = super().get(i)
        img_lists = self.open(fn)
        return img_lists    
    

class MultiToMultiImageList(MultiImageList):
    _label_cls = NpyRawImageList

In [30]:
def map_to_hr(x):
    hr_name = x.relative_to(lr_path)
    return hr_path/hr_name

def get_src(size=128, scale=1):
    if scale != 1: use_lr_path = lr_path
    else: use_lr_path = lr_up_path
        
    def map_to_hr(x):
        hr_name = x.relative_to(use_lr_path)
        return hr_path/hr_name
    
    src = (MultiToMultiImageList.from_folder(use_lr_path, extensions=['.npy'])
           .split_by_folder()
           .label_from_func(map_to_hr))
    return src

def get_data(bs, size, scale=1, max_zoom=1.):
    src = get_src(size, scale=scale)    
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    data = (src
            .transform(tfms, size=size)
            .transform_y(tfms, size=size*scale)
            .databunch(bs=bs).normalize(do_y=True))
    return data

In [31]:
bs = 24
size = 128
data = get_data(bs, size, scale=4)

In [35]:
print('cool')

cool


In [33]:
%%time
x,y = next(iter(data.train_dl))

CPU times: user 41.4 ms, sys: 690 ms, total: 732 ms
Wall time: 4.32 s


In [36]:
x[0].shape, y[0].shape

(torch.Size([2, 3, 1, 128, 128]), torch.Size([1, 512, 512]))