In [1]:
%load_ext autoreload
%autoreload 2

<h3> Загрузка библиотек

In [2]:
import os
import numpy as np

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

In [3]:
from ml.models.unet3d import U_Net
from ml.models.rog import ROG
from ml.models.unet_deepsup import U_Net_DeepSup

from ml.utils import get_total_params, load_pretrainned
from ml.dataset import HeadDataset
from ml.controller import Controller
from ml.losses import (ExponentialLogarithmicLoss, WeightedExpBCE, TverskyLoss,
                       IOU_Metric, MultyscaleLoss, SumLoss, LinearCombLoss)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [5]:
train_settings  = {
    "patch_shape" : (64, 64, 64),
    "patches_per_volume" : 256,
    "augmentation_coef" : 1,
    "patches_queue_length" : 1440,
    "batch_size" : 8,
    "num_workers": 4,
    "sampler": "uniform",#"weighted"
}
test_settings  = {
    "patch_shape" : (256, 256, 128),
    "overlap_shape" : (32, 32, 24),
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/HeadData"
dataset = HeadDataset(data_dir, train_settings, test_settings)

In [None]:
train_settings = None
test_settings  = {
    "patch_shape" : (256, 256, 128),
    "overlap_shape" : (32, 32, 24),
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/HeadData_test"
dataset = HeadDataset(data_dir, train_settings, test_settings)

In [6]:
class swish(nn.Module):
    def forward(self, input_tensor):
        return input_tensor * torch.sigmoid(input_tensor)

model = U_Net_DeepSup(channel_coef=16, act_fn=swish())

In [101]:
#funcs_and_сoef_list = []

#funcs_and_сoef_list.append([ExponentialLogarithmicLoss(gamma_tversky = 1, gamma_bce = 1, lamb=0.0,
#                                   freq = 0.001, tversky_alfa=0.75), 1])

#funcs_and_сoef_list.append([TverskyLoss(0.75), 1])


#funcs_and_сoef_list.append([SumLoss(alfa=0.5), 0.1])

#loss_fn = LinearCombLoss(funcs_and_сoef_list)

In [7]:
loss_fn = MultyscaleLoss(ExponentialLogarithmicLoss(gamma_tversky = 1, gamma_bce = 1, lamb=0.9,
                                                    freq = 0.001, tversky_alfa=0.75))

metric_fn = IOU_Metric()

controller_config = {
    "loss" : loss_fn,
    "metric" : metric_fn,
    'device' : device,
    "optimizer_fn" : lambda model: torch.optim.ASGD(model.parameters(), lr=0.25),
    "sheduler_fn": lambda optimizer: StepLR(optimizer, step_size=5, gamma=0.5)
}
controller = Controller(controller_config)

In [8]:
controller.fit(model, dataset, 50)

Epoch 1/50


100%|███████████████████████████████████████████| 32/32 [00:12<00:00,  2.54it/s]


{'mean_loss': 1.0309889167547226}


100%|█████████████████████████████████████████████| 2/2 [00:32<00:00, 16.14s/it]


{'metrics': [{'sample': 'P62_CTA', 'seg_sum/GT_sum': tensor(0.), 'metric1': tensor([3.1986e-11])}, {'sample': 'new_CTA', 'seg_sum/GT_sum': tensor(0.), 'metric1': tensor([3.1248e-11])}]}
Epoch 2/50


100%|███████████████████████████████████████████| 32/32 [00:15<00:00,  2.12it/s]


{'mean_loss': 0.9861912503838539}


100%|█████████████████████████████████████████████| 2/2 [00:32<00:00, 16.13s/it]


{'metrics': [{'sample': 'P62_CTA', 'seg_sum/GT_sum': tensor(2.2818), 'metric1': tensor([0.0010])}, {'sample': 'new_CTA', 'seg_sum/GT_sum': tensor(0.0042), 'metric1': tensor([3.1118e-11])}]}
Epoch 3/50


100%|███████████████████████████████████████████| 32/32 [00:15<00:00,  2.11it/s]


{'mean_loss': 0.9693324230611324}


100%|█████████████████████████████████████████████| 2/2 [00:32<00:00, 16.14s/it]


{'metrics': [{'sample': 'P62_CTA', 'seg_sum/GT_sum': tensor(0.), 'metric1': tensor([3.1986e-11])}, {'sample': 'new_CTA', 'seg_sum/GT_sum': tensor(0.), 'metric1': tensor([3.1248e-11])}]}
Epoch 4/50


100%|███████████████████████████████████████████| 32/32 [00:15<00:00,  2.10it/s]


{'mean_loss': 0.9771508425474167}


100%|█████████████████████████████████████████████| 2/2 [00:32<00:00, 16.15s/it]


{'metrics': [{'sample': 'P62_CTA', 'seg_sum/GT_sum': tensor(0.), 'metric1': tensor([3.1986e-11])}, {'sample': 'new_CTA', 'seg_sum/GT_sum': tensor(0.), 'metric1': tensor([3.1248e-11])}]}
Epoch 5/50


100%|███████████████████████████████████████████| 32/32 [00:15<00:00,  2.10it/s]


{'mean_loss': 1.0233476720750332}


  0%|                                                     | 0/2 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [11]:
model_name = "Unet16_logTversky_54"
#controller.save("/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)

In [12]:
controller.load_model(model, "/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)
controller.model = model.to(device)

In [13]:
controller.val_epoch(dataset.test_dataloader)

100%|█████████████████████████████████████████████| 2/2 [00:22<00:00, 11.45s/it]


{'metrics': [{'sample': 'P62_CTA',
   'seg_sum/GT_sum': tensor(0.6784),
   'metric1': tensor([0.4902])},
  {'sample': 'new_CTA',
   'seg_sum/GT_sum': tensor(0.8539),
   'metric1': tensor([0.3330])}]}

In [27]:
data_dir = "seg_data/" + model_name
if not os.path.exists(data_dir):
    os.mkdir(data_dir)
controller.predict(dataset.test_dataloader, None)

  0%|                                                     | 0/2 [00:06<?, ?it/s]


KeyboardInterrupt: 