In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

from fastai import *
from fastai.vision import *
from superRes.generators import *
from superRes.critics import *
from superRes.dataset import *
from superRes.loss import *
from superRes.save import *
from superRes.fid_loss import *
from superRes.ssim import *
from pathlib import Path

import torchvision
import geffnet # efficient/ mobile net

In [2]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_databunch(sz=sz, bs=bs, crappy_path=path_lowRes, 
                         good_path=path_fullRes, 
                         random_seed=None, keep_pct=keep_pct)

def get_DIV2k_data(pLow, bs:int, sz:int):
    """Given the path of low resolution images with a proper suffix
       returns a databunch
    """
    suffixes = {"dataset/DIV2K_train_LR_x8": "x8",
                "dataset/DIV2K_train_LR_difficult":"x4d", 
                "dataset/DIV2K_train_LR_mild":"x4m"}
    lowResSuffix = suffixes[str(pLow)]
    src = ImageImageList.from_folder(pLow).split_by_idxs(train_idx=list(range(0,800)), valid_idx=list(range(800,900)))
    
    data = (src.label_from_func(lambda x: path_fullRes/(x.name).replace(lowResSuffix, '')).transform(
            get_transforms(
                flip_vert=True,
                max_rotate=30,
                max_zoom=3.,
                max_lighting=.4,
                max_warp=.4,
                p_affine=.85
            ),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

def get_DIV2k_data_QF(pLow, bs:int, sz:int):
    """Given the path of low resolution images
       returns a databunch
    """
    src = ImageImageList.from_folder(pLow).split_by_idxs(train_idx=list(range(0,800)), valid_idx=list(range(800,900)))
    
    data = (src.label_from_func(lambda x: path_fullRes/(x.name.replace(".jpg", ".png"))).transform(
            get_transforms(
                max_rotate=30,
                max_lighting=.4,
                max_warp=.4
            ),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

def create_training_images(fn, i, p_hr, p_lr, size):
    """Create low quality images from folder p_hr in p_lr"""
    dest = p_lr/fn.relative_to(p_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, size, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60)

In [3]:
path = Path('./dataset/')

path_fullRes = path/'DIV2K_train_HR'
path_lowRes_diff = path/'DIV2K_train_LR_difficult' # suffix "x4d" ~300px
path_lowRes_mild = path/'DIV2K_train_LR_mild' # suffix "x4m" ~300px
path_lowRes_x8 = path/'DIV2K_train_LR_x8' # suffix "x8" ~150px

path_fullRes_patches = path/"DIV2K_train_HR_Patches"/"64px"
path_lowRes_patches  = path/"DIV2K_train_LR_Patches"/"64px"
path_lowRes_256 = path/'DIV2K_train_LR_256'
path_lowRes_512 = path/'DIV2K_train_LR_512'

proj_id = 'unet_superRes_mobilenetV3_Patches64px'

gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

nf_factor = 2
pct_start = 1e-8

In [4]:
bs=25
sz=128
lr = 1e-3
wd = 1e-3
keep_pct=1.0
epochs = 10

In [5]:
data_gen = get_DIV2k_data_QF(path_lowRes_512, bs=bs, sz=sz)

In [6]:
res = models.resnet34

print(res.__module__)
help(res)

torchvision.models.resnet
Help on function resnet34 in module torchvision.models.resnet:

resnet34(pretrained=False, progress=True, **kwargs)
    ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr



In [7]:
shuffle = models.shufflenet_v2_x2_0 #non è supportato da unet_learner
print(type(shuffle))
# dir(shuffle)

<class 'function'>


In [8]:
efficient = geffnet.efficientnet_b5
print(type(efficient))
help(efficient)

<class 'function'>
Help on function efficientnet_b5 in module geffnet.gen_efficientnet:

efficientnet_b5(pretrained=False, **kwargs)
    EfficientNet-B5



In [9]:
mobile = geffnet.mobilenetv3_rw
print(type(mobile))
help(mobile)

<class 'function'>
Help on function mobilenetv3_rw in module geffnet.mobilenetv3:

mobilenetv3_rw(pretrained=False, **kwargs)
    MobileNet-V3 RW
    Attn: See note in gen function for this variant.



In [10]:
from efficientnet_pytorch import EfficientNet

def get_EfficientNet(pretrained=False, **kwargs):
    return EfficientNet.from_pretrained('efficientnet-b2', num_classes=data_gen.c)

effNet = get_EfficientNet; effNet

<function __main__.get_EfficientNet(pretrained=False, **kwargs)>

In [11]:
# help(cnn_learner)

In [12]:
# cnn = cnn_learner(data=data_gen, base_arch=res)

In [13]:
# cnn.summary()

In [14]:
loss_func = msssim

In [19]:
learn_gen = gen_learner_wide(data=data_gen, arch=mobile, gen_loss=msssim, nf_factor=2)

learn_gen.summary()

DynamicUnetWide
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [16, 64, 64]         432        False     
______________________________________________________________________
BatchNorm2d          [16, 64, 64]         32         True      
______________________________________________________________________
HardSwish            [16, 64, 64]         0          False     
______________________________________________________________________
Conv2d               [16, 64, 64]         144        False     
______________________________________________________________________
BatchNorm2d          [16, 64, 64]         32         True      
______________________________________________________________________
ReLU                 [16, 64, 64]         0          False     
______________________________________________________________________
Identity             [16, 64, 64]         0          False     
______________________________________________

In [16]:
def is_pathlike(x:Any)->bool: return isinstance(x, (str,Path))

In [17]:
is_pathlike("unet_superRes_mobilenetV3_Patches64px_gen_64px_0")

True

In [23]:
learn_gen.path

PosixPath('dataset/DIV2K_train_LR_512')

In [24]:
learn_gen.load("/data/students_home/fmameli/repos/SuperRes/dataset/DIV2K_train_LR_Patches/64px/models/unet_superRes_mobilenetV3_Patches64px_gen_512px_0")

Learner(data=ImageDataBunch;

Train: LabelList (800 items)
x: ImageImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
Path: dataset/DIV2K_train_LR_512;

Valid: LabelList (100 items)
x: ImageImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
Path: dataset/DIV2K_train_LR_512;

Test: None, model=DynamicUnetWide(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): HardSwish()
      (3): Sequential(
        (0): Sequential(
          (0): DepthwiseSeparableConv(
           