In [1]:
#!pip install --user torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 torchtext==0.10.0

In [None]:
#!pip install timm

### Imports

In [1]:
from fastai.vision.all import *
import tqdm


from timm import create_model
from fastai.vision.learner import _update_first_layer
import gc

In [2]:
path = Path('src/LIVECell_dataset_2021')
img_path = path / 'images/png'
mask_path = path / 'images/masks'

df = pd.read_csv(path / 'livecell_train_val.csv', index_col = 0)

  mask |= (ar1 == a)


In [3]:
df.tail(1)

Unnamed: 0,id,file_name,width,height,segmentation,area,bbox,iscrowd,is_valid,cell_type
1200185,1617270,SHSY5Y_Phase_B10_2_00d16h00m_2,704,520,"[[334.61, 146.01, 333.52, 147.11, 332.26, 148.05, 331.32, 148.84, 329.75, 149.31, 327.87, 150.41, 326.14, 151.5, 323.94, 152.92, 321.9, 154.64, 319.86, 156.05, 319.71, 156.21, 319.71, 156.21, 319.24, 155.11, 318.29, 153.7, 317.51, 150.56, 317.35, 146.95, 316.73, 143.5, 316.88, 141.15, 317.35, 138.01, 318.45, 135.81, 320.02, 133.62, 321.43, 130.32, 321.9, 127.03, 322.22, 121.85, 323.0, 114.79, 323.0, 106.79, 324.57, 108.2, 325.98, 111.34, 327.08, 114.0, 326.92, 116.2, 326.61, 118.55, 326.3, 120.91, 326.14, 123.73, 326.14, 126.56, 326.3, 128.75, 327.4, 130.32, 328.65, 132.36, 330.22, 132.36,...",809.03835,"[316.73, 106.79, 50.20999999999998, 49.42]",0,True,SHSY5Y


### Label functions for classification & segmentation

In [4]:
def mask_label_func(fn):
    return mask_path / fn.name

In [5]:
def class_label_func(fn):
    return re.search('(.*)_Phase.*',fn.stem).group(1)

### Use the same train/val split troughout the switch training, to get reliable metrics.

In [None]:
df['is_valid'] = False

val_ids = np.random.choice(df.id.unique(), size = int(606 * 0.2))
df.loc[df.id.isin(val_ids), 'is_valid'] = True

In [None]:
#is_val_fns = df.loc[df.is_valid == 1, 'file_name'].unique()
#is_train_fns = df.loc[df.is_valid == 0, 'file_name'].unique()

def is_valid(fn):
    return fn.stem in val_ids

In [None]:
def MySplitter():
    def _inner(o):
        train_idx = o.argwhere(is_valid, negate = True)
        val_idx = o.argwhere(is_valid)    
        return train_idx, val_idx
    return _inner

### Training Parameters

In [7]:
TASKS = ['class', 'mask']
SIZES = [(224, 303), (336, 455), (448, 606)]
DEBUG = True

### Shrink dataset for debuging / prototyping

In [8]:
items = get_image_files(img_path)
print(len(items))
unique_files = df.file_name.unique()
if DEBUG:
    unique_files = np.random.choice(unique_files, 1000)
items = L([fn for fn in items if fn.stem in unique_files])
print(len(items))
fn = items[0]
img_rgb = load_image(fn)

4184
865


### Build Dataloaders

In [52]:
def get_all(path):
    items = get_image_files(path)
    if DEBUG:
        unique_files = np.random.choice(items, 500)
    return L(*unique_files)

In [54]:
def get_dls(src_path, task, size, bs_mult = 1):
    if task == 'mask':
        print('Build dataloaders for segmentation.')
        TaskBlock = (MaskBlock(codes = np.array(['bg', 'cell'])),)
        label_func = mask_label_func
        bs = 1 * bs_mult
        get_items = get_annotated
        
    if task == 'class':
        print('Build dataloaders for classification.')
        TaskBlock = (CategoryBlock(),)
        label_func = class_label_func
        bs = 4 * bs_mult
        get_items = get_all
        
    db = DataBlock(
        blocks = (ImageBlock(), *TaskBlock),
        get_items = get_items,
        get_y = label_func,
        splitter = RandomSplitter(),
        item_tfms = Resize((448, 606)),
        batch_tfms = [*aug_transforms(size = size, 
                                    flip_vert = True, 
                                    max_rotate = 180., 
                                    max_warp = 0.
                                   ),
                      Normalize.from_stats(*imagenet_stats)
                     ],
        n_inp = 1
    )
    print(f'Using batch size of: {bs} and image size of: {size}')
    return db.dataloaders(src_path, bs = bs)

### Build Model

In [66]:
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): 
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): 
        return cut(model)
    else: 
        raise NamedError("cut must be either integer or function")

### Build Learner

In [122]:
def get_learner(task, dls, arch, learner_ex, encoder_ex = False, with_opt = False):

    body = create_timm_body(arch, pretrained = True)
    n_out = get_c(dls)

    if task == 'mask':
        img_size = dls.one_batch()[0].shape[-2:]
        model = models.unet.DynamicUnet(body, n_out, img_size, self_attention = True)
        metrics = [Dice, JaccardCoeff]

    if task == 'class':
        nf = num_features_model(nn.Sequential(*body.children()))
        head = create_head(nf, n_out, ps = 0.5)
        model = nn.Sequential(body, head)
        apply_init(model[1], nn.init.kaiming_normal_)
        metrics = accuracy

    print(f'Create learner for task {task}')
    learn = Learner(
        dls,
        model,
        metrics = metrics,
        splitter = default_split
    ).to_fp16()
    
    if learner_ex[TASKS.index(task)]:
        print(f'Load pretrained model')
        #learn = load_learner(f'{task}_{arch}_learner.pkl', cpu = False)
        #learn.dls = dls
        learn.load(f'{task}_{arch}_learner', with_opt = with_opt)
    
    if encoder_ex:
        print('Load encoder from previous task')
        load_model(f'{arch}_encoder.pkl', learn.model[0], opt=None, with_opt=False)
    else:
        print('Start training from scratch.')

    return learn

## Switch Test

In [125]:
def switch_training(tasks, sizes, epochs, rounds, learner_ex, encoder_ex, lr = 1e-3, arch = 'resnet34', with_opt = False): #efficientnet_b3a
    print('Multitask training with: ', arch)
    for k, size in enumerate(sizes):
        for i in range(rounds):
            print(f'Start round no {i+1} with image size {size}')
            for task in tasks:
                if task == 'class':
                    eps = epochs[0]
                else:
                    eps = epochs[1]
                dls = get_dls(img_path, task, size, bs_mult = 1)
                learn = get_learner(task, dls, arch, learner_ex, encoder_ex, with_opt = with_opt)
                learn.fine_tune(eps, lr, cbs = CutMix(), freeze_epochs = 1)

                # Save encoder for all tasks
                encoder = get_model(learn.model)[0]
                save_model(f'{arch}_encoder.pkl', encoder, opt = None, with_opt=False)
                encoder_ex = True
                # Save learner for specific task
                #learn.export(f'{task}_{arch}_learner.pkl')
                learn.save(f'{task}_{arch}_learner')
                with_opt = True
                learner_ex[TASKS.index(task)] = True

                #Cleanup
                if (size != None) & (task == 'mask'):
                    print('Clean up')
                    del(dls)
                    del(learn)
                    torch.cuda.empty_cache()
                    gc.collect()
                else:
                    learn.export('mask_learner_final.pkl')
                print(f'Done with task: {task}')
            print(f'Finish round no: {i+1} with image size: {size}')

In [126]:
switch_training(tasks = TASKS, sizes = SIZES[:2], rounds = 3, learner_ex = [False, False], encoder_ex = False, epochs = 1)

Multitask training with:  resnet34
Start round no: 0 with image size: (224, 303)
Build dataloaders for classification.
Using batch size of: 4 and image size of: (224, 303)
Build classifier.
Create new learner for task class
Start training from scratch.


epoch,train_loss,valid_loss,accuracy,time
0,3.334918,1.492302,0.43,00:04


epoch,train_loss,valid_loss,accuracy,time
0,2.707852,1.142073,0.58,00:05


Clean up
Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (224, 303)
Build Unet learner.
Create new learner for task mask
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,1.022464,0.40543,0.697332,0.535311,00:24


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.256063,0.315576,0.852543,0.742985,00:28


Clean up
Done with task: mask
Finish round no: 0 with image size: (224, 303)
Start round no: 1 with image size: (224, 303)
Build dataloaders for classification.
Using batch size of: 4 and image size of: (224, 303)
Load pretrained learner for task: class
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,2.449506,0.977844,0.61,00:04


epoch,train_loss,valid_loss,accuracy,time
0,2.329289,0.835243,0.72,00:05


Clean up
Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (224, 303)
Load pretrained learner for task: mask
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.289198,0.275097,0.84202,0.727146,00:22


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.203344,0.258822,0.860021,0.754418,00:26


Clean up
Done with task: mask
Finish round no: 1 with image size: (224, 303)
Start round no: 2 with image size: (224, 303)
Build dataloaders for classification.
Using batch size of: 4 and image size of: (224, 303)
Load pretrained learner for task: class
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,2.40583,0.752788,0.75,00:04


epoch,train_loss,valid_loss,accuracy,time
0,2.105946,0.621471,0.8,00:05


Clean up
Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (224, 303)
Load pretrained learner for task: mask
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.235964,0.335457,0.802514,0.670166,00:23


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.192223,0.583948,0.652633,0.484376,00:26


Clean up
Done with task: mask
Finish round no: 2 with image size: (224, 303)
Start round no: 0 with image size: (336, 455)
Build dataloaders for classification.
Using batch size of: 4 and image size of: (336, 455)
Load pretrained learner for task: class
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,2.038789,0.401373,0.89,00:06


epoch,train_loss,valid_loss,accuracy,time
0,2.040146,0.328276,0.93,00:08


Clean up
Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (336, 455)
Load pretrained learner for task: mask
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.196726,0.219574,0.867222,0.765572,00:39


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.164285,0.211647,0.858054,0.751396,00:44


Clean up
Done with task: mask
Finish round no: 0 with image size: (336, 455)
Start round no: 1 with image size: (336, 455)
Build dataloaders for classification.
Using batch size of: 4 and image size of: (336, 455)
Load pretrained learner for task: class
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,1.979987,0.257572,0.95,00:06


epoch,train_loss,valid_loss,accuracy,time
0,1.966719,0.293886,0.93,00:08


Clean up
Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (336, 455)
Load pretrained learner for task: mask
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.19597,0.233474,0.87731,0.781436,00:40


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.174975,0.193935,0.896892,0.813059,00:44


Clean up
Done with task: mask
Finish round no: 1 with image size: (336, 455)
Start round no: 2 with image size: (336, 455)
Build dataloaders for classification.
Using batch size of: 4 and image size of: (336, 455)
Load pretrained learner for task: class
Load encoder from previous task


epoch,train_loss,valid_loss,accuracy,time
0,1.89853,0.32261,0.93,00:06


epoch,train_loss,valid_loss,accuracy,time
0,1.824323,0.35955,0.92,00:08


Clean up
Done with task: class
Build dataloaders for segmentation.
Using batch size of: 1 and image size of: (336, 455)
Load pretrained learner for task: mask
Load encoder from previous task


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.180745,0.187035,0.877678,0.78202,00:38


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.16255,0.320526,0.858165,0.751566,00:43


Clean up
Done with task: mask
Finish round no: 2 with image size: (336, 455)
