In [None]:
%load_ext autoreload
%autoreload 2

<h3> Загрузка библиотек

In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

In [2]:
from ml.models.unet3d import U_Net
from ml.models.rog import ROG
from ml.models.unet_deepsup import Unet_MSS
from ml.models.constructor import Net, NetBlock, conv_block, bottle_neck_connection

from ml.utils import get_total_params, load_pretrainned
from ml.tio_dataset import TioDataset
from ml.controller import Controller
from ml.losses import (ExponentialLogarithmicLoss, WeightedExpBCE, TverskyLoss,
                       IOU_Metric, MultyscaleLoss, SumLoss, LinearCombLoss)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
train_settings  = {
    "patch_shape" : (64, 64, 64),
    "patches_per_volume" : 128,
    "patches_queue_length" : 1440,
    "batch_size" : 4,
    "num_workers": 4,
    "sampler": "weighted" #"uniform",#
}

val_settings  = {
    "patch_shape" : (64, 64, 64),
    "patches_per_volume" : 32,
    "patches_queue_length" : 1440,
    "batch_size" : 8,
    "num_workers": 4,
    "sampler": "uniform",#"weighted" #"uniform",#
}

test_settings  = {
    "patch_shape" : (128, 128, 128),
    "overlap_shape" : (32, 32, 24),
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/MainData"
dataset = TioDataset(data_dir,
                 train_settings=train_settings,
                 val_settings=val_settings,
                 test_settings=test_settings)

In [5]:
# test_settings  = {
#     "patch_shape" : (128, 128, 128),
#     "overlap_shape" : (32, 32, 24),
#     "batch_size" : 1,
#     "num_workers": 4,
# }

# data_dir = "/home/msst/Documents/medtech/MainData_test"
# dataset = TioDataset(data_dir,
#                      train_settings=None,
#                      val_settings=None,
#                      test_settings=test_settings)

In [6]:
# class swish(nn.Module):
#     def forward(self, input_tensor):
#         return input_tensor * torch.sigmoid(input_tensor)

# model = Unet_MSS(channel_coef=32, act_fn=swish())

In [7]:
channel_coef = 8
act_fn = nn.PReLU()

block_11_settings = {
    "in_blocks" : {
        "IN" : nn.Identity(),
        },    
    "backbone" : conv_block(1, channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_12_settings = {
    "in_blocks" : {
        "b21" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b11" : nn.Identity(),
        },
    "backbone" : conv_block(3*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_13_settings = {
    "in_blocks" : {
        "b22" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b12" : bottle_neck_connection(4*channel_coef, 4*channel_coef, 8*channel_coef, act_fn=act_fn),
        },
    "backbone" : conv_block(12*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_14_settings = {
    "in_blocks" : {
        "b23" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b13" : nn.Identity(),
        },
    "backbone" : conv_block(12*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_21_settings = {
    "in_blocks" : {
        "b11" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(channel_coef, 2*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_22_settings = {
    "in_blocks" : {
        "b31" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b21" : nn.Identity(),
        "b12" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(10*channel_coef, 8*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_23_settings = {
    "in_blocks" : {
        "b32" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b22" : nn.Identity(),
        "b13" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(16*channel_coef, 8*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_31_settings = {
    "in_blocks" : {
        "b21" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(2*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_32_settings = {
    "in_blocks" : {
        "b31" : bottle_neck_connection(4*channel_coef, 4*channel_coef, 8*channel_coef, act_fn=act_fn),
        "b22" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(12*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_out_settings = {
     "in_blocks" : {
        "b14" : nn.Identity(),
        },
    "backbone" : conv_block(4*channel_coef, 1, kernel_size=3, stride=1, padding=1, act_fn=nn.Sigmoid()),
}

net_blocks = { 
    "b11" : NetBlock(block_11_settings),
    "b12" : NetBlock(block_12_settings),
    "b13" : NetBlock(block_13_settings),
    "b14" : NetBlock(block_14_settings),
    "b21" : NetBlock(block_21_settings),
    "b22" : NetBlock(block_22_settings),
    "b23" : NetBlock(block_23_settings),
    "b31" : NetBlock(block_31_settings),
    "b32" : NetBlock(block_32_settings),
    "out" : NetBlock(block_out_settings),
}

model = Net(net_blocks)

In [8]:
get_total_params(model)

1247274

In [9]:
#funcs_and_сoef_list = []

#funcs_and_сoef_list.append([ExponentialLogarithmicLoss(gamma_tversky = 1, gamma_bce = 1, lamb=0.0,
#                                   freq = 0.001, tversky_alfa=0.75), 1])

#funcs_and_сoef_list.append([TverskyLoss(0.75), 1])


#funcs_and_сoef_list.append([SumLoss(alfa=0.5), 0.1])

#loss_fn = LinearCombLoss(funcs_and_сoef_list)

In [10]:
loss_fn = MultyscaleLoss(ExponentialLogarithmicLoss(gamma_tversky = 1, gamma_bce = 1, lamb=0.9,
                                                    freq = 0.001, tversky_alfa=0.55))
metric_fn = IOU_Metric()

controller_config = {
    "loss" : loss_fn,
    "metric" : metric_fn,
    'device' : device,
    "optimizer_fn" : lambda model: torch.optim.ASGD(model.parameters(), lr=0.1),
    "sheduler_fn": lambda optimizer: StepLR(optimizer, step_size=5, gamma=0.5)
}
controller = Controller(controller_config)

cuda


In [None]:
controller.fit(model, dataset, 50)

Epoch 1/50


100%|█████████████████████████████████████████| 128/128 [01:12<00:00,  1.78it/s]


{'mean_loss': 1.0435931296087801}


100%|█████████████████████████████████████████████| 8/8 [00:02<00:00,  3.44it/s]


{'mean_loss': 0.919118233025074}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.43s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.0003), 'metric1': tensor([3.3500e-11])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(0.0003), 'metric1': tensor([3.3872e-11])}]}
Epoch 2/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.73it/s]


{'mean_loss': 0.9031954170204699}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.95it/s]


{'mean_loss': 0.766069769859314}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.40s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.0598), 'metric1': tensor([0.0414])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(0.2425), 'metric1': tensor([0.0417])}]}
Epoch 3/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.75it/s]


{'mean_loss': 0.8734171348623931}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.90it/s]


{'mean_loss': 0.9171577021479607}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.42s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.0855), 'metric1': tensor([0.0567])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(0.1323), 'metric1': tensor([0.0761])}]}
Epoch 4/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.7602338092401624}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.85it/s]


{'mean_loss': 0.6702056303620338}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.45s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.4115), 'metric1': tensor([0.1551])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(0.6365), 'metric1': tensor([0.1409])}]}
Epoch 5/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.6538168140687048}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.88it/s]


{'mean_loss': 0.6042757742106915}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.45s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.8836), 'metric1': tensor([0.2324])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(1.8627), 'metric1': tensor([0.1735])}]}
Epoch 6/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.526658863062039}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.87it/s]


{'mean_loss': 0.6319119110703468}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.44s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.7329), 'metric1': tensor([0.2991])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(1.8450), 'metric1': tensor([0.1853])}]}
Epoch 7/50


100%|█████████████████████████████████████████| 128/128 [01:14<00:00,  1.73it/s]


{'mean_loss': 0.4892423953860998}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.90it/s]


{'mean_loss': 0.5254191905260086}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.43s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.7284), 'metric1': tensor([0.3414])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(1.7085), 'metric1': tensor([0.2222])}]}
Epoch 8/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.4582331113051623}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.90it/s]


{'mean_loss': 0.4027121439576149}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.45s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(0.9427), 'metric1': tensor([0.3844])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(2.0056), 'metric1': tensor([0.2484])}]}
Epoch 9/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.4841310849878937}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.87it/s]


{'mean_loss': 0.7648486644029617}


100%|█████████████████████████████████████████████| 2/2 [00:45<00:00, 22.56s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(1.2659), 'metric1': tensor([0.3879])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(2.2688), 'metric1': tensor([0.2720])}]}
Epoch 10/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.41496554645709693}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.87it/s]


{'mean_loss': 1.1103733032941818}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.44s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(1.6731), 'metric1': tensor([0.3738])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(2.7272), 'metric1': tensor([0.2657])}]}
Epoch 11/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.3789358652429655}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.88it/s]


{'mean_loss': 0.3526517190039158}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.42s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(1.5669), 'metric1': tensor([0.4216])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(2.6094), 'metric1': tensor([0.2902])}]}
Epoch 12/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.40020469785667956}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.89it/s]


{'mean_loss': 0.8096069768071175}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.43s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(1.8335), 'metric1': tensor([0.4060])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(2.8397), 'metric1': tensor([0.2854])}]}
Epoch 13/50


100%|█████████████████████████████████████████| 128/128 [01:14<00:00,  1.73it/s]


{'mean_loss': 0.38302720396313816}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.88it/s]


{'mean_loss': 0.69333141669631}


100%|█████████████████████████████████████████████| 2/2 [00:45<00:00, 22.59s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(2.2051), 'metric1': tensor([0.3677])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(3.1806), 'metric1': tensor([0.2697])}]}
Epoch 14/50


100%|█████████████████████████████████████████| 128/128 [01:14<00:00,  1.73it/s]


{'mean_loss': 0.37585655599832535}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.90it/s]


{'mean_loss': 0.835393238812685}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.43s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(2.4173), 'metric1': tensor([0.3500])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(3.3405), 'metric1': tensor([0.2651])}]}
Epoch 15/50


100%|█████████████████████████████████████████| 128/128 [01:13<00:00,  1.74it/s]


{'mean_loss': 0.3727474871557206}


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.89it/s]


{'mean_loss': 0.43044427409768105}


100%|█████████████████████████████████████████████| 2/2 [00:44<00:00, 22.43s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(2.4382), 'metric1': tensor([0.3525])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(3.3418), 'metric1': tensor([0.2691])}]}
Epoch 16/50


 67%|████████████████████████████▏             | 86/128 [00:50<00:22,  1.87it/s]

In [14]:
model_name = "UnetMSS32_ExpLog09_34"#"Unet16_ExpLog09_25"#"Unet16_ExpLog09_100"#"Unet16_logTversky_54"
#controller.save("/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)

In [10]:
#controller.load_model(model, "/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)
#controller.model = model.to(device)

In [15]:
path_to_check= "/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name
controller.load(model, path_to_checkpoint=path_to_check)

In [12]:
controller.val_epoch(dataset.test_dataloader)

100%|█████████████████████████████████████████████| 7/7 [01:15<00:00, 10.74s/it]


{'metrics': [{'sample': 'P62_CTA',
   'seg_sum/GT_sum': tensor(0.0060),
   'metric1': tensor([7.3981e-11])},
  {'sample': 'P28_CTA',
   'seg_sum/GT_sum': tensor(0.0002),
   'metric1': tensor([0.0011])},
  {'sample': 'P12_CTA',
   'seg_sum/GT_sum': tensor(74.4396),
   'metric1': tensor([0.0129])},
  {'sample': 'P70_CTA',
   'seg_sum/GT_sum': tensor(0.0044),
   'metric1': tensor([0.0011])},
  {'sample': 'P35_CTA',
   'seg_sum/GT_sum': tensor(0.0040),
   'metric1': tensor([0.0006])},
  {'sample': 'new_CTA',
   'seg_sum/GT_sum': tensor(0.0035),
   'metric1': tensor([7.4289e-05])},
  {'sample': 'CT_S5020',
   'seg_sum/GT_sum': tensor(36.7029),
   'metric1': tensor([0.0207])}]}

In [21]:
data_dir = "seg_data/" + model_name
if not os.path.exists(data_dir):
    os.mkdir(data_dir)
controller.predict(dataset.test_dataloader, data_dir)

100%|█████████████████████████████████████████████| 6/6 [02:49<00:00, 28.33s/it]


[{'sample': 'P62_CTA_0',
  'seg_sum/GT_sum': tensor(3.3345),
  'metric1': tensor([0.2538])},
 {'sample': 'P70_CTA_0',
  'seg_sum/GT_sum': tensor(0.0003),
  'metric1': tensor([7.3287e-05])},
 {'sample': 'P12_CTA_0',
  'seg_sum/GT_sum': tensor(6.4134),
  'metric1': tensor([0.1475])},
 {'sample': 'P28_CTA_0',
  'seg_sum/GT_sum': tensor(7.6404e-06),
  'metric1': tensor([3.6783e-05])},
 {'sample': 'P35_CTA_0',
  'seg_sum/GT_sum': tensor(0.0002),
  'metric1': tensor([2.7909e-05])},
 {'sample': 'CT_S5020_0',
  'seg_sum/GT_sum': tensor(5.6240),
  'metric1': tensor([0.1178])}]