In [1]:
import sys
sys.path.append("../../fastai/")

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from fastai.conv_learner import *
from fastai.dataset import *
#from fastai.models.resnet import vgg_resnet50
import torchvision
from torchvision import datasets, models, transforms, utils
from model_summary import *

import json

In [4]:
torch.cuda.set_device(0)

In [5]:
torch.backends.cudnn.benchmark=True

## Data

In [6]:
PATH = Path('../CSV/')
train_FN = 'oxford_pet_train.csv'
valid_FN = 'oxford_pet_valid.csv'
train_csv = pd.read_csv(PATH/train_FN)
valid_csv = pd.read_csv(PATH/valid_FN)

In [7]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax

In [8]:
TRAIN_DN = '../Data/oxford_pets/sparse_images/'
MASKS_DN = '../Data/oxford_pets/sparse_masks/'
sz = 128
bs = 64
nw = 16

In [9]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

In [10]:
x_names = np.array([Path(TRAIN_DN)/o for o in train_csv['name']]+[Path(TRAIN_DN)/o for o in valid_csv['name']])
y_names = np.array([Path(MASKS_DN)/o for o in train_csv['name']]+[Path(MASKS_DN)/o for o in valid_csv['name']])

In [11]:
len(x_names)

2999

In [12]:
val_idxs = np.random.choice(len(x_names),len(x_names)//4,replace=False)#list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)

In [13]:
aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),
            RandomFlip(tfm_y=TfmType.CLASS),
            RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)]

In [14]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm

In [15]:
x,y = next(iter(md.trn_dl))

In [16]:
x.shape,y.shape

(torch.Size([64, 3, 128, 128]), torch.Size([64, 128, 128]))

## Simple upsample

In [17]:
# f = vgg16
# cut,lr_cut = model_meta[f]

# def get_base():
#     layers = cut_model(f(True), cut)
#     return nn.Sequential(*layers)

# def dice(pred, targs):
#     pred = (pred>0).float()
#     return 2. * (pred*targs).sum() / (pred+targs).sum()

# class StdUpsample(nn.Module):
#     def __init__(self, nin, nout):
#         super().__init__()
#         self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
#         self.bn = nn.BatchNorm2d(nout)
        
#     def forward(self, x): return self.bn(F.relu(self.conv(x)))

# class Upsample34(nn.Module):
#     def __init__(self, rn):
#         super().__init__()
#         self.rn = rn
#         self.features = nn.Sequential(
#             rn, nn.ReLU(),
#             StdUpsample(512,256),
#             StdUpsample(256,256),
#             StdUpsample(256,256),
#             StdUpsample(256,256),
#             nn.ConvTranspose2d(256, 1, 2, stride=2))
        
#     def forward(self,x): return self.features(x)[:,0]

# class UpsampleModel():
#     def __init__(self,model,name='upsample'):
#         self.model,self.name = model,name

#     def get_layer_groups(self, precompute):
#         lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
#         return lgs + [children(self.model.features)[1:]]

# m_base = get_base()

# m = to_gpu(Upsample34(m_base))
# model = UpsampleModel(m)

In [24]:
def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

In [None]:
class StdUpsample(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)
        
    def forward(self, x): return self.bn(F.relu(self.conv(x)))
    
class Upsample34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.features = nn.Sequential(
            rn,
            StdUpsample(512,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            nn.ConvTranspose2d(256, 1, 2, stride=2))
        
    def forward(self,x): return self.features(x)
       

v = models.vgg16_bn(pretrained=True)
v1 = nn.Sequential(*list(v.children())[:-1])

m_base = v1
m = Upsample34(m_base)
model = m

In [18]:
class VGG_UNet(nn.Module):
    def __init__(self,base):
        super().__init__()
        self.base = base
        self.conv1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv2 = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.conv3 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        self.bn = nn.BatchNorm2d(256)
        
#         self.features = nn.Sequential(
#             base,
#             StdUpsample(512,256),
#             StdUpsample(256,256),
#             StdUpsample(256,256),
#             StdUpsample(256,256),
#             nn.ConvTranspose2d(256, 1, 2, stride=2))
        
    def forward(self,x):
        x = self.base(x)
        x = self.bn(F.relu(self.conv1(x)))
        x = self.bn(F.relu(self.conv2(x)))
        x = self.bn(F.relu(self.conv2(x)))
        x = self.bn(F.relu(self.conv2(x)))
        x = self.conv3(x)
        x = x.squeeze()
        return x

v = models.vgg16_bn(pretrained=True)
v1 = nn.Sequential(*list(v.children())[:-1])
model = VGG_UNet(v1)

In [None]:
learn = ConvLearner.from_model_data(model,md)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.summary()

In [None]:
learn.unfreeze()

In [None]:
learn.fit(1e-2,5, wds=1e-7, cycle_len=5,use_clr=(20,8),best_save_name='unet_oxford')

In [None]:
learn.save('tmp')

In [None]:
learn.load('tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs,1,cycle_len=4,use_clr=(20,8))

In [None]:
learn.save('128')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

## U-net (ish)

In [19]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [20]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))

In [21]:
class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[0][i]) for i in [12,22,32,42]]
        self.up1 = UnetBlock(512,512,256)
        self.up2 = UnetBlock(256,512,256)
        self.up3 = UnetBlock(256,256,256)
        self.up4 = UnetBlock(256,128,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()

In [22]:
m = Unet34(v1)

In [25]:
learn = ConvLearner.from_model_data(m,md)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.summary()

In [None]:
[o.features.size() for o in m.sfs]

In [None]:
learn.freeze_to(1)

In [None]:
learn.lr_find()
learn.sched.plot()

In [None]:
lr=4e-2
wd=1e-7

lrs = np.array([lr/100,lr/10,lr])

In [26]:
learn.fit(1e-2,1,wds=1e-7,cycle_len=8,use_clr=(5,8))

HBox(children=(IntProgress(value=0, description='Epoch', max=8, style=ProgressStyle(description_width='initial…

epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.325353   11.860334  0.428252   0.510133  
    1      0.256698   0.324426   0.870623   0.808468       
    2      0.214785   0.189345   0.912013   0.876879       
    3      0.187506   0.183517   0.907994   0.873609       
    4      0.168405   0.165564   0.916091   0.890073       
    5      0.151413   0.151869   0.925216   0.902037       
    6      0.136981   0.147264   0.925288   0.903854       
    7      0.127521   0.143222   0.932      0.909226       



[array([0.14322]), 0.9319995036112451, 0.9092264362944313]

In [None]:
learn.save('128urn-tmp')

In [None]:
learn.load('128urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs/4, 1, wds=wd, cycle_len=20,use_clr=(20,10))

In [None]:
learn.save('128urn-0')

In [None]:
learn.load('128urn-0')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

In [None]:
m.close()

## 512x512

In [None]:
sz=512
bs=16

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)
denorm = md.trn_ds.denorm

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.freeze_to(1)

In [None]:
learn.load('128urn-0')

In [None]:
learn.fit(lr,1,wds=wd, cycle_len=5,use_clr=(5,5))

In [None]:
learn.save('512urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.load('512urn-tmp')

In [None]:
learn.fit(lrs/4,1,wds=wd, cycle_len=8,use_clr=(20,8))

In [None]:
learn.save('512urn')

In [None]:
learn.load('512urn')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

In [None]:
m.close()

## 1024x1024

In [None]:
sz=1024
bs=4

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.load('512urn')

In [None]:
learn.freeze_to(1)

In [None]:
learn.fit(lr,1, wds=wd, cycle_len=2,use_clr=(5,4))

In [None]:
learn.save('1024urn-tmp')

In [None]:
learn.load('1024urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
lrs = np.array([lr/200,lr/30,lr])

In [None]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))

In [None]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))

In [None]:
learn.sched.plot_loss()

In [None]:
learn.save('1024urn')

In [None]:
learn.load('1024urn')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);