In [None]:
from deepcell.cli.modules.create_dataset import construct_dataset, \
    VoteTallyingStrategy
from deepcell.datasets.channel import Channel
from deepcell.datasets.model_input import ModelInput
from deepcell.data_splitter import DataSplitter
from deepcell.datasets.roi_dataset import RoiDataset
from deepcell.models.classifier import Classifier
from deepcell.inference import inference


from tqdm import tqdm
import json
import torchvision
from torchvision.models.vgg import VGG11_BN_Weights
from torch.utils.data.dataloader import DataLoader

# Fetch data

In [None]:
!mkdir data_cell_classifier_102323

# Copy learning mfish data
!aws s3 sync s3://dev.deepcell.alleninstitute.org/learning-mfish/ data_cell_classifier_102323

# copy SSF data
!mkdir -p tmp

!aws s3 cp s3://dev.deepcell.alleninstitute.org/input_data/2023-06-05_05:06:06-551773/train/0/0.tar.gz tmp
!aws s3 cp s3://dev.deepcell.alleninstitute.org/input_data/2023-06-05_05:06:06-551773/train/1/1.tar.gz tmp
!aws s3 cp s3://dev.deepcell.alleninstitute.org/input_data/2023-06-05_05:06:06-551773/train/2/2.tar.gz tmp
!aws s3 cp s3://dev.deepcell.alleninstitute.org/input_data/2023-06-05_05:06:06-551773/train/3/3.tar.gz tmp
!aws s3 cp s3://dev.deepcell.alleninstitute.org/input_data/2023-06-05_05:06:06-551773/train/4/4.tar.gz tmp
            
!tar -xzvf tmp/0.tar.gz -C data_cell_classifier_102323
!tar -xzvf tmp/1.tar.gz -C data_cell_classifier_102323
!tar -xzvf tmp/2.tar.gz -C data_cell_classifier_102323
!tar -xzvf tmp/3.tar.gz -C data_cell_classifier_102323
!tar -xzvf tmp/4.tar.gz -C data_cell_classifier_102323

!rm -rf tmp

# Copy data that was in the SSF test set
!aws s3 sync s3://dev.deepcell.alleninstitute.org/102623_existing_test_data/ data_cell_classifier_102323

# Fetch model weights

In [None]:
# Copy SSF model weights
!mkdir -p ssf_baseline_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-0-1685968030/output/model.tar.gz .
!tar -xzvf model.tar.gz -C ssf_baseline_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-1-1685968593/output/model.tar.gz .
!tar -xzvf model.tar.gz -C ssf_baseline_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-2-1685968728/output/model.tar.gz .
!tar -xzvf model.tar.gz -C ssf_baseline_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-3-1685968867/output/model.tar.gz .
!tar -xzvf model.tar.gz -C ssf_baseline_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-4-1685969004/output/model.tar.gz .
!tar -xzvf model.tar.gz -C ssf_baseline_checkpoints
    
!mv ssf_baseline_checkpoints/0/0_model.pt ssf_baseline_checkpoints
!mv ssf_baseline_checkpoints/1/1_model.pt ssf_baseline_checkpoints
!mv ssf_baseline_checkpoints/2/2_model.pt ssf_baseline_checkpoints
!mv ssf_baseline_checkpoints/3/3_model.pt ssf_baseline_checkpoints
!mv ssf_baseline_checkpoints/4/4_model.pt ssf_baseline_checkpoints

In [None]:
# copy Learning MFish model weights

!mkdir -p learning_mfish_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-0-1698767254/output/model.tar.gz   .
!tar -xzvf model.tar.gz -C learning_mfish_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-1-1698767276/output/model.tar.gz   .
!tar -xzvf model.tar.gz -C learning_mfish_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-2-1698767299/output/model.tar.gz    .
!tar -xzvf model.tar.gz -C learning_mfish_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-3-1698767322/output/model.tar.gz    .
!tar -xzvf model.tar.gz -C learning_mfish_checkpoints

!aws s3 cp s3://dev.deepcell.alleninstitute.org/deepcell-train-fold-4-1698767344/output/model.tar.gz    .
!tar -xzvf model.tar.gz -C learning_mfish_checkpoints

!mv learning_mfish_checkpoints/0/0_model.pt learning_mfish_checkpoints
!mv learning_mfish_checkpoints/1/1_model.pt learning_mfish_checkpoints
!mv learning_mfish_checkpoints/2/2_model.pt learning_mfish_checkpoints
!mv learning_mfish_checkpoints/3/3_model.pt learning_mfish_checkpoints
!mv learning_mfish_checkpoints/4/4_model.pt learning_mfish_checkpoints

# Train model

In [None]:
# The labels are stored in the cell labeling app database cell-labeling-app-db-instance-1.cdxknown3r0n.us-west-2.rds.amazonaws.com . 
# This pulls from there using the `/get_all_labels` endpoint

labels = construct_dataset(
    cell_labeling_app_host='ec2-34-211-120-165.us-west-2.compute.amazonaws.com',
    vote_tallying_strategy=VoteTallyingStrategy.MAJORITY
)

In [None]:
def _get_model_inputs():
    model_inputs = []
    exp_ids = labels['experiment_id'].unique()
    for exp_id in tqdm(exp_ids):
        exp_labels = labels[labels['experiment_id'] == exp_id]
        
        for row in exp_labels.itertuples(index=False):
            model_input = ModelInput(
                experiment_id=row.experiment_id,
                project_name=row.job_name,
                roi_id=row.roi_id,
                channel_path_map={
                    Channel.CORRELATION_PROJECTION: f'/home/ec2-user/SageMaker/data_cell_classifier_102323/correlation_{exp_id}_{row.roi_id}.png',
                    Channel.MAX_PROJECTION: f'/home/ec2-user/SageMaker/data_cell_classifier_102323/max_{exp_id}_{row.roi_id}.png',
                    Channel.MASK: f'/home/ec2-user/SageMaker/data_cell_classifier_102323/mask_{exp_id}_{row.roi_id}.png',
                },
                channel_order=[Channel.CORRELATION_PROJECTION, Channel.MAX_PROJECTION, Channel.MASK],
                label=row.label
            )
            model_inputs.append(model_input)
    return model_inputs
model_inputs = _get_model_inputs()

## Exclude ROIs in motion border

In [None]:
# Due to a bug, the motion border info was not correctly calculated.
# It was recalculated outside this notebook
# Incorporate this info by using the 2 files roi_meta.json (SSF) and roi_meta_learningmfish.json (lmf)

!aws s3 cp s3://dev.deepcell.alleninstitute.org/roi_meta.json .
!aws s3 cp s3://dev.deepcell.alleninstitute.org/roi_meta_learningmfish.json .


def _add_motion_border_info():
    with open('roi_meta.json') as f:
        meta_ssf = json.load(f)
    with open('roi_meta_learningmfish.json') as f:
        meta_lmf = json.load(f)
    
    for m in model_inputs:
        if m.project_name == 'SSF':
            m.overlaps_motion_border = meta_ssf[f'{m.experiment_id}_{m.roi_id}']['overlaps_motion_border']
        else:
            m.overlaps_motion_border = meta_lmf[f'{m.experiment_id}_{m.roi_id}']['overlaps_motion_border']
_add_motion_border_info()

In [None]:
len(model_inputs)

In [None]:
model_inputs = [x for x in model_inputs if not x.overlaps_motion_border]
len(model_inputs)

## Train/test split

In [None]:
splitter = DataSplitter(
    model_inputs=model_inputs,
    train_transform=RoiDataset.get_default_transforms(
        is_train=True, 
        crop_size=(128, 128),
        means=VGG11_BN_Weights.DEFAULT.transforms().mean, 
        stds=VGG11_BN_Weights.DEFAULT.transforms().std
    ),
    test_transform=RoiDataset.get_default_transforms(
        is_train=False,
        crop_size=(128, 128),
        means=VGG11_BN_Weights.DEFAULT.transforms().mean, 
        stds=VGG11_BN_Weights.DEFAULT.transforms().std
    ),
    seed=1234
)

In [None]:
# Fetch a file exp_meta.json which has metadata such as experiment depth,
# used for train/test split

!aws s3 cp s3://dev.deepcell.alleninstitute.org/exp_meta.json .

def _get_train_test_split():
    with open('exp_meta.json') as f:
        exp_meta = json.load(f)
    lmf_inputs = [x for x in model_inputs if x.project_name == 'LearningMFish']
    lmf_experiments = set([int(x.experiment_id) for x in lmf_inputs])
    exp_meta = [x for x in exp_meta if x['experiment_id'] in lmf_experiments]
    
    # only include learning mfish data in test
    train, test = splitter.get_train_test_split(
        test_size=0.3,
        full_dataset=lmf_inputs,
        exp_metas=exp_meta,
        n_depth_bins=4
    )
    
    return train, test

train, test = _get_train_test_split()

In [None]:
len(train), len(test)

In [None]:
with open('train_model_inputs_103023_only_learning_mfish.json', 'w') as f:
    f.write(json.dumps([x.to_dict() for x in train.model_inputs if x.project_name == 'LearningMFish'], indent=2))

Train model either using train.ipynb or using `python -m deepcell.cli.modules.cloud.train --input_json <input_json>`

An example input json for the latter is at `s3://dev.deepcell.alleninstitute.org/cloud_train.json`

The rest of the notebook will assume that we've trained the model and are using the weights defined above

# Baseline performance

In [None]:
def _get_baseline_perf(which='test'):
    model = getattr(
        torchvision.models, 'vgg11_bn')(
            pretrained=True,
            progress=False)
    
    model = Classifier(
        model=model,
        truncate_to_layer=22,
        classifier_cfg=[]
    )
    mi = test.model_inputs if which == 'test' else \
        [x for x in train.model_inputs if x.project_name == 'LearningMFish']
    ds = RoiDataset(
        model_inputs=mi,
        transform=RoiDataset.get_default_transforms(
            is_train=False, 
            crop_size=(128, 128),
            means=VGG11_BN_Weights.DEFAULT.transforms().mean, 
            stds=VGG11_BN_Weights.DEFAULT.transforms().std
        )
    )
    test_dataloader = DataLoader(
        dataset=ds, 
        shuffle=False,
        batch_size=64
    )
    metrics, res = inference(
        model=model,
        test_loader=test_dataloader,
        threshold=0.5,
        has_labels=True,
        checkpoint_path='ssf_baseline_checkpoints'
    )
    return metrics, res

In [None]:
baseline_metrics, baseline_preds = _get_baseline_perf(which='test')

In [None]:
def calc_accuracy_metrics(metrics):
    p = metrics.TP / (metrics.TP + metrics.FP)
    r = metrics.TP / (metrics.TP + metrics.FN)
    f1 = 2 * p * r / (p + r)
    return p, r, f1

In [None]:
calc_accuracy_metrics(metrics=baseline_metrics)

# Trained model performance

In [None]:
def _calc_test_set_perf(model_inputs=test.model_inputs, checkpoints_path='learning_mfish_checkpoints'):
    model = getattr(
        torchvision.models, 'vgg11_bn')(
            pretrained=True,
            progress=False)
    
    model = Classifier(
        model=model,
        truncate_to_layer=22,
        classifier_cfg=[]
    )
    ds = RoiDataset(
        model_inputs=model_inputs,
        transform=RoiDataset.get_default_transforms(
            is_train=False, 
            crop_size=(128, 128),
            means=VGG11_BN_Weights.DEFAULT.transforms().mean, 
            stds=VGG11_BN_Weights.DEFAULT.transforms().std
        )
    )
    test_dataloader = DataLoader(
        dataset=ds, 
        shuffle=False,
        batch_size=64
    )
    metrics, res = inference(
        model=model,
        test_loader=test_dataloader,
        threshold=0.5,
        has_labels=True,
        checkpoint_path=checkpoints_path
    )
    return metrics, res

In [None]:
test_metrics, test_preds = _calc_test_set_perf()

In [None]:
len(test)

In [None]:
len([x for x in test.model_inputs if x.label == 'cell'])

In [None]:
calc_accuracy_metrics(metrics=test_metrics)