In [1]:
import torch
import src
import wandb
import numpy as np
import os
import cv2
import random

from PIL import Image
from sklearn.model_selection import KFold

In [2]:
def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [3]:
ENCODER = 'resnet152'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu') 
PARAMS = {
    "lr" : 0.00005,
    "batch_size" : 1,
    "num_splits" : 3,
}
set_random_seed(345262)

In [4]:
loss = src.utils.base.SumOfLosses(
    src.utils.losses.DiceLoss(),
    src.utils.losses.BCELoss()
)

# loss = src.utils.losses.JaccardLoss()

# loss = src.utils.losses.DiceLoss()

metrics = [
    src.utils.metrics.IoU(threshold=0.9)
]

In [5]:
wandb.login()
wandb.init(project="cancer_research-dima")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdalyutkin01[0m ([33mcancer_research[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
preprocess = src.utils.dataset.Preprocessing(
    dir='./data/',
    img_size=768,
    augmentation=False
)

In [7]:
clear_images, clear_masks = preprocess.load_folder()
kfold = KFold(n_splits=PARAMS["num_splits"], shuffle=True)
fold_data = []
for fold, (train_indices, test_indices) in enumerate(kfold.split(clear_masks)):
    train_data = {
        "images" : [clear_images[index] for index in train_indices],
        "masks" : [clear_masks[index] for index in train_indices]
    }
    test_data = {
        "images" : [clear_images[index] for index in test_indices],
        "masks" : [clear_masks[index] for index in test_indices]
    }
    fold_data.append((train_data, test_data))

In [8]:
clear_masks[0].shape

(615, 980)

In [9]:
track = 0
TEST_IMAGES = 0
TEST_MASKS = 0
loader_for_top_model = 0
for fold, (train_data, test_data) in enumerate(fold_data):
    print(f"Training on fold {fold + 1}...")
    model = src.UnetPlusPlus(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS,
        in_channels=3,
        classes=1,
        activation=ACTIVATION,
    )
    optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=PARAMS["lr"]),])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
        factor=0.1, patience=1, threshold=0.0001, threshold_mode='abs')
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.02)
    train_epoch = src.utils.train.TrainEpoch(
        model, 
        loss=loss, 
        metrics=metrics, 
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )
    valid_epoch = src.utils.train.ValidEpoch(
        model,
        loss=loss, 
        metrics=metrics, 
        device=DEVICE,
        verbose=True,
    )
    
    train_dataloader, valid_dataloader = preprocess.Generator(
        batch_size=PARAMS["batch_size"],
        X_train=train_data["images"],
        y_train=train_data["masks"],
        X_test=test_data["images"],
        y_test=test_data["masks"]
    )

    for epoch in range(0, 5):
        print('\nEpoch: {}'.format(epoch + 1))
        train_logs = train_epoch.run(train_dataloader)
        valid_logs = valid_epoch.run(valid_dataloader)
        
        wandb.log({
            f'fold_{fold+1}/train/train_IoU': train_logs['iou_score'],
            f'fold_{fold+1}/train/train_loss': train_logs['dice_loss + bce_loss']
        })
        wandb.log({
            f'fold_{fold+1}/valid/valid_IoU': valid_logs['iou_score'], 
            f'fold_{fold+1}/valid/valid_loss': valid_logs['dice_loss + bce_loss']
        })
        
        scheduler.step(valid_logs['dice_loss + bce_loss'])
        if track < valid_logs['iou_score']:
            TEST_IMAGES = test_data["images"]
            TEST_MASKS = test_data["masks"]
            track = valid_logs['iou_score']
            loader_for_top_model = test_data
            torch.save({
                'model_state_dict': model.state_dict(),
                'best_IoU' : track
                }, f'./checkpoint/best_model_{PARAMS["num_splits"]}fold.pth')
            print('Model saved!')
wandb.finish()

Training on fold 1...

Epoch: 1
train: 100%|██████████| 68/68 [00:38<00:00,  1.77it/s, dice_loss + bce_loss - 1.557, iou_score - 0.1941] 
valid: 100%|██████████| 34/34 [00:13<00:00,  2.53it/s, dice_loss + bce_loss - 1.411, iou_score - 0.3012]
Model saved!

Epoch: 2
train: 100%|██████████| 68/68 [00:34<00:00,  1.96it/s, dice_loss + bce_loss - 1.364, iou_score - 0.4082]
valid: 100%|██████████| 34/34 [00:13<00:00,  2.52it/s, dice_loss + bce_loss - 1.339, iou_score - 0.357] 
Model saved!

Epoch: 3
train: 100%|██████████| 68/68 [00:33<00:00,  2.05it/s, dice_loss + bce_loss - 1.268, iou_score - 0.5361]
valid: 100%|██████████| 34/34 [00:13<00:00,  2.50it/s, dice_loss + bce_loss - 1.253, iou_score - 0.3621]
Model saved!

Epoch: 4
train: 100%|██████████| 68/68 [00:32<00:00,  2.07it/s, dice_loss + bce_loss - 1.192, iou_score - 0.6299]
valid: 100%|██████████| 34/34 [00:13<00:00,  2.55it/s, dice_loss + bce_loss - 1.231, iou_score - 0.3604]

Epoch: 5
train: 100%|██████████| 68/68 [00:33<00:00,  2.0

VBox(children=(Label(value='0.001 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.142019…

0,1
fold_1/train/train_IoU,▁▄▆▇█
fold_1/train/train_loss,█▅▃▂▁
fold_1/valid/valid_IoU,▁▅▆▆█
fold_1/valid/valid_loss,█▆▄▃▁
fold_2/train/train_IoU,▁▄▇▇█
fold_2/train/train_loss,█▆▄▂▁
fold_2/valid/valid_IoU,▁▆▇█▅
fold_2/valid/valid_loss,█▅▃▁▂
fold_3/train/train_IoU,▁▄▆▇█
fold_3/train/train_loss,█▄▃▂▁

0,1
fold_1/train/train_IoU,0.70017
fold_1/train/train_loss,1.12567
fold_1/valid/valid_IoU,0.38828
fold_1/valid/valid_loss,1.15039
fold_2/train/train_IoU,0.70853
fold_2/train/train_loss,0.75398
fold_2/valid/valid_IoU,0.30512
fold_2/valid/valid_loss,0.89912
fold_3/train/train_IoU,0.71486
fold_3/train/train_loss,1.17388


In [10]:
valid_pred_folder = './valid_predict_kfold_v1/'
checkpoint = torch.load(f'./checkpoint/best_model_{PARAMS["num_splits"]}fold.pth')
model.load_state_dict(checkpoint['model_state_dict'])
checkpoint['best_IoU']

0.4676332465001406

In [11]:
preds = []
for image in TEST_IMAGES:
    image_prepoc = preprocess.inference_image(image)
    image_prepoc = torch.tensor(np.array([image_prepoc]), dtype = torch.float)
    image_prepoc = image_prepoc.to(DEVICE)
    with torch.no_grad():
        predict = model.forward(image_prepoc)
    predict_value = predict.cpu().detach().numpy()
    preds.append(predict_value)

In [12]:
def test(number):
    test = cv2.resize(
        TEST_MASKS[number],
        (768, 768),
        interpolation = cv2.INTER_AREA
    ).astype('float32')
    return float(metrics[0](torch.tensor(test), torch.tensor(preds[number][0][0])))

In [13]:
test_list = []
for i in range(len(TEST_MASKS)):
    test_list.append(test(i))
print(sum(test_list) / len(test_list))

0.023302702972774997


In [14]:
for num in range(len(TEST_IMAGES)):
    true_mask = TEST_MASKS[num]
    true_mask = cv2.cvtColor(true_mask, cv2.COLOR_BGR2RGB)
    greenBGR = src.utils.plotting.Visual.green_chanel(true_mask)
    concat_true = cv2.addWeighted(TEST_IMAGES[num], 0.4, true_mask, 0.9, 0)
    
    #preproc predict
    predict_mask = src.utils.plotting.Visual.preproc_predict(preds[num])
    blueBGR_predict = src.utils.plotting.Visual.blue_chanel_predict(predict_mask)
    greenBGR_predict = src.utils.plotting.Visual.green_chanel_predict(predict_mask)
    
    #calc IoU
    IoU = src.utils.plotting.Visual.calc_IoU(greenBGR, greenBGR_predict)
    
    #concat two masks
    bitwiseAnd = cv2.bitwise_or(greenBGR, blueBGR_predict)
    
    #concat true images with mask
    concat_predict = cv2.addWeighted(TEST_IMAGES[num], 0.7, bitwiseAnd, 0.7, 0)
    concat_predict_resize = cv2.resize(concat_predict, (1280, 872), interpolation = cv2.INTER_AREA)
    
    #put text
    font = cv2.FONT_HERSHEY_SIMPLEX
    org = (570, 50)
    fontScale = 1
    color = (0, 255, 255)
    thickness = 2
    image = cv2.putText(concat_predict_resize, 'IoU ' + str(round(IoU, 3)), org, font, 
                   fontScale, color, thickness, cv2.LINE_AA)
    
    #save 
    cv2.imwrite(filename=os.path.join(valid_pred_folder + str(num) + "_true" + '.png'), img=concat_true)
    cv2.imwrite(filename=os.path.join(valid_pred_folder + str(num) + "_predict" + '.png'), img=concat_predict_resize)

In [16]:
x = torch.randn(1, 3, 768, 768, requires_grad=True)
model.to('cpu')
model.eval()
torch.onnx.export(
    model, x, "UNET++_trained.onnx", 
    export_params=True, 
    verbose=True,
    input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(20)],
    output_names = [ "output1" ])

  if h % output_stride != 0 or w % output_stride != 0:
