In [None]:
import fastai
from fastai.vision.all import *
from fastdownload import FastDownload
from models.CycleGan import GeneratorUNet, DiscriminatorUNet
from torchsummary import summary
from IPython.core.debugger import set_trace

In [None]:
path = Path('data')
Path.BASE_PATH = path

In [None]:
if not (path/'archive').exists():
    loader = FastDownload(base=path.name, data='extracted', module=fastai.data)
    loader.get("https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip")

In [None]:
generator = GeneratorUNet(input_shape=[3, 128, 128],
                          filters=[32, 64, 128, 256],
                          ds_kernel_size=5,
                          us_kernel_size=5,
                          ds_stride=2,
                          us_stride=1
                         )

In [None]:
summary(generator, (3, 128, 128), depth=4)

In [None]:
discriminator = DiscriminatorUNet(input_shape=[3, 128, 128],
                                  filters=[32, 64, 128, 256],
                                  kernel_size=5
                                 )

In [None]:
summary(discriminator, (3, 128, 128))

In [None]:
train_path = path/'extracted/apple2orange'

In [None]:
# Idea from Siamese tutorial: https://docs.fast.ai/tutorial.siamese.html
class ImageTuple(fastuple):
    
    @classmethod
    def create(cls, paths):
        return cls(tuple(PILImage.create(p) for p in paths))
    
    def show(self, ctx=None, **kwargs): 
        t1,t2 = self
        if not isinstance(t1, Tensor) or not isinstance(t2, Tensor) or t1.shape != t2.shape: return ctx
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([t1,line,t2], dim=2), ctx=ctx, **kwargs)

    
def ImageTupleBlock():
    return TransformBlock(type_tfms=ImageTuple.create, batch_tfms=IntToFloatTensor)

In [None]:
def get_image_tuples(path: Path, domain_a: str = 'trainA', domain_b: str = 'trainB'):
    imgs_a = get_image_files(path/domain_a)
    imgs_b = get_image_files(path/domain_b)
    
    return list(map(lambda x: list(x + ('apple_orange',)), zip(imgs_a, imgs_b)))

def get_x(tup):
    return tup[:2]

def get_y(tup):
    return tup[2]

In [None]:
datablock = DataBlock(blocks=(ImageTupleBlock, CategoryBlock),
                      get_items=get_image_tuples,
                      get_x=get_x,
                      get_y=get_y,
                      splitter=RandomSplitter(seed=42)
                     )

In [None]:
dataloaders = datablock.dataloaders(train_path)

In [None]:
dataloaders.show_batch(max_n=4)