In [1]:
import sys
sys.path.append('/home/james/Documents/VS/EmbedSegScrolls')
import numpy as np
import os
import torch
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
from matplotlib.colors import ListedColormap
import json

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.
<a id='center'></a>

In [2]:
data_dir = 'crops'
project_name = 'Vesuvius'
center = 'approximate-medoid' # 'centroid', 'medoid'

print("Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(project_name, data_dir))

Project Name chosen as : Vesuvius. 
Train-Val images-masks-center-images will be accessed from : crops


In [3]:
try:
    assert center in {'medoid', 'centroid', 'approximate-medoid'}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "centroid"}', 42)
    raise

Spatial Embedding Location chosen as : approximate-medoid


### Obtain properties of the dataset 

Here, we read the `dataset.json` file prepared in the `01-data` notebook previously.

In [4]:
if os.path.isfile('data_properties.json'): 
    with open('data_properties.json') as json_file:
        data = json.load(json_file)
        data_type, foreground_weight, n_z, n_y, n_x, pixel_size_z_microns, pixel_size_x_microns = data['data_type'], float(data['foreground_weight']), int(data['n_z']), int(data['n_y']), int(data['n_x']), float(data['pixel_size_z_microns']), float(data['pixel_size_x_microns'])

### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. 

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [5]:
train_size = len(os.listdir(os.path.join(data_dir, project_name, 'train', 'images')))
train_batch_size = 1

### Create the `train_dataset_dict` dictionary  

In [6]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         center = center, 
                                         size = train_size, 
                                         batch_size = train_batch_size, 
                                         type = 'train',
                                         name = '3d')

`train_dataset_dict` dictionary successfully created                 with: 
 -- train images accessed from crops/Vesuvius/train/images, 
 -- number of images per epoch equal to 23, 
 -- batch size set at 1, 


### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to the total number of validation image crops.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!

In [7]:
val_size = len(os.listdir(os.path.join(data_dir, project_name, 'val', 'images')))
val_batch_size = 1

### Create the `val_dataset_dict` dictionary

In [8]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                       project_name = project_name, 
                                       center = center, 
                                       size = val_size, 
                                       batch_size = val_batch_size, 
                                       type ='val',
                                       name ='3d')

`val_dataset_dict` dictionary successfully created                 with: 
 -- val images accessed from crops/Vesuvius/val/images, 
 -- number of images per epoch equal to 9, 
 -- batch size set at 1, 


### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to the number of channels in the input images. 
* Set the `num_classes = [6, 1]` for `3d` training and `num_classes = [4, 1]` for `2d` training
<br>(here, 6 implies the offsets and bandwidths in x, y and z dimensions and 1 implies the `seediness` value per pixel)

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [9]:
input_channels = 1
num_classes = [6, 1] 

### Create the `model_dict` dictionary

In [10]:
model_dict = create_model_dict(input_channels = input_channels,
                              num_classes = num_classes,
                              name = '3d')

`model_dict` dictionary successfully created                 with: 
 -- num of classes equal to 1, 
 -- input channels                 equal to [6, 1], 
 -- name equal to branched_erfnet_3d


### Create the `loss_dict` dictionary

In [11]:
loss_dict = create_loss_dict(n_sigma = 3, foreground_weight = foreground_weight)

`loss_dict` dictionary successfully created                 with: 
 -- foreground weight equal to 3.708, 
 -- w_inst                 equal to 1, 
 -- w_var                 equal to 10, 
 -- w_seed equal to 1


### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for reasonable results, you should atleast train for longer than 50 epochs.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './experiment/Mouse-Organoid-Cells-CBG-demo/checkpoint.pth'` to resume training from the last checkpoint.


In [16]:
n_epochs = 50
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = None

In the cell after this one, a `configs` dictionary is generated from the parameters specified here!
<a id='resume'></a>

### Create the  `configs` dictionary 

In [17]:
configs = create_configs(n_epochs = n_epochs,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         n_z = n_z,
                         n_y = n_y, 
                         n_x = n_x,
                         anisotropy_factor = pixel_size_z_microns/pixel_size_x_microns, 
                         )

`configs` dictionary successfully created with: 
 -- n_epochs equal to 50, 
 -- save_dir equal to experiment/Vesuvius-demo, 
 -- n_z equal to 256, 
 -- n_y equal to 256, 
 -- n_x equal to 256, 


In [18]:
import torch

print(torch.cuda.is_available())  # Checks if CUDA is available on your system
print(torch.version.cuda)         # Shows the CUDA version PyTorch was built with


True
12.1


### Begin training!

Executing the next cell would begin the training. 

In [19]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

3-D `train` dataloader created! Accessing data from crops/Vesuvius/train/
Number of images in `train` directory is 23
Number of instances in `train` directory is 23
Number of center images in `train` directory is 23
*************************
3-D `val` dataloader created! Accessing data from crops/Vesuvius/val/
Number of images in `val` directory is 9
Number of instances in `val` directory is 9
Number of center images in `val` directory is 9
*************************
Creating Branched Erfnet 3D with [6, 1] outputs
initialize last layer with size:  torch.Size([16, 6, 2, 2, 2])
Created spatial emb loss function with:                     n_sigma: 3, foreground_weight: 3.7082556884380753
*************************
Created logger with keys:  ('train', 'val', 'iou')
Starting epoch 0
learning rate: 0.0005


100%|██████████| 23/23 [00:09<00:00,  2.53it/s]
100%|██████████| 9/9 [00:02<00:00,  3.72it/s]


===> train loss: 2.36
===> val loss: 2.50, val iou: 0.00
=> saving checkpoint
Starting epoch 1
learning rate: 0.0004977494364660346


100%|██████████| 23/23 [00:08<00:00,  2.57it/s]
100%|██████████| 9/9 [00:02<00:00,  3.84it/s]


===> train loss: 2.23
===> val loss: 2.30, val iou: 0.00
=> saving checkpoint
Starting epoch 2
learning rate: 0.0004954977417064171


100%|██████████| 23/23 [00:09<00:00,  2.54it/s]
100%|██████████| 9/9 [00:02<00:00,  3.87it/s]


===> train loss: 2.18
===> val loss: 2.18, val iou: 0.00
=> saving checkpoint
Starting epoch 3
learning rate: 0.0004932449094349202


100%|██████████| 23/23 [00:09<00:00,  2.49it/s]
100%|██████████| 9/9 [00:02<00:00,  3.90it/s]


===> train loss: 2.12
===> val loss: 2.11, val iou: 0.00
=> saving checkpoint
Starting epoch 4
learning rate: 0.0004909909332982877


100%|██████████| 23/23 [00:09<00:00,  2.55it/s]
100%|██████████| 9/9 [00:02<00:00,  3.84it/s]


===> train loss: 2.09
===> val loss: 2.08, val iou: 0.00
=> saving checkpoint
Starting epoch 5
learning rate: 0.0004887358068751748


100%|██████████| 23/23 [00:08<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  3.85it/s]


===> train loss: 2.06
===> val loss: 2.05, val iou: 0.00
=> saving checkpoint
Starting epoch 6
learning rate: 0.0004864795236750653


100%|██████████| 23/23 [00:09<00:00,  2.55it/s]
100%|██████████| 9/9 [00:02<00:00,  3.75it/s]


===> train loss: 2.05
===> val loss: 2.04, val iou: 0.00
=> saving checkpoint
Starting epoch 7
learning rate: 0.00048422207713716544


100%|██████████| 23/23 [00:08<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  3.79it/s]


===> train loss: 2.03
===> val loss: 2.04, val iou: 0.00
=> saving checkpoint
Starting epoch 8
learning rate: 0.00048196346062927547


100%|██████████| 23/23 [00:09<00:00,  2.51it/s]
100%|██████████| 9/9 [00:02<00:00,  3.89it/s]


===> train loss: 1.89
===> val loss: 2.75, val iou: 0.02
=> saving checkpoint
Starting epoch 9
learning rate: 0.00047970366744663594


100%|██████████| 23/23 [00:09<00:00,  2.49it/s]
100%|██████████| 9/9 [00:02<00:00,  3.81it/s]


===> train loss: 1.53
===> val loss: 1.91, val iou: 0.10
=> saving checkpoint
Starting epoch 10
learning rate: 0.00047744269081074987


100%|██████████| 23/23 [00:09<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  3.91it/s]


===> train loss: 1.46
===> val loss: 1.97, val iou: 0.10
=> saving checkpoint
Starting epoch 11
learning rate: 0.0004751805238681794


100%|██████████| 23/23 [00:08<00:00,  2.60it/s]
100%|██████████| 9/9 [00:02<00:00,  3.71it/s]


===> train loss: 1.47
===> val loss: 1.80, val iou: 0.10
=> saving checkpoint
Starting epoch 12
learning rate: 0.000472917159689316


100%|██████████| 23/23 [00:08<00:00,  2.60it/s]
100%|██████████| 9/9 [00:02<00:00,  3.83it/s]


===> train loss: 1.45
===> val loss: 1.70, val iou: 0.11
=> saving checkpoint
Starting epoch 13
learning rate: 0.00047065259126712457


100%|██████████| 23/23 [00:08<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  4.02it/s]


===> train loss: 1.43
===> val loss: 1.79, val iou: 0.10
=> saving checkpoint
Starting epoch 14
learning rate: 0.00046838681151585874


100%|██████████| 23/23 [00:08<00:00,  2.57it/s]
100%|██████████| 9/9 [00:02<00:00,  3.90it/s]


===> train loss: 1.41
===> val loss: 1.85, val iou: 0.09
=> saving checkpoint
Starting epoch 15
learning rate: 0.0004661198132697498


100%|██████████| 23/23 [00:09<00:00,  2.51it/s]
100%|██████████| 9/9 [00:02<00:00,  3.81it/s]


===> train loss: 1.49
===> val loss: 1.79, val iou: 0.11
=> saving checkpoint
Starting epoch 16
learning rate: 0.0004638515892816641


100%|██████████| 23/23 [00:09<00:00,  2.55it/s]
100%|██████████| 9/9 [00:02<00:00,  3.87it/s]


===> train loss: 1.38
===> val loss: 1.81, val iou: 0.10
=> saving checkpoint
Starting epoch 17
learning rate: 0.00046158213222173284


100%|██████████| 23/23 [00:08<00:00,  2.57it/s]
100%|██████████| 9/9 [00:02<00:00,  3.94it/s]


===> train loss: 1.39
===> val loss: 1.79, val iou: 0.10
=> saving checkpoint
Starting epoch 18
learning rate: 0.0004593114346759497


100%|██████████| 23/23 [00:09<00:00,  2.53it/s]
100%|██████████| 9/9 [00:02<00:00,  3.79it/s]


===> train loss: 1.39
===> val loss: 1.82, val iou: 0.11
=> saving checkpoint
Starting epoch 19
learning rate: 0.00045703948914473726


100%|██████████| 23/23 [00:08<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  3.89it/s]


===> train loss: 1.34
===> val loss: 1.80, val iou: 0.10
=> saving checkpoint
Starting epoch 20
learning rate: 0.00045476628804148113


100%|██████████| 23/23 [00:09<00:00,  2.52it/s]
100%|██████████| 9/9 [00:02<00:00,  3.87it/s]


===> train loss: 1.37
===> val loss: 1.83, val iou: 0.11
=> saving checkpoint
Starting epoch 21
learning rate: 0.00045249182369103055


100%|██████████| 23/23 [00:09<00:00,  2.54it/s]
100%|██████████| 9/9 [00:02<00:00,  3.70it/s]


===> train loss: 1.32
===> val loss: 1.67, val iou: 0.12
=> saving checkpoint
Starting epoch 22
learning rate: 0.00045021608832816447


100%|██████████| 23/23 [00:09<00:00,  2.52it/s]
100%|██████████| 9/9 [00:02<00:00,  3.80it/s]


===> train loss: 1.34
===> val loss: 1.95, val iou: 0.10
=> saving checkpoint
Starting epoch 23
learning rate: 0.0004479390740960227


100%|██████████| 23/23 [00:09<00:00,  2.54it/s]
100%|██████████| 9/9 [00:02<00:00,  3.89it/s]


===> train loss: 1.31
===> val loss: 1.71, val iou: 0.12
=> saving checkpoint
Starting epoch 24
learning rate: 0.00044566077304449995


100%|██████████| 23/23 [00:08<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  3.95it/s]


===> train loss: 1.28
===> val loss: 1.75, val iou: 0.13
=> saving checkpoint
Starting epoch 25
learning rate: 0.00044338117712860363


100%|██████████| 23/23 [00:08<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  3.90it/s]


===> train loss: 1.25
===> val loss: 1.66, val iou: 0.13
=> saving checkpoint
Starting epoch 26
learning rate: 0.00044110027820677195


100%|██████████| 23/23 [00:08<00:00,  2.58it/s]
100%|██████████| 9/9 [00:02<00:00,  3.85it/s]


===> train loss: 1.31
===> val loss: 1.66, val iou: 0.13
=> saving checkpoint
Starting epoch 27
learning rate: 0.000438818068039153


100%|██████████| 23/23 [00:09<00:00,  2.55it/s]
100%|██████████| 9/9 [00:02<00:00,  3.81it/s]


===> train loss: 1.26
===> val loss: 1.61, val iou: 0.15
=> saving checkpoint
Starting epoch 28
learning rate: 0.000436534538285843


100%|██████████| 23/23 [00:09<00:00,  2.52it/s]
100%|██████████| 9/9 [00:02<00:00,  3.80it/s]


===> train loss: 1.17
===> val loss: 1.66, val iou: 0.15
=> saving checkpoint
Starting epoch 29
learning rate: 0.00043424968050508256


100%|██████████| 23/23 [00:09<00:00,  2.51it/s]
100%|██████████| 9/9 [00:02<00:00,  3.81it/s]


===> train loss: 1.13
===> val loss: 1.67, val iou: 0.14
=> saving checkpoint
Starting epoch 30
learning rate: 0.00043196348615140955


100%|██████████| 23/23 [00:09<00:00,  2.55it/s]
100%|██████████| 9/9 [00:02<00:00,  3.85it/s]


===> train loss: 1.19
===> val loss: 1.59, val iou: 0.14
=> saving checkpoint
Starting epoch 31
learning rate: 0.0004296759465737673


100%|██████████| 23/23 [00:09<00:00,  2.51it/s]
100%|██████████| 9/9 [00:02<00:00,  3.83it/s]


===> train loss: 1.13
===> val loss: 1.62, val iou: 0.14
=> saving checkpoint
Starting epoch 32
learning rate: 0.00042738705301356716


100%|██████████| 23/23 [00:08<00:00,  2.57it/s]
100%|██████████| 9/9 [00:02<00:00,  3.77it/s]


===> train loss: 1.11
===> val loss: 1.86, val iou: 0.14
=> saving checkpoint
Starting epoch 33
learning rate: 0.0004250967966027037


100%|██████████| 23/23 [00:08<00:00,  2.56it/s]
100%|██████████| 9/9 [00:02<00:00,  4.02it/s]


===> train loss: 1.11
===> val loss: 1.54, val iou: 0.18
=> saving checkpoint
Starting epoch 34
learning rate: 0.00042280516836152096


100%|██████████| 23/23 [00:09<00:00,  2.54it/s]
100%|██████████| 9/9 [00:02<00:00,  3.92it/s]


===> train loss: 1.08
===> val loss: 1.57, val iou: 0.16
=> saving checkpoint
Starting epoch 35
learning rate: 0.00042051215919672877


100%|██████████| 23/23 [00:09<00:00,  2.49it/s]
100%|██████████| 9/9 [00:02<00:00,  3.85it/s]


===> train loss: 1.02
===> val loss: 1.55, val iou: 0.15
=> saving checkpoint
Starting epoch 36
learning rate: 0.00041821775989926696


100%|██████████| 23/23 [00:09<00:00,  2.53it/s]
100%|██████████| 9/9 [00:02<00:00,  3.86it/s]


===> train loss: 1.12
===> val loss: 1.56, val iou: 0.16
=> saving checkpoint
Starting epoch 37
learning rate: 0.00041592196114211634


100%|██████████| 23/23 [00:08<00:00,  2.58it/s]
100%|██████████| 9/9 [00:02<00:00,  3.99it/s]


===> train loss: 1.04
===> val loss: 1.67, val iou: 0.16
=> saving checkpoint
Starting epoch 38
learning rate: 0.0004136247534780547


100%|██████████| 23/23 [00:09<00:00,  2.51it/s]
100%|██████████| 9/9 [00:02<00:00,  3.83it/s]


===> train loss: 1.03
===> val loss: 1.57, val iou: 0.18
=> saving checkpoint
Starting epoch 39
learning rate: 0.00041132612733735566


100%|██████████| 23/23 [00:08<00:00,  2.58it/s]
100%|██████████| 9/9 [00:02<00:00,  3.97it/s]


===> train loss: 0.98
===> val loss: 1.64, val iou: 0.18
=> saving checkpoint
Starting epoch 40
learning rate: 0.00040902607302542923


100%|██████████| 23/23 [00:08<00:00,  2.58it/s]
100%|██████████| 9/9 [00:02<00:00,  3.80it/s]


===> train loss: 1.02
===> val loss: 1.54, val iou: 0.17
=> saving checkpoint
Starting epoch 41
learning rate: 0.00040672458072040163


100%|██████████| 23/23 [00:09<00:00,  2.55it/s]
100%|██████████| 9/9 [00:02<00:00,  3.85it/s]


===> train loss: 1.02
===> val loss: 1.62, val iou: 0.18
=> saving checkpoint
Starting epoch 42
learning rate: 0.00040442164047063304


100%|██████████| 23/23 [00:08<00:00,  2.57it/s]
100%|██████████| 9/9 [00:02<00:00,  3.77it/s]


===> train loss: 1.03
===> val loss: 1.60, val iou: 0.16
=> saving checkpoint
Starting epoch 43
learning rate: 0.0004021172421921706


100%|██████████| 23/23 [00:09<00:00,  2.52it/s]
100%|██████████| 9/9 [00:02<00:00,  3.93it/s]


===> train loss: 0.94
===> val loss: 1.61, val iou: 0.18
=> saving checkpoint
Starting epoch 44
learning rate: 0.0003998113756661346


100%|██████████| 23/23 [00:08<00:00,  2.57it/s]
100%|██████████| 9/9 [00:02<00:00,  3.95it/s]


===> train loss: 0.97
===> val loss: 1.63, val iou: 0.17
=> saving checkpoint
Starting epoch 45
learning rate: 0.000397504030536037


100%|██████████| 23/23 [00:09<00:00,  2.52it/s]
100%|██████████| 9/9 [00:02<00:00,  4.03it/s]


===> train loss: 1.01
===> val loss: 1.63, val iou: 0.21
=> saving checkpoint
Starting epoch 46
learning rate: 0.0003951951963050278


100%|██████████| 23/23 [00:09<00:00,  2.54it/s]
100%|██████████| 9/9 [00:02<00:00,  3.73it/s]


===> train loss: 0.99
===> val loss: 1.60, val iou: 0.17
=> saving checkpoint
Starting epoch 47
learning rate: 0.00039288486233306853


100%|██████████| 23/23 [00:09<00:00,  2.52it/s]
100%|██████████| 9/9 [00:02<00:00,  3.91it/s]


===> train loss: 0.88
===> val loss: 1.59, val iou: 0.21
=> saving checkpoint
Starting epoch 48
learning rate: 0.0003905730178340304


100%|██████████| 23/23 [00:09<00:00,  2.55it/s]
100%|██████████| 9/9 [00:02<00:00,  4.00it/s]


===> train loss: 0.89
===> val loss: 1.60, val iou: 0.19
=> saving checkpoint
Starting epoch 49
learning rate: 0.0003882596518727134


100%|██████████| 23/23 [00:09<00:00,  2.54it/s]
100%|██████████| 9/9 [00:02<00:00,  3.83it/s]

===> train loss: 0.93
===> val loss: 1.73, val iou: 0.17
=> saving checkpoint





<div class="alert alert-block alert-warning"> 
  Common causes for errors during training, may include : <br>
    1. Not having <b>center images</b> for  <b>both</b> train and val directories  <br>
    2. <b>Mismatch</b> between type of center-images saved in <b>01-data.ipynb</b> and the type of center chosen in this notebook (see the <b><a href="#center"> center</a></b> parameter in the third code cell in this notebook)   <br>
    3. In case of resuming training from a previous checkpoint, please ensure that the model weights are read from the correct directory, using the <b><a href="#resume"> resume_path</a></b> parameter. Additionally, please ensure that the <b>save_dir</b> parameter for saving the model weights points to a relevant directory. 
</div>