In [1]:
import sys
sys.path.append('/home/james/Documents/VS/EmbedSegScrolls')
import numpy as np
import os
import torch
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
from matplotlib.colors import ListedColormap
import json

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.
<a id='center'></a>

In [2]:
data_dir = 'crops'
project_name = 'Mouse-Organoid-Cells-CBG'
center = 'medoid' # 'centroid', 'medoid'

print("Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(project_name, data_dir))

Project Name chosen as : Mouse-Organoid-Cells-CBG. 
Train-Val images-masks-center-images will be accessed from : crops


In [3]:
try:
    assert center in {'medoid', 'centroid'}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "centroid"}', 42)
    raise

Spatial Embedding Location chosen as : medoid


### Obtain properties of the dataset 

Here, we read the `dataset.json` file prepared in the `01-data` notebook previously.

In [4]:
if os.path.isfile('data_properties.json'): 
    with open('data_properties.json') as json_file:
        data = json.load(json_file)
        data_type, foreground_weight, n_z, n_y, n_x, pixel_size_z_microns, pixel_size_x_microns = data['data_type'], float(data['foreground_weight']), int(data['n_z']), int(data['n_y']), int(data['n_x']), float(data['pixel_size_z_microns']), float(data['pixel_size_x_microns'])

### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. 

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [5]:
train_size = len(os.listdir(os.path.join(data_dir, project_name, 'train', 'images')))
train_batch_size = 32

### Create the `train_dataset_dict` dictionary  

In [6]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         center = center, 
                                         size = train_size, 
                                         batch_size = train_batch_size, 
                                         type = 'train',
                                         name = '3d')

`train_dataset_dict` dictionary successfully created                 with: 
 -- train images accessed from crops/Mouse-Organoid-Cells-CBG/train/images, 
 -- number of images per epoch equal to 567, 
 -- batch size set at 32, 


### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to the total number of validation image crops.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!

In [7]:
val_size = len(os.listdir(os.path.join(data_dir, project_name, 'val', 'images')))
val_batch_size = 16

### Create the `val_dataset_dict` dictionary

In [8]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                       project_name = project_name, 
                                       center = center, 
                                       size = val_size, 
                                       batch_size = val_batch_size, 
                                       type ='val',
                                       name ='3d')

`val_dataset_dict` dictionary successfully created                 with: 
 -- val images accessed from crops/Mouse-Organoid-Cells-CBG/val/images, 
 -- number of images per epoch equal to 113, 
 -- batch size set at 16, 


### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to the number of channels in the input images. 
* Set the `num_classes = [6, 1]` for `3d` training and `num_classes = [4, 1]` for `2d` training
<br>(here, 6 implies the offsets and bandwidths in x, y and z dimensions and 1 implies the `seediness` value per pixel)

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [9]:
input_channels = 1
num_classes = [6, 1] 

### Create the `model_dict` dictionary

In [10]:
model_dict = create_model_dict(input_channels = input_channels,
                              num_classes = num_classes,
                              name = '3d')

`model_dict` dictionary successfully created                 with: 
 -- num of classes equal to 1, 
 -- input channels                 equal to [6, 1], 
 -- name equal to branched_erfnet_3d


### Create the `loss_dict` dictionary

In [11]:
loss_dict = create_loss_dict(n_sigma = 3, foreground_weight = foreground_weight)

`loss_dict` dictionary successfully created                 with: 
 -- foreground weight equal to 34.143, 
 -- w_inst                 equal to 1, 
 -- w_var                 equal to 10, 
 -- w_seed equal to 1


### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for reasonable results, you should atleast train for longer than 50 epochs.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './experiment/Mouse-Organoid-Cells-CBG-demo/checkpoint.pth'` to resume training from the last checkpoint.


In [12]:
n_epochs = 2
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = None

In the cell after this one, a `configs` dictionary is generated from the parameters specified here!
<a id='resume'></a>

### Create the  `configs` dictionary 

In [13]:
configs = create_configs(n_epochs = n_epochs,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         n_z = n_z,
                         n_y = n_y, 
                         n_x = n_x,
                         anisotropy_factor = pixel_size_z_microns/pixel_size_x_microns, 
                         )

`configs` dictionary successfully created with: 
 -- n_epochs equal to 200, 
 -- save_dir equal to experiment/Mouse-Organoid-Cells-CBG-demo, 
 -- n_z equal to 72, 
 -- n_y equal to 408, 
 -- n_x equal to 408, 


In [14]:
import torch

print(torch.cuda.is_available())  # Checks if CUDA is available on your system
print(torch.version.cuda)         # Shows the CUDA version PyTorch was built with


True
12.1


### Begin training!

Executing the next cell would begin the training. 

In [15]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

3-D `train` dataloader created! Accessing data from crops/Mouse-Organoid-Cells-CBG/train/
Number of images in `train` directory is 567
Number of instances in `train` directory is 567
Number of center images in `train` directory is 567
*************************
3-D `val` dataloader created! Accessing data from crops/Mouse-Organoid-Cells-CBG/val/
Number of images in `val` directory is 113
Number of instances in `val` directory is 113
Number of center images in `val` directory is 113
*************************
Creating Branched Erfnet 3D with [6, 1] outputs
initialize last layer with size:  torch.Size([16, 6, 2, 2, 2])
Created spatial emb loss function with:                     n_sigma: 3, foreground_weight: 34.143469240821844
*************************
Created logger with keys:  ('train', 'val', 'iou')
Starting epoch 0
learning rate: 0.0005


  0%|          | 0/17 [00:00<?, ?it/s]

xyzm_s shape is torch.Size([3, 24, 152, 152])
prediction shape is torch.Size([32, 7, 24, 152, 152])


  0%|          | 0/17 [00:05<?, ?it/s]


KeyboardInterrupt: 

<div class="alert alert-block alert-warning"> 
  Common causes for errors during training, may include : <br>
    1. Not having <b>center images</b> for  <b>both</b> train and val directories  <br>
    2. <b>Mismatch</b> between type of center-images saved in <b>01-data.ipynb</b> and the type of center chosen in this notebook (see the <b><a href="#center"> center</a></b> parameter in the third code cell in this notebook)   <br>
    3. In case of resuming training from a previous checkpoint, please ensure that the model weights are read from the correct directory, using the <b><a href="#resume"> resume_path</a></b> parameter. Additionally, please ensure that the <b>save_dir</b> parameter for saving the model weights points to a relevant directory. 
</div>