# embryo_binary_segmentation train_mouse_embryo

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchio as tio

In [None]:
import _unet_smaller
from _data_load import upload_data
from _losses import dice_loss, focal_loss
from _config import DATA_PARAMS, FINE_TUNING, TRAINING_PARAMS
from _train_functions import train

# Load parameters

In [None]:
loss, learning_rate, batch_size, epochs, save_model_path, fine_tuning, save_each = TRAINING_PARAMS.values()

In [None]:
data_path, binarize, target_size, patch_size, augmentations = DATA_PARAMS.values()

train_folder = f"{save_path}/Train/"
val_folder = f"{save_path}/Val/"

In [None]:
upload_model_path, old_steps = FINE_TUNING.values()

In [None]:
if loss == 'bce':
    loss_fn = nn.BCELoss()
elif loss == 'dice':
    loss_fn = dice_loss
elif loss == 'focal':
    loss_fn = focal_loss  

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device, '\n')

# Upload Data

In [None]:
train_dataset = upload_data(train_folder, 'train', binarize, patch_size)
print("Train data is loaded \n")

In [None]:
val_dataset = upload_data(val_folder, 'val', binarize, patch_size)
print("Validation data is loaded \n")

In [None]:
if augmentations:
    spatial = tio.OneOf({
        tio.RandomElasticDeformation(num_control_points=(6, 6, 8), locked_borders=2, max_displacement=(16, 16, 2)): 0.1,
        tio.RandomAffine(scales=(1, 1.05), degrees=5): 0.2,
        tio.RandomFlip(axes=('LR',)): 0.1,
        tio.RandomGhosting(): 0.2,
        tio.RandomBiasField():0.1,
        tio.RandomNoise(): 0.1
                        },
        p=0.8,)

    subjects_dataset = tio.SubjectsDataset(train_dataset, transform=spatial)

    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    train_loader = DataLoader(subjects_dataset, batch_size=batch_size, shuffle=True)

else:
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


print("Dataloaders are created \n")

# Train

In [None]:
unet_small = unet_smaller.UNet()
unet_small = unet_small.to(device)
optim = torch.optim.SGD(unet_small.parameters(), lr=learning_rate)

In [None]:
min_val_loss, best_epoch = train(unet_small, optim, loss_fn, epochs, train_loader, val_loader, save_model_path, upload_path, old_steps, save_each)

In [None]:
print(f"Best epoch: {best_epoch}, Min Validation Loss: {min_val_loss}")

In [None]:
del unet_small
torch.cuda.empty_cache()