In [1]:
from fastai import *
from fastai.vision import *
from pathlib import Path
import PIL
import cv2

from utils import FocalLoss, f1
from wrn4 import *

In [2]:
MASKS = 'train.csv'

PATH = Path('./')

TRAIN = Path('train/')

# TRAIN64 = Path('train64/')
TRAIN128 = Path('train128/')
# TRAIN256 = Path('train256/')
# TRAIN512 = Path('train512/')
# TEST256 = Path('test256/')

SAMPLE = Path('sample_submission.csv')

seg = pd.read_csv(PATH/MASKS)
sample_sub = pd.read_csv(PATH/SAMPLE)
train_names = list(seg.Id.values)
test_names = list(sample_sub.Id.values)

In [3]:
def open_image4d(fn:PathOrStr)->Image:
    "Return `Image` object created from image in file `fn`."
    x = PIL.Image.open(fn)
    return Image(pil2tensor(x).float().div_(255))

In [4]:
fname = train_names[0]

In [5]:
%time im = open_image4d(TRAIN128/(fname+'.png'))

CPU times: user 10.3 ms, sys: 3.82 ms, total: 14.1 ms
Wall time: 11.8 ms


In [6]:
im.shape

torch.Size([4, 128, 128])

In [7]:
class Image4C_ds(ImageMultiDataset):
    def __init__(self, fns:FilePathList, labels:ImgLabels, classes:Optional[Collection[Any]]=None):
        super().__init__(fns, labels, classes)
    def __getitem__(self,i:int)->Tuple[Image, np.ndarray]: return open_image4d(self.x[i]), self.encode(self.y[i])
    def _get_x(self,i): return open_image4d(self.x[i])

In [8]:
from fastai.data_block import _df_to_fns_labels
fnames, labels = _df_to_fns_labels(seg, suffix='.png', label_delim=' ', fn_col=0, label_col=1)
test_fnames, _ = _df_to_fns_labels(sample_sub, suffix='.png', fn_col=0, label_col=1)

In [9]:
classes = uniqueify(np.concatenate(labels))

In [10]:
stats = ([0.485, 0.456, 0.406, 0.406], [0.229, 0.224, 0.225, 0.225])
norm,denorm = normalize_funcs(*stats)

In [11]:
tfms = get_transforms(do_flip=True, flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.)

In [12]:
def get_data(sz=64, bs=64):
    test_ds=None
    if sz==64: folder=TRAIN64
    if sz==128: folder=TRAIN128
    if sz==256: 
        folder=TRAIN256
        test_ds = Image4C_ds.from_single_folder(test_fnames, classes)
        
    train, val = Image4C_ds.from_folder(PATH, folder, fnames, labels, valid_pct=0.2, classes=classes)
    return ImageDataBunch.create(train_ds=train, valid_ds=val, test_ds=test_ds,
                                 ds_tfms=tfms, tfms=norm, bs=bs, size=sz)

In [13]:
arch = wrn_22_4()

In [26]:
def get_learner(data, loss=False, fp16=False):
    learn = Learner(data, arch , metrics=[accuracy_thresh, f1])
    if loss: learn.loss_func=FocalLoss()
    if fp16: learn.to_fp16();
    return learn

In [39]:
data = get_data(128,32)
learn = get_learner(data, True, True)

In [40]:
learn.load('wrn4_128')

In [None]:
%time learn.fit_one_cycle(8, 1e-2/7)

epoch,train_loss,valid_loss,accuracy_thresh,f1
1,0.899562,0.886406,0.959370,0.496534
2,0.894591,0.873120,0.959879,0.495570
3,0.924635,0.898049,0.958496,0.460149
,,,,


In [21]:
learn.save('wrn4_128')