In [1]:
import sys
sys.path.append("../../fastai/")

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from fastai.conv_learner import *
from fastai.dataset import *
from fastai.models.resnet import vgg_resnet50

import json

In [4]:
torch.cuda.set_device(0)

In [5]:
torch.backends.cudnn.benchmark=True

## Data

In [6]:
PATH = Path('../CSV/')
train_FN = 'gain_train.csv'
valid_FN = 'gain_valid.csv'
train_csv = pd.read_csv(PATH/train_FN)
valid_csv = pd.read_csv(PATH/valid_FN)

In [7]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax

In [8]:
TRAIN_DN = '../Data/CBIS-DDSM_classification_orient/'
MASKS_DN = '../Data/masks_orient/'
sz = 128
bs = 64
nw = 16

In [9]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

In [10]:
x_names = np.array([Path(TRAIN_DN)/o for o in train_csv['name']]+[Path(TRAIN_DN)/o for o in valid_csv['name']])
y_names = np.array([Path(MASKS_DN)/o.replace('.j','_mask.j') for o in train_csv['name']]+[Path(MASKS_DN)/o.replace('.j','_mask.j') for o in valid_csv['name']])

In [12]:
y_names

array([PosixPath('../Data/masks_orient/Calc-Training_P_00858_LEFT_MLO_mask.jpg'),
       PosixPath('../Data/masks_orient/Calc-Training_P_00078_LEFT_MLO_mask.jpg'),
       PosixPath('../Data/masks_orient/Mass-Training_P_00279_LEFT_CC_mask.jpg'), ...,
       PosixPath('../Data/masks_orient/Mass-Training_P_00208_RIGHT_MLO_mask.jpg'),
       PosixPath('../Data/masks_orient/Mass-Training_P_01768_LEFT_CC_mask.jpg'),
       PosixPath('../Data/masks_orient/Mass-Training_P_00419_LEFT_CC_mask.jpg')], dtype=object)

In [13]:
val_idxs = np.random.choice(len(x_names),len(x_names)//4,replace=False)#list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)

In [14]:
aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),
            RandomFlip(tfm_y=TfmType.CLASS),
            RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)]

In [15]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm

In [None]:
x,y = next(iter(md.trn_dl))

In [None]:
x.shape,y.shape

## Simple upsample

In [16]:
f = resnet34
cut,lr_cut = model_meta[f]

In [17]:
def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)

In [18]:
def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

In [19]:
class StdUpsample(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)
        
    def forward(self, x): return self.bn(F.relu(self.conv(x)))

In [20]:
class Upsample34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.features = nn.Sequential(
            rn, nn.ReLU(),
            StdUpsample(512,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            nn.ConvTranspose2d(256, 1, 2, stride=2))
        
    def forward(self,x): return self.features(x)[:,0]

In [21]:
class UpsampleModel():
    def __init__(self,model,name='upsample'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model.features)[1:]]

In [22]:
m_base = get_base()

In [23]:
m = to_gpu(Upsample34(m_base))
models = UpsampleModel(m)

In [24]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [25]:
learn.unfreeze()

In [None]:
learn.fit(1e-2,2, wds=1e-7, cycle_len=4,use_clr=(20,8),best_save_name='unet_oxford')

HBox(children=(IntProgress(value=0, description='Epoch', max=8, style=ProgressStyle(description_width='initialâ€¦

  0%|          | 0/29 [00:00<?, ?it/s]

In [None]:
learn.save('tmp')

In [None]:
learn.load('tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs,1,cycle_len=4,use_clr=(20,8))

In [None]:
learn.save('128')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

## U-net (ish)

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))

In [None]:
class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()

In [None]:
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.summary()

In [None]:
[o.features.size() for o in m.sfs]

In [None]:
learn.freeze_to(1)

In [None]:
learn.lr_find()
learn.sched.plot()

In [None]:
lr=4e-2
wd=1e-7

lrs = np.array([lr/100,lr/10,lr])

In [None]:
learn.fit(1e-2,1,wds=1e-7,cycle_len=8,use_clr=(5,8))

In [None]:
learn.save('128urn-tmp')

In [None]:
learn.load('128urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs/4, 1, wds=wd, cycle_len=20,use_clr=(20,10))

In [None]:
learn.save('128urn-0')

In [None]:
learn.load('128urn-0')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

In [None]:
m.close()

## 512x512

In [None]:
sz=512
bs=16

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)
denorm = md.trn_ds.denorm

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.freeze_to(1)

In [None]:
learn.load('128urn-0')

In [None]:
learn.fit(lr,1,wds=wd, cycle_len=5,use_clr=(5,5))

In [None]:
learn.save('512urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.load('512urn-tmp')

In [None]:
learn.fit(lrs/4,1,wds=wd, cycle_len=8,use_clr=(20,8))

In [None]:
learn.save('512urn')

In [None]:
learn.load('512urn')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);

In [None]:
m.close()

## 1024x1024

In [None]:
sz=1024
bs=4

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
learn.load('512urn')

In [None]:
learn.freeze_to(1)

In [None]:
learn.fit(lr,1, wds=wd, cycle_len=2,use_clr=(5,4))

In [None]:
learn.save('1024urn-tmp')

In [None]:
learn.load('1024urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
lrs = np.array([lr/200,lr/30,lr])

In [None]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))

In [None]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))

In [None]:
learn.sched.plot_loss()

In [None]:
learn.save('1024urn')

In [None]:
learn.load('1024urn')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);