In [1]:
import torch, torchvision
import os
import random
import logging #Se renseigner, par encore utilisé
import dataset

import constants as cst
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
# import matplotlib.pyplot as plt
import torch.nn.functional as F

import loss_fn
from unet import UNET

In [2]:
TERM = "VC"
EXTRA = "ALL"

DATASET = "/datasets/zebrafish_images"
MASKS = "/datasets/zebrafish_vertebraes"

In [3]:
random.seed(cst.SEED)
torch.manual_seed(cst.SEED)
np.random.seed(cst.SEED)

SIZE = (384, 512)

DEVICE_NAME = "cpu"
if torch.cuda.is_available():
    DEVICE_NAME = 'cuda:0'
DEVICE = torch.device(DEVICE_NAME)

transform = transforms.Compose([transforms.Resize(SIZE),
                                transforms.Pad((0, 64, 0, 64))])
untransform = transforms.Compose([transforms.CenterCrop(SIZE),
                                 transforms.Resize((1932, 2576))])

model = UNET(3, 2)
overall_best_model = UNET(3, 2)
overall_best_model = model
model.to(DEVICE)

best_fold = 0
overall_best = 10
best_fold_epoch = 0
fold_best_vals = []

image_folder = DATASET
mask_folder = MASKS


for fold in range(cst.FOLDS):
    # Datasets and loaders
    training_set = dataset.ZebrafishDataset_KFold_v2(image_folder,
                                                  mask_folder,
                                                  actual_fold=fold,
                                                  dataset="train",
                                                  folds=5)
    validation_set = dataset.ZebrafishDataset_KFold_v2(image_folder,
                                                    mask_folder,
                                                    actual_fold=fold,
                                                    dataset="validate",
                                                    folds=5)
    testing_set = dataset.ZebrafishDataset_KFold_v2(image_folder,
                                                 mask_folder,
                                                 actual_fold=fold,
                                                 dataset="test",
                                                 folds=5)

    training_loader = torch.utils.data.DataLoader(training_set,
                                                  batch_size=cst.BATCH_SIZE,
                                                  shuffle=True,
                                                  num_workers=cst.WORKERS)

    validation_loader = torch.utils.data.DataLoader(validation_set,
                                                    batch_size=cst.BATCH_SIZE,
                                                    shuffle=True,
                                                    num_workers=cst.WORKERS)

    testing_loader = torch.utils.data.DataLoader(testing_set,
                                                 batch_size=1,
                                                 shuffle=True,
                                                 num_workers=cst.WORKERS)

    # (Channels x Classes)
    model = UNET(3, 2)
    best_model = UNET(3, 2)
    best_model = model
    model.to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    criterion_string = "CE"

    if cst.LOSS == "Dice":
        print("Dice")
        criterion = loss_fn.DiceLoss()
        criterion_string = "DCE"
    if cst.LOSS == "IOU":
        print("IOU")
        criterion = loss_fn.IoULoss()
        criterion_string = "IOU"
    if cst.LOSS == "Tversky":
        print("Twersky")
        criterion = loss_fn.TverskyLoss()
        criterion_string = "Tversky"

    optimiser = torch.optim.Adam(model.parameters(), lr=cst.LEARNING_RATE, weight_decay=cst.WEIGHT_DECAY)
    optimiser_string = "ADAM" + "_" + "LR" + str(cst.LEARNING_RATE) + "_" + "WD" + str(cst.WEIGHT_DECAY)

    if cst.OPTIMIZER == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cst.LEARNING_RATE,
                                    momentum=cst.MOMENTUM,
                                    weight_decay=cst.WEIGHT_DECAY)
        optimiser_string = "SGD" + "_" + "LR" + str(cst.LEARNING_RATE) + "_" + "M" + str(cst.MOMENTUM)
        optimiser_string += "_" + "WD" + str(cst.WEIGHT_DECAY)

    model.eval()
    with torch.no_grad():
        val_loss = []
        for images, masks, names in validation_loader:
            images = transform(images)
            outputs = model(images.to(DEVICE))
            outputs = untransform(outputs)

            masks = masks.type(torch.LongTensor)
            masks = torch.squeeze(masks, 1)

            if cst.LOSS == "CE":
                vloss = criterion(outputs, masks.to(DEVICE))
            else:
                vloss = criterion(outputs, F.one_hot(masks, 2).permute(0, 3, 1, 2).float())

            loss = vloss.detach().item()
            val_loss.append(loss)

        loss = np.mean(val_loss)
        print("Validation loss before training: {}".format(loss))

    best_val = loss
    best_epoch = 0

    params_string = "Params" + "_" + "Epoch" + str(cst.EPOCHS) + "_" + "BS" + str(cst.BATCH_SIZE)
    params_string += "_" + "W" + str(cst.WORKERS)

    epochs_train_losses = []
    epochs_val_losses = []
    for i in range(cst.EPOCHS):
        print("Starting epoch {}".format(i+1), end=". ")

        model.train()
        train_loss = []
        for images, masks, names in training_loader:
            images = transform(images)
            outputs = model(images.to(DEVICE))
            outputs = untransform(outputs)

            masks = masks.type(torch.LongTensor)
            masks = torch.squeeze(masks, 1)

            if cst.LOSS == "CE":
                tloss = criterion(outputs, masks.to(DEVICE))
            else:
                tloss = criterion(outputs, F.one_hot(masks, 2).permute(0, 3, 1, 2).float())

            loss = tloss.detach().item()
            train_loss.append(loss)

            optimiser.zero_grad()
            tloss.backward()
            optimiser.step()

        #print("d")
        loss = np.mean(train_loss)
        epochs_train_losses.append(loss)
        print("Trained: {}".format(loss), end=". ")

        model.eval()
        with torch.no_grad():
            val_loss = []
            for images, masks, names in validation_loader:
                images = transform(images)
                outputs = model(images.to(DEVICE))
                outputs = untransform(outputs)

                masks = masks.type(torch.LongTensor)
                masks = torch.squeeze(masks, 1)

                if cst.LOSS == "CE":
                    vloss = criterion(outputs, masks.to(DEVICE))
                else:
                    vloss = criterion(outputs, F.one_hot(masks, 2).permute(0, 3, 1, 2).float())

                loss = vloss.detach().item()
                val_loss.append(loss)

            loss = np.mean(val_loss)
            epochs_val_losses.append(loss)
            print("Validation: {}.".format(loss))

            if loss < best_val:
                best_val = loss
                best_model = model
                best_epoch = i+1

    #print("Training: {}".format(epochs_train_losses))
    #print("Validating: {}".format(epochs_val_losses))
    print("Best score: {}".format(best_val))

    #model_filepath = os.path.join(cst.MODEL, model_name)
    #best_filepath = os.path.join(cst.MODEL, best_name)
    #torch.save(model.state_dict(), model_filepath)
    #torch.save(best_model.state_dict(), best_filepath)

    fold_best_vals.append(best_val)
    
    if best_val < overall_best:
        overall_best = best_val
        overall_best_model = best_model
        best_fold_epoch = best_epoch
        best_fold = fold

    print("--------------------")
    print("Fold: {}".format(fold))
    print("Last val: {}".format(loss))
    print("Best val: {}".format(best_val))
    print("--------------------")


    """
    index = [i+1 for i in range(cst.EPOCHS)]
    plt.plot(index, epochs_train_losses, label="Training")
    plt.plot(index, epochs_val_losses, label="Validation")
    plt.title(cst.LOSS)
    plt.ylabel("Loss")
    plt.xlabel("Epochs")
    plt.legend()
    plot_name = name + ".png"
    plt.savefig(plot_name)
    """

name = TERM + "_" + EXTRA + "_Fold" + str(best_fold) +  "_EPOCH" + str(best_fold_epoch) + "_Val" + str(best_val)
name = name + criterion_string + "_" + optimiser_string + "_" + params_string
model_name = name + ".pth"
best_filepath = os.path.join(cst.MODEL, model_name)
#torch.save(overall_best_model.state_dict(), best_filepath)

print("-------------------------")
print("-----------END-----------")
print("-------------------------")
print("Best fold: {}".format(best_fold))
print("Best validation loss: {}".format(overall_best))
print("Best epoch: {}".format(best_fold_epoch))
print("Mean val of folds: {}".format(np.mean(fold_best_vals)))
print("Best vals for each fold:")
a = 0
for item in fold_best_vals:
    print("Fold :{}".format(a+1), end=" - ")
    print("Val: {}".format(item))

Training set length: 118
Validation set length: 30
Testing set length: 29
Validation loss before training: 0.6688893884420395
Starting epoch 1. Trained: 0.6770721395810445. Validation: 0.6640378162264824.
Starting epoch 2. Trained: 0.6568287551403046. Validation: 0.6479606181383133.
Starting epoch 3. Trained: 0.6287015855312348. Validation: 0.6060688868165016.
Starting epoch 4. Trained: 0.5605874558289846. Validation: 0.48322538658976555.
Starting epoch 5. Trained: 0.13975270371884108. Validation: 0.042093385476619005.
Starting epoch 6. Trained: 0.01889147103453676. Validation: 0.018525584368035197.
Starting epoch 7. Trained: 0.015393812116235495. Validation: 0.015903846826404333.
Starting epoch 8. Trained: 0.014059053920209407. Validation: 0.014704628963954747.
Starting epoch 9. Trained: 0.013171347634245953. Validation: 0.013860409846529365.
Starting epoch 10. Trained: 0.012542510265484452. Validation: 0.013865678803995252.
Starting epoch 11. Trained: 0.011924838383371631. Validation