# embryo_binary_segmentation train_mouse_embryo

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchio as tio

In [2]:
import embryo_binary_segmentation._unet_smaller as unet_smaller
from embryo_binary_segmentation._data_load import upload_data
from embryo_binary_segmentation._losses import dice_loss, focal_loss
from embryo_binary_segmentation._config import DATA_PARAMS, FINE_TUNING, TRAINING_PARAMS
from embryo_binary_segmentation._train_functions import train



# Load parameters

In [3]:
loss, learning_rate, batch_size, epochs, save_model_path, fine_tuning, save_each = TRAINING_PARAMS.values()

In [4]:
data_path, binarize, target_size, patch_size, augmentations = DATA_PARAMS.values()

train_folder = f"{data_path}Train/"
val_folder = f"{data_path}Val/"

In [5]:
upload_model_path, old_steps = FINE_TUNING.values()

In [6]:
if loss == 'bce':
    loss_fn = nn.BCELoss()
elif loss == 'dice':
    loss_fn = dice_loss
elif loss == 'focal':
    loss_fn = focal_loss  

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device, '\n')

cuda 



# Upload Data

In [8]:
train_dataset = upload_data(train_folder, 'train', binarize, patch_size)
print("Train data is loaded \n")

/home/polinasoloveva/Data/Train/e7_woon/SEG_seeds_from_prev_cropped_binary
/home/polinasoloveva/Data/Train/e7_woon/FUSE_raw_cropped
/home/polinasoloveva/Data/Train/e12_JLM/SEG_seeds_from_previous_binary
/home/polinasoloveva/Data/Train/e12_JLM/FUSE
Train data is loaded 



In [9]:
train_dataset[0]['image'][tio.DATA].shape

torch.Size([1, 32, 256, 256])

In [10]:
val_dataset = upload_data(val_folder, 'val', binarize, patch_size)
print("Validation data is loaded \n")

/home/polinasoloveva/Data/Val/e7_woon/SEG_seeds_from_prev_cropped_binary
/home/polinasoloveva/Data/Val/e7_woon/FUSE_raw_cropped
/home/polinasoloveva/Data/Val/e12_JLM/SEG_seeds_from_previous_binary
/home/polinasoloveva/Data/Val/e12_JLM/FUSE
Validation data is loaded 



In [11]:
if augmentations:
    spatial = tio.OneOf({
        tio.RandomElasticDeformation(num_control_points=(6, 6, 8), locked_borders=2, max_displacement=(16, 16, 2)): 0.1,
        tio.RandomAffine(scales=(1, 1.05), degrees=5): 0.2,
        tio.RandomFlip(axes=('LR',)): 0.1,
        tio.RandomGhosting(): 0.2,
        tio.RandomBiasField():0.1,
        tio.RandomNoise(): 0.1
                        },
        p=0.8,)

    subjects_dataset = tio.SubjectsDataset(train_dataset, transform=spatial)

    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    train_loader = DataLoader(subjects_dataset, batch_size=batch_size, shuffle=True)

else:
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


print("Dataloaders are created \n")

Dataloaders are created 



# Train

In [12]:
unet_small = unet_smaller.UNet()
unet_small = unet_small.to(device)
optim = torch.optim.SGD(unet_small.parameters(), lr=learning_rate)



In [13]:
min_val_loss, best_epoch = train(unet_small, device, optim, loss_fn, 5, train_loader, val_loader, save_model_path, save_each, upload_model_path, old_steps)

Loaded model weights from /home/polinasoloveva/Models/Test/best.model
* Epoch 3/7


  self.parse_free_form_transform(
Training: 100%|█████████████████████████████████████████████████████████████████| 51/51 [00:57<00:00,  1.12s/it]


Train loss: 0.569606
New best model saved with loss 0.5635792016983032
* Epoch 4/7


Training: 100%|█████████████████████████████████████████████████████████████████| 51/51 [00:57<00:00,  1.12s/it]


Train loss: 0.558055
New best model saved with loss 0.551106333732605
* Epoch 5/7


Training: 100%|█████████████████████████████████████████████████████████████████| 51/51 [01:00<00:00,  1.18s/it]


Train loss: 0.547016
New best model saved with loss 0.5406396389007568
* Epoch 6/7


Training: 100%|█████████████████████████████████████████████████████████████████| 51/51 [00:58<00:00,  1.15s/it]


Train loss: 0.536757
New best model saved with loss 0.5293614268302917
* Epoch 7/7


Training: 100%|█████████████████████████████████████████████████████████████████| 51/51 [00:58<00:00,  1.15s/it]


Train loss: 0.529766
New best model saved with loss 0.5197939872741699


In [14]:
print(f"Best epoch: {best_epoch}, Min Validation Loss: {min_val_loss}")

Best epoch: 6, Min Validation Loss: 0.5197939872741699


In [15]:
del unet_small
torch.cuda.empty_cache()