# Appendix 3 - Data augmentation and resizing 3D images
**Estimated time to run through notebook is 20 minutes** 

This notebook shows how to
-  [Load libraries, predefine some functions, load the manifest, and make a dataset](#preprocessing)
-  [Configuring serotiny](#config)
-  [Resize images for 3D training](#train3D)
-  [Conclusion](#end)

#### Resources 
- Serotiny code: https://github.com/AllenCell/serotiny
- Serotiny documentation: https://allencell.github.io/serotiny
- Hydra for configurability https://hydra.cc/
- MLFlow for experiment tracking https://mlflow.org/
- Pytorch Lightning for DL training/testing/predictions https://pytorchlightning.ai/

## <a id='preprocessing'></a>Load libraries, predefine some functions, load the manifest, and make a dataset 


### Load libraries and predefined functions

In [None]:
from upath import UPath as Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import nbvv

from serotiny.io.image import image_loader
from cytodata_aics.io_utils import rescale_image

### Load the manifest and explore dimensions

In [None]:
df = pd.read_parquet("s3://allencell-cytodata-variance-data/processed/hackathon_manifest_092022.parquet")
print(f'Number of cells: {len(df)}')
print(f'Number of columns: {len(df.columns)}')

### Make a simple dataset of edge vs. non-edge cells

In [None]:
from serotiny.transforms.dataframe.transforms import split_dataframe

Path("/home/aicsuser/serotiny_data/").mkdir(parents=True, exist_ok=True)

n = 1000 # number of cells per class
cells_edgeVSnoedge = df.groupby("edge_flag").sample(n)

# Add the train, test and validate split
cells_edgeVSnoedge = split_dataframe(dataframe=cells_edgeVSnoedge, train_frac=0.7, val_frac=0.2, return_splits=False)

cells_edgeVSnoedge.to_csv("/home/aicsuser/serotiny_data/cells_edgeVSnoedge_2.csv") 
print(f"Number of cells: {len(cells_edgeVSnoedge)}")
print(f"Number of columns: {len(cells_edgeVSnoedge.columns)}")

## <a id='config'></a>Configuring serotiny

In [None]:
run_name = f"memdna_zproj_{now_str()}"
print(run_name)

!serotiny train \
    model=example2_classifier_2d \
    data=example2_dataloader_2d \
    mlflow.experiment_name=cytodata_chapter5 \
    mlflow.run_name={run_name} \
    trainer.gpus=[0] \
    trainer.max_epochs=10

As refered to in Chapter 5, `serotiny` requires you to have configured 5 modules. 
Below we will see an example of how each one is configured, but before that it is worth talking about
the syntax used for these configurations.

`serotiny` uses `hydra` as a configuration framework. In `hydra`, configs are written in YAML and they
can use a special syntax to represent the instantiation of classes (and partial functions), or the invocation of functions.

For example, suppose we have a class `SomeClass` inside the module `some_class` of a package `some_package`. The way one would
instantiate this class within a `hydra` config would be like:

---

   
```yaml
_target_: some_package.some_class.SomeClass  # this is the "path" to the class.

# assuming this class takes `param1` and `param2` as arguments
param1: a
param2: b
```

---

You may have guessed this, but if your class takes an object of some other class as a parameter, you can have nested instantiations like:

---

```yaml
_target_: some_package.some_class.SomeClass  # this is the "path" to the class.

# assuming this class takes `param1` and `param2` as arguments
param1: a
param2: b

# assuming this class also takes `param3` as an argument, and that it should be an instantiation
# of a class some_package.another_class.AnotherClass
param3:
  _target_: some_package.another_class.AnotherClass
  arg1: 1
  arg2: 2
```

---

When we call `hydra.utils.instantiate` on the config object that results from reading this YAML, the class gets instantiated!
This is the main mechanism used by `serotiny` to obtain the several objects it needs to carry out the training/testing/prediction of a model.



##### **Advanced version**

This version of a `data` config uses some YAML tricks and the `_aux_` section (which is ignored by `serotiny`) to
more flexibly build models

```yaml
_aux_: 
  _a: &hidden_channels 4
  _b: &kernel_size 3
  _c: &conv_block
    _target_: torch.nn.Sequential
    _args_:
      - _target_: torch.nn.LazyConv2d
        out_channels: *hidden_channels
        kernel_size: *kernel_size
        stride: 1
      - _target_: torch.nn.LeakyReLU
      - _target_: torch.nn.LazyBatchNorm2d

_target_: serotiny.models.BasicModel
x_label: image
y_label: class
network:
  _target_: torch.nn.Sequential
  _args_:
    - *conv_block
    - *conv_block
    - *conv_block
    - _target_: serotiny.networks.layers.Flatten
    - _target_: torch.nn.LazyLinear
      out_features: 1
    - _target_: torch.nn.Sigmoid
    
loss:
  _target_: torch.nn.BCELoss
  
  
# a function used by `serotiny predict` to store the results of feeding data through the model
save_predictions:
  _target_: cytodata_aics.model_utils.save_predictions_classifier
  _partial_: true

# fields to include in the output for each batch
fields_to_log:
  - id
```

## <a id='train3D'></a>Resize images for 3D training  

#### Updating the `data` config to resize the images

```yaml
_target_: serotiny.datamodules.ManifestDatamodule

path: /home/aicsuser/serotiny_data/cells_edgeVSnoedge.csv

batch_size: 64
num_workers: 6
loaders:
  id:
    _target_: serotiny.io.dataframe.loaders.LoadColumn
    column: CellId
    dtype: int
  class:
    _target_: serotiny.io.dataframe.loaders.LoadColumn
    column: edge_flag
    dtype: float32
  image:
    _target_: serotiny.io.dataframe.loaders.LoadImage
    column: registered_path
    select_channels: ['membrane']
    dtype: float32
    ome_zarr_level: 1 #scaling the image
    transform:
        - _partial_: true
          _target_: cytodata_aics.io_utils.rescale_image
          channels: ['membrane']
        - _target_: monai.transforms.GaussianSharpen #transformation will be applied to all select_channels
          
          
       
    
    
split_column: "split"
```

#### Changing the working directory

In [None]:
# we need the commands we type to be ran from the serotiny project root
# (because that's what `serotiny` expects) so we change directories here,
# so we can run commands within the notebook
import os
os.chdir("/home/aicsuser/cytodata-hackathon-base")

#### Creating a run name based on the current date and time

In [None]:
from datetime import datetime

# util to avoid referring to the same run unintentionally
now_str = lambda : datetime.now().strftime("%Y%m%d_%H%M%S")

#### Starting a training. Track the training at http://mlflow.cytodata.allencell.org/

++data.loaders.image.ome_zarr_level=1 \
level = 0 # full image
level = 1 # .5 scaled image in all 3 dimensions
level = 2 # .25 scaled in ....


In [None]:
run_name = f"some_3d_run_{now_str()}"

!serotiny train \
    model=example_classifier_3d \
    data=example_dataloader_3d \
    mlflow.experiment_name=cytodata_chapter5 \
    mlflow.run_name={run_name} \
    trainer.gpus=[0] \
    trainer.max_epochs=1 

Note: The above task takes more the 16GB (it will not fit on the AWS computers) 45618MiB / 81920MiB

### Make predictions from the pretrained model

In [None]:
!serotiny predict \
    model=example_classifier_3d \
    data=example_dataloader_3d \
    mlflow.experiment_name=cytodata_chapter5 \
    mlflow.run_name={run_name} \
    trainer.gpus=[0]

### Retrieving predictions from MLFlow

In [None]:
mlflow.set_tracking_uri("http://mlflow.mlflow.svc.cluster.local")

with download_artifact("predictions/model_predictions.csv", experiment_name="cytodata_chapter5", run_name=run_name) as path:
    predictions_3d_df = pd.read_csv(path)

In [None]:
predictions_3d_df = predictions_3d_df.merge(cells_edgeVSnoedge[['CellId','split']].rename(columns={'CellId':'id'}), on = 'id')
predictions_3d_df
# print(len(predictions_3d_df))

In [None]:
plt.hist(predictions_3d_df.yhat.to_numpy())
plt.show()

### Confusion matrices of train, valid and test splits 

In [None]:
from sklearn.metrics import confusion_matrix,accuracy_score,classification_report

# make confusion matrix for each split
splits = ['train','valid','test']
fig, axes = plt.subplots(nrows=1,ncols=len(splits),figsize=(10, 3), dpi=100)

for i,split in enumerate(splits):
    
    y_true = predictions_3d_df[predictions_3d_df['split']==split]['y'].to_numpy()
    y_pred = predictions_3d_df[predictions_3d_df['split']==split]['yhat'].to_numpy()
    y_pred = np.round(y_pred) #get to crisp binary class labels from posterior probability

    # Computer confusion matrix
    cm = confusion_matrix(y_true, y_pred)    
    score = accuracy_score(y_true,y_pred) #compute accuracy score
    cm_df = pd.DataFrame(cm)
    sns.heatmap(cm_df, annot=True, fmt='d',ax = axes[i])
    axes[i].set_title(f'Accuracy on {split} is {score:.2f}')
    axes[i].set_xlabel('True')
    axes[i].set_ylabel('Predicted')

plt.show()

# <a id='end'></a>Conclusion
