# PyThorch Model

### Imports

In [None]:
#hide
#Run once per session
!pip install fastai wwf -q --upgrade

In [None]:
from fastcore.xtras import Path

from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos

from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize

from fastai.layers import Mish
from fastai.losses import BaseLoss
from fastai.optimizer import ranger

from fastai.torch_core import tensor

from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats, SegmentationDataLoaders
from fastai.vision.learner import unet_learner
from fastai.metrics import DiceMulti

from PIL import Image
import numpy as np

from torch import nn
from torchvision.models.resnet import resnet34

import torch
import torch.nn.functional as F

### Dataloading and Transforming

In [None]:
DATA_PATH_X = 'clean/x-data'
DATA_PATH_Y = 'clean/y-data'
CODES_PATH = "clean/codes.txt"
BATCH_SIZE = 4
IMG_SIZE = (128)

In [None]:
path = DATA_PATH_X
fnames = get_image_files(path)
def label_func(x): return DATA_PATH_Y/f'{x.stem}{x.suffix}'
codes = np.loadtxt(CODES_PATH, dtype=str)
    
dls = SegmentationDataLoaders.from_label_func(
                                        path, fnames, label_func, codes=codes,  
                                        valid_pct=0.2, seed=11, bs=BATCH_SIZE, 
                                        batch_tfms=[*aug_transforms(size=IMG_SIZE), 
                                        Normalize.from_stats(*imagenet_stats)])

In [None]:
dls.show_batch(max_n=9, figsize=(8,6))

### DL Model

#### Model Creation

In [None]:
opt = ranger
metric = DiceMulti

In [None]:
learn = unet_learner(dls, resnet34, metrics=metric, self_attention=True, act_cls=Mish, opt_func=opt)

In [None]:
learn.summary()

#### Model training

In [None]:
learn.lr_find()

In [None]:
lr = 1e-3

In [None]:
learn.fit_flat_cos(10, slice(lr))

In [None]:
learn.save('stage-1')
learn.load('stage-1');

In [None]:
learn.show_results(max_n=4, figsize=(12,6))

### Model Testing

In [None]:
TEST_PATH_X = 'test/x-data'
TEST_PATH_Y = 'test/y-data'
TEST_PATHS = (TEST_PATH_X, TEST_PATH_Y) 

In [None]:
def model_tester(mdl, paths):
    pass