In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

from fastai import *
from fastai.vision import *
from superRes.generators import *
from superRes.critics import *
from superRes.dataset import *
from superRes.loss import *
from superRes.save import *
from superRes.fid_loss import *
from superRes.ssim import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile
from pathlib import Path

import torchvision
import geffnet # efficient/ mobile net

In [None]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_databunch(sz=sz, bs=bs, crappy_path=path_lowRes, 
                         good_path=path_fullRes, 
                         random_seed=None, keep_pct=keep_pct)

def get_DIV2k_data(pLow, bs:int, sz:int):
    """Given the path of low resolution images with a proper suffix
       returns a databunch
    """
    suffixes = {"dataset/DIV2K_train_LR_x8": "x8",
                "dataset/DIV2K_train_LR_difficult":"x4d", 
                "dataset/DIV2K_train_LR_mild":"x4m"}
    lowResSuffix = suffixes[str(pLow)]
    src = ImageImageList.from_folder(pLow).split_by_idxs(train_idx=list(range(0,800)), valid_idx=list(range(800,900)))
    
    data = (src.label_from_func(lambda x: path_fullRes/(x.name).replace(lowResSuffix, '')).transform(
            get_transforms(
                flip_vert=True,
                max_rotate=30,
                max_zoom=3.,
                max_lighting=.4,
                max_warp=.4,
                p_affine=.85
            ),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

def create_training_images(fn, i, p_hr, p_lr, size):
    """Create low quality images from folder p_hr in p_lr"""
    dest = p_lr/fn.relative_to(p_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, size, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60)

In [None]:
path = Path('./dataset/')

path_fullRes = path/'DIV2K_train_HR'
path_lowRes_diff = path/'DIV2K_train_LR_difficult' # suffix "x4d" ~300px
path_lowRes_mild = path/'DIV2K_train_LR_mild' # suffix "x4m" ~300px
path_lowRes_x8 = path/'DIV2K_train_LR_x8' # suffix "x8" ~150px


proj_id = 'unet_superRes_mobilenetV3_FID'

gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

nf_factor = 2
pct_start = 1e-8

In [None]:
bs=25
sz=128
lr = 1e-3
wd = 1e-3
keep_pct=1.0
epochs = 10

In [None]:
data_gen = get_DIV2k_data(path_lowRes_x8, bs=bs, sz=sz)

In [None]:
res = models.resnet34

print(res.__module__)
help(res)

In [None]:
shuffle = models.shufflenet_v2_x2_0 #non è supportato da unet_learner
print(type(shuffle))
# dir(shuffle)

In [None]:
efficient = geffnet.efficientnet_b5
print(type(efficient))
help(efficient)

In [None]:
mobile = geffnet.mobilenetv3_100
print(type(mobile))
help(mobile)

In [None]:
from efficientnet_pytorch import EfficientNet

def get_EfficientNet(pretrained=False, **kwargs):
    return EfficientNet.from_pretrained('efficientnet-b2', num_classes=data_gen.c)

effNet = get_EfficientNet; effNet

In [None]:
mobileV2 = fastai.vision.models.resnet34

In [None]:
# help(cnn_learner)

In [None]:
# cnn = cnn_learner(data=data_gen, base_arch=res)

In [None]:
# cnn.summary()

In [None]:
learn_gen = gen_learner_wide(data=data_gen, arch=mobileV2, gen_loss=FeatureLoss(), nf_factor=3)

learn_gen.summary()