In [1]:
import sys
sys.path.append("../../fastai/")

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from fastai.conv_learner import *
from fastai.dataset import *
#from fastai.models.resnet import vgg_resnet50
import torchvision
from torchvision import datasets, models, transforms, utils
from model_summary import *

import json

In [4]:
torch.cuda.set_device(0)

In [5]:
torch.backends.cudnn.benchmark=True

## Data

In [6]:
PATH = Path('../CSV/')
train_FN = 'oxford_pet_train.csv'
valid_FN = 'oxford_pet_valid.csv'
train_csv = pd.read_csv(PATH/train_FN)
valid_csv = pd.read_csv(PATH/valid_FN)

In [7]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax

In [8]:
TRAIN_DN = '../Data/oxford_pets/sparse_images/'
MASKS_DN = '../Data/oxford_pets/sparse_masks/'
sz = 128
bs = 64
nw = 16

In [9]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

In [42]:
x_names = np.array([Path(TRAIN_DN)/o for o in train_csv['name']]+[Path(TRAIN_DN)/o for o in valid_csv['name']])
y_names = np.array([Path(MASKS_DN)/o for o in train_csv['name']]+[Path(MASKS_DN)/o for o in valid_csv['name']])

In [43]:
len(x_names)

2999

In [44]:
val_idxs = np.arange(1999,len(x_names))

In [45]:
#val_idxs = np.random.choice(len(x_names),len(x_names)//4,replace=False)#list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)

In [46]:
aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),
            RandomFlip(tfm_y=TfmType.CLASS),
            RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)]

In [47]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm

In [48]:
x,y = next(iter(md.trn_dl))

In [49]:
x.shape,y.shape

(torch.Size([64, 3, 128, 128]), torch.Size([64, 128, 128]))

## U-net (ish)

In [50]:
def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

In [51]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))

class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[0][i]) for i in [12,22,32,42]]
        self.up1 = UnetBlock(512,512,256)
        self.up2 = UnetBlock(256,512,256)
        self.up3 = UnetBlock(256,256,256)
        self.up4 = UnetBlock(256,128,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()

v = models.vgg16_bn(pretrained=True)
v1 = nn.Sequential(*list(v.children())[:-1])
m = Unet34(v1)

In [52]:
learn = ConvLearner.from_model_data(m,md)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [53]:
learn.unfreeze()

In [54]:
learn.fit(1e-2,2,cycle_len=5,best_save_name='oxford_fastai')

HBox(children=(IntProgress(value=0, description='Epoch', max=10, style=ProgressStyle(description_width='initia…

epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.485965   5.10526    0.645106   0.596308  
    1      0.378305   0.454411   0.818197   0.732497       
    2      0.32249    0.323882   0.829137   0.754386       
    3      0.285301   0.268931   0.869204   0.816265       
    4      0.262679   0.253006   0.877925   0.830482       
    5      0.277865   0.31778    0.851556   0.788082       
    6      0.266159   0.248124   0.881306   0.833142       
    7      0.247248   0.242843   0.888238   0.838351       
    8      0.224514   0.210829   0.897181   0.859887       
    9      0.20553    0.204677   0.902174   0.866159       


[array([0.20468]), 0.9021736459732056, 0.8661589056094138]

In [None]:
learn.save('128urn-tmp')

In [None]:
learn.load('128urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs/4, 1, wds=wd, cycle_len=20,use_clr=(20,10))

In [None]:
learn.save('128urn-0')

In [None]:
learn.load('128urn-0')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

In [None]:
m.close()

## 512x512

In [None]:
sz=512
bs=16

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)
denorm = md.trn_ds.denorm

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.freeze_to(1)

In [None]:
learn.load('128urn-0')

In [None]:
learn.fit(lr,1,wds=wd, cycle_len=5,use_clr=(5,5))

In [None]:
learn.save('512urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.load('512urn-tmp')

In [None]:
learn.fit(lrs/4,1,wds=wd, cycle_len=8,use_clr=(20,8))

In [None]:
learn.save('512urn')

In [None]:
learn.load('512urn')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

In [None]:
m.close()

## 1024x1024

In [None]:
sz=1024
bs=4

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.load('512urn')

In [None]:
learn.freeze_to(1)

In [None]:
learn.fit(lr,1, wds=wd, cycle_len=2,use_clr=(5,4))

In [None]:
learn.save('1024urn-tmp')

In [None]:
learn.load('1024urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
lrs = np.array([lr/200,lr/30,lr])

In [None]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))

In [None]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))

In [None]:
learn.sched.plot_loss()

In [None]:
learn.save('1024urn')

In [None]:
learn.load('1024urn')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);