In [1]:
#!pip install --user torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 torchtext==0.10.0

In [2]:
#!pip install timm

### Imports

In [1]:
from fastai.vision.all import *
import tqdm

from timm import create_model
from fastai.vision.learner import _update_first_layer
import gc

In [3]:
stage = 'train'
path = Path('src')
img_path = path / stage

mask_path = path / 'masks'

df = pd.read_csv(path / f'{stage}.csv')

In [4]:
enc_path = 'models/resnet34_encoder.pkl'
class_learn_path = 'models/class_resnet34_learner.pth'
mask_learn_path =  'models/mask_resnet34_learner.pth'

In [5]:
df.tail(1)

Unnamed: 0,id,annotation,width,height,cell_type,plate_time,sample_date,sample_id,elapsed_timedelta
73584,ffdb3cc02eef,249775 2 250477 6 251180 8 251882 11 252585 12 253288 14 253992 14 254695 16 255398 17 256102 17 256805 17 257509 17 258212 17 258917 16 259621 15 260326 13 261031 11 261736 9 262442 6,704,520,cort,11h59m00s,2020-11-01,cort[debris]_D9-3_Vessel-384_Ph_4,0 days 11:59:00


### Label functions for classification & segmentation

In [6]:
def class_label_func(fn):
    #return re.search('(.*)_Phase.*',fn.stem).group(1)
    return df.loc[df.id == fn.stem, 'cell_type'].unique()[0]

In [10]:
def mask_label_func(fn):
    return mask_path / fn.name

### Use the same train/val split troughout the switch training, to get reliable metrics.

In [7]:
df['is_valid'] = False

val_ids = np.random.choice(df.id.unique(), size = int(606 * 0.2))
df.loc[df.id.isin(val_ids), 'is_valid'] = True

In [8]:
#is_val_fns = df.loc[df.is_valid == 1, 'file_name'].unique()
#is_train_fns = df.loc[df.is_valid == 0, 'file_name'].unique()

def is_valid(fn):
    return fn.stem in val_ids

In [9]:
def MySplitter():
    def _inner(o):
        train_idx = o.argwhere(is_valid, negate = True)
        val_idx = o.argwhere(is_valid)    
        return train_idx, val_idx
    return _inner

### Training Parameters

In [12]:
TASKS = ['class', 'mask']
SIZES = [(224, 303), (336, 455), (448, 606), None] #None equals native image size of (520, 704)
DEBUG = True
BS_MULTS = [1, 1, 1, 1]#[40, 18, 10, 6]

### Shrink dataset for debuging / prototyping

In [13]:
items = get_image_files(img_path)
print(len(items))
unique_files = df.file_name.unique()
if DEBUG:
    unique_files = np.random.choice(unique_files, 1000)
items = L([fn for fn in items if fn.stem in unique_files])
print(len(items))
fn = items[0]
img_rgb = load_image(fn)

### Build Dataloaders

In [15]:
def get_all(path):
    items = get_image_files(path)
    if DEBUG:
        items = np.random.choice(items, 500)
    return L(*items)

In [61]:
def get_dls(src_path, task, size, bs_mult = 1):
    if task == 'mask':
        print('Build dataloaders for segmentation.')
        TaskBlock = (MaskBlock(codes = np.array(['bg', 'cell'])),)
        label_func = mask_label_func
        bs = 1 * bs_mult
        get_items = get_all #get_annotated
        
    if task == 'class':
        print('Build dataloaders for classification.')
        TaskBlock = (CategoryBlock(),)
        label_func = class_label_func
        bs = int(8 * bs_mult)
        get_items = get_all
        
    db = DataBlock(
        blocks = (ImageBlock(), *TaskBlock),
        get_items = get_items,
        get_y = label_func,
        splitter = MySplitter(),
        #item_tfms = Resize((448, 606)),
        batch_tfms = [*aug_transforms(size = size, 
                                    flip_vert = True, 
                                    max_rotate = 180., 
                                    max_warp = 0.
                                   ),
                      Normalize.from_stats(*imagenet_stats)
                     ],
        n_inp = 1
    )
    print(f'Using batch size of: {bs} and image size of: {size}')
    return db.dataloaders(src_path, bs = bs)

### Build Model

In [62]:
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): 
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): 
        return cut(model)
    else: 
        raise NamedError("cut must be either integer or function")

### Workaround to change the classification learner head from 9 classes (pretraining) to 3. 

In [63]:
### Create same model as used in pretraining
#body = create_timm_body('resnet34', pretrained = False)
#n_out = 9
#nf = num_features_model(nn.Sequential(*body.children()))
#head = create_head(nf, n_out, ps = 0.5)
#model = nn.Sequential(body, head)

### Create new learner from that model
#dls = get_dls(img_path, 'class', (224, 303), bs_mult = 1)
#learn = Learner(
#    dls,
#    model,
#    metrics = accuracy,
#    splitter = default_split,
#).to_fp16()

### Load weights from pretraining
#learn.unfreeze()
#learn.load('class_resnet34_learner')

### Substitute last layer by one with needed output features
#learn.model[-1][-1] = nn.Linear(in_features = 512, out_features = 3, bias = False)
#apply_init(model[-1][-1], nn.init.kaiming_normal_)

### Save model weights
#learn.save('class_resnet34_learner_sartorius', with_opt = False)

### Build Learner

In [64]:
def get_learner(task, dls, arch, learner_ex, encoder_ex = False, with_opt = False):

    body = create_timm_body(arch, pretrained = True)
    n_out = get_c(dls)

    if task == 'mask':
        img_size = dls.one_batch()[0].shape[-2:]
        model = models.unet.DynamicUnet(body, n_out, img_size, self_attention = True)
        metrics = [Dice, JaccardCoeff]

    if task == 'class':
        nf = num_features_model(nn.Sequential(*body.children()))
        head = create_head(nf, n_out, ps = 0.5)
        model = nn.Sequential(body, head)
        apply_init(model[1], nn.init.kaiming_normal_)
        metrics = accuracy

    print(f'Create learner for task {task}')
    learn = Learner(
        dls,
        model,
        metrics = metrics,
        splitter = default_split
    ).to_fp16()
    
    if learner_ex[TASKS.index(task)]:
        print(f'Load pretrained model')
        #learn = load_learner(f'{task}_{arch}_learner.pkl', cpu = False)
        #learn.dls = dls
        learn.load(f'{task}_{arch}_learner', with_opt = with_opt)
    
    if encoder_ex:
        print('Load encoder from previous task')
        load_model(f'{arch}_encoder.pkl', learn.model[0], opt=None, with_opt=False)
    else:
        print('Start training from scratch.')

    return learn

## Switch Training

In [41]:
def switch_training(tasks, sizes, epochs, rounds, learner_ex, encoder_ex, lr = 1e-3, arch = 'resnet34', with_opt = False): #efficientnet_b3a
    print('Multitask training with: ', arch)
    for k, size in enumerate(sizes):
        for i in range(rounds):
            print(f'Start round no {i+1} with image size {size}')
            for task in tasks:
                if task == 'class':
                    eps = epochs[0]
                else:
                    eps = epochs[1]
                dls = get_dls(img_path, task, size, bs_mult = 1)
                learn = get_learner(task, dls, arch, learner_ex, encoder_ex, with_opt = with_opt)
                learn.fine_tune(eps, lr, cbs = CutMix(), freeze_epochs = 1)

                # Save encoder for all tasks
                encoder = get_model(learn.model)[0]
                save_model(f'{arch}_encoder.pkl', encoder, opt = None, with_opt=False)
                encoder_ex = True
                # Save learner for specific task
                #learn.export(f'{task}_{arch}_learner.pkl')
                learn.save(f'{task}_{arch}_learner')
                with_opt = True
                learner_ex[TASKS.index(task)] = True

                #Cleanup
                if (size != None) & (task == 'mask'):
                    print('Clean up')
                    del(dls)
                    del(learn)
                    torch.cuda.empty_cache()
                    gc.collect()
                else:
                    learn.export('mask_learner_final.pkl')
                print(f'Done with task: {task}')
            print(f'Finish round no: {i+1} with image size: {size}')

In [44]:
switch_training(TASKS, SIZES, epochs = [3, 7], rounds = 1, learner_ex = [True, True], encoder_ex = True)

Multitask training with:  resnet34
Start round no 1 with image size (224, 303)
Build dataloaders for classification.
Using batch size of: 8 and image size of: (224, 303)
Create learner for task class
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,1.30674,0.358232,0.82716,00:03


epoch,train_loss,valid_loss,accuracy,time
0,0.985211,0.319932,0.876543,00:04
1,0.894811,0.23717,0.925926,00:04
2,0.954591,0.281618,0.839506,00:04


Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (224, 303)
Create learner for task mask
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.181521,0.238498,0.627143,0.456816,00:29


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.122127,0.278615,0.580188,0.408637,00:33
1,0.141512,0.151199,0.689003,0.525556,00:33
2,0.118166,0.17858,0.626971,0.456633,00:33
3,0.129609,0.15981,0.686207,0.52231,00:33
4,0.102238,0.139839,0.746867,0.596,00:33
5,0.126051,0.149783,0.687869,0.524238,00:34
6,0.1159,0.193118,0.646882,0.478068,00:34


Clean up
Done with task: mask
Finish round no: 1 with image size: (224, 303)
Start round no 1 with image size (336, 455)
Build dataloaders for classification.
Using batch size of: 8 and image size of: (336, 455)
Create learner for task class
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,0.884488,0.073631,0.988506,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.837263,0.154621,0.942529,00:06
1,0.823856,0.102536,0.954023,00:06
2,0.799449,0.099561,0.954023,00:06


Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (336, 455)
Create learner for task mask
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.132957,0.198514,0.649732,0.481188,00:46


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.134799,0.167195,0.693514,0.530824,00:51
1,0.112444,0.216779,0.663613,0.496573,00:48
2,0.115242,0.367417,0.576122,0.404615,00:50
3,0.138206,0.190929,0.682216,0.517699,00:50
4,0.128169,0.177357,0.62766,0.457364,00:49
5,0.091084,0.142321,0.742219,0.590101,00:49
6,0.101161,0.146515,0.734802,0.58078,00:49


Clean up
Done with task: mask
Finish round no: 1 with image size: (336, 455)
Start round no 1 with image size (448, 606)
Build dataloaders for classification.
Using batch size of: 8 and image size of: (448, 606)
Create learner for task class
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,0.653128,0.040696,1.0,00:06


epoch,train_loss,valid_loss,accuracy,time
0,0.688476,0.059961,1.0,00:07
1,0.751672,0.020444,1.0,00:07
2,0.729517,0.034693,1.0,00:07


Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (448, 606)
Create learner for task mask
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.13029,0.159508,0.734694,0.580645,01:04


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.120728,0.127278,0.726634,0.570641,01:09
1,0.111604,0.12104,0.743338,0.591517,01:09
2,0.103367,0.122812,0.742345,0.590261,01:10
3,0.125611,0.124547,0.730576,0.575518,01:10
4,0.124807,0.119049,0.73837,0.585251,01:09
5,0.113568,0.138758,0.670115,0.50389,01:09
6,0.119942,0.117867,0.756083,0.607825,01:09


Clean up
Done with task: mask
Finish round no: 1 with image size: (448, 606)
Start round no 1 with image size None
Build dataloaders for classification.
Using batch size of: 8 and image size of: None
Create learner for task class
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,0.665901,0.038989,1.0,00:08


epoch,train_loss,valid_loss,accuracy,time
0,0.669046,0.024379,1.0,00:10
1,0.648728,0.06365,0.975904,00:10
2,0.607382,0.06739,0.975904,00:10


Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: None
Create learner for task mask
Load pretrained model
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.121166,0.125349,0.745365,0.594089,01:20


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.101037,0.129411,0.76867,0.62426,01:23
1,0.10138,0.12036,0.76343,0.617377,01:22
2,0.127924,0.125561,0.771399,0.627867,01:22
3,0.103065,0.124629,0.738253,0.585104,01:22
4,0.100548,0.146188,0.676349,0.510972,01:22
5,0.097417,0.136004,0.762022,0.615538,01:22
6,0.111194,0.132303,0.760009,0.612915,01:22


Done with task: mask
Finish round no: 1 with image size: None


## Export learners

In [30]:
task = 'class'
arch = 'resnet34'

dls = get_dls(img_path, task, (224, 303))
learn = get_learner(task, dls, arch, [True, True], False)

Build dataloaders for classification.
Using batch size of: 8 and image size of: (224, 303)
Create learner for task class
Load pretrained model
Start training from scratch.


In [31]:
learn.export('class_final.pkl')

In [32]:
task = 'mask'
arch = 'resnet34'

dls = get_dls(img_path, task, (224, 303))
learn = get_learner(task, dls, arch, [True, True], False)

Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (224, 303)
Create learner for task mask
Load pretrained model
Start training from scratch.


In [33]:
learn.export('mask_final.pkl')