## Сегментация изображений

Перед нами стоит задача **сегментации** кт-снимков легких и определения областей поражения от Covid-19. 

Исходные данные: **json** с названиями файлов,
**images** - сами данные, **labels** - разметка: 0 - этот пиксель **НЕ относится** к поврежденному лёгкому, 1 - пиксель **относится** к повреждённому лёгкому

Каждый скан - numpy arrays of shape (512, 512, n_slices)



## 0. Подготовка данных и импорт библиотек

In [None]:
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import os
import json
import nibabel as nib
import numpy as np
from tqdm.notebook import tqdm
import torchvision
import random

import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
import albumentations as A # Будем использовать для аугментации данных
import segmentation_models_pytorch as smp

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu" # Определяем доступность gpu

device = torch.device(device)
print(device)

In [None]:
# from google.colab import drive # Маунтим диск 
# drive.mount('/content/drive')

Структура проекта:
1. core_path - общий путь до проекта. В коллабе это "./drive/MyDrive/Deep Learning/CovidKaggleTask/"

2. path = core_path + "data/data/" - Путь до самих файлов. В этой директории есть сам json файл, папки images и labels, в которых есть информация о кт-сканах

3. core_path/models - путь до моделей для обучения. В силу больших по объёму данных, необходимо сохранять каждый раз модель. Таким образом в этой папке сохраняются модели после каждой эпохи. lungs_ct_model_1.h5 - корректное название для модели, которая была обучена только на одной эпохе

In [None]:
# core_path = "./drive/MyDrive/Deep Learning/CovidKaggleTask/" # Определяем пути до файлов
# path = core_path + "data/data/"

In [None]:
# Для каггла
core_path = "../input/tgcovid/"
path = core_path + "data/data/"

## 1. Работа с датасетом

In [None]:
from os import listdir
from os.path import isfile, join
onlyfiles = [f for f in listdir(path + "images") if isfile(join(path + "images", f))]
onlyfiles[:5] # Получаем список всех файлов

Определяемся с аугментацией данных. Пытаемся понять, какие преобразования могут быть полезными для данной задачи

In [None]:
transforms = A.Compose([
    # Гауссовый шум
    A.GaussNoise(always_apply=False, p=0.3, var_limit=(1, 5)),
    # Отражение
    # A.Flip(always_apply=False, p=0.5),
    # Рандомное сжатие
    A.GridDistortion(always_apply=False, p=0.2, num_steps=3, distort_limit=(-0.30000001192092896, 0.30000001192092896), interpolation=0, border_mode=0, value=(0, 0, 0), mask_value=None),
    # Отражение
    # A.HorizontalFlip(always_apply=False, p=0.5),
    # Блюр в движении (резкий поворот камеры)
    A.MotionBlur(always_apply=False, p=0.4, blur_limit=(3, 5)),
    # Шум
    A.MultiplicativeNoise(always_apply=False, p=0.3, multiplier=(0.8899999856948853, 1.1699999570846558), per_channel=True, elementwise=True),
    # Искажение
    # A.OpticalDistortion(always_apply=False, p=0.5, distort_limit=(-0.5, 0.5), shift_limit=(-0.05000000074505806, 0.05000000074505806), interpolation=0, border_mode=0, value=(0, 0, 0), mask_value=None),
    # Шум
    # A.RandomGamma(always_apply=False, p=0.5, gamma_limit=(80, 120), eps=1e-07),
    # Случайный поворот
    A.VerticalFlip(always_apply=False, p=0.3)
    ]
)

In [None]:
transforms = A.Compose([
    A.MotionBlur(always_apply=False, p=0.91, blur_limit=(3, 5)),
    A.VerticalFlip(always_apply=False, p=0.01)
    ]
)

In [None]:
transforms = A.Compose([
    A.OneOf([A.GaussNoise(always_apply=False, p=0.25, var_limit=(4, 12)),
             A.MotionBlur(always_apply=False, p=0.25, blur_limit=(3, 5)),
             A.Blur(always_apply=False, p=0.25, blur_limit=(3, 5))]),
    # A.Rotate(always_apply=False, p=0.9, limit=(-15, 15)),
    A.augmentations.geometric.transforms.Perspective(scale=(0.03, 0.08), keep_size=True, pad_mode=0, pad_val=0, mask_pad_val=0, fit_output=False, interpolation=1, always_apply=False, p=1.0)
])

Испытаем аугментации сейчас

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

def blend_with_mask(image, mask): # Функция возвращает картинку с наложенной на неё маской - label, которые являются ковидной штукой
    image = image.astype(np.float32)
    min_in = image.min()
    max_in = image.max()
    image = (image - min_in) / (max_in - min_in + 1e-8) * 255
    image = np.dstack((image, image, image)).astype(np.uint8)
    zeros = np.zeros_like(mask)
    mask = np.dstack((zeros, zeros, mask * 255)).astype(np.uint8)
    return Image.blend(
        Image.fromarray(image),
        Image.fromarray(mask),
        alpha=.3
    )

path_images = os.path.join(path, 'images')
path_labels = os.path.join(path, 'labels')
image = torch.tensor(nib.load(os.path.join(path_images, onlyfiles[0])).get_fdata(), dtype=torch.uint8).transpose(1, 2).transpose(0, 1) # Загружаем конкретный кт-скан по названию из json
label = torch.tensor(nib.load(os.path.join(path_labels, onlyfiles[0][:-4] + "_mask.nii")).get_fdata(), dtype=torch.uint8).transpose(1, 2).transpose(0, 1)
result_aug = transforms(image=image.numpy(), mask=label.numpy())
image = result_aug["image"]
label = result_aug["mask"]
print(image.shape)

slices = []
slices_num = (20, )
for idx in slices_num:
    slices.append(blend_with_mask(
        image[idx],
        label[idx]
    ))    

figure = plt.figure(figsize=(12, 12))
for i, image in enumerate(slices):
    ax = figure.add_subplot(1, len(slices), i + 1)
    ax.imshow(slices[i])

In [None]:
np.asarray(image).shape

Тут как обычно: создаём класс датасета, в котором прописываем метод __init__(self, список из сканов для трейна, список из сканов для валидации, аугментации), метод длины датасета и получения элемента из датасета

In [None]:
class CovidDataset(Dataset):
    def __init__(self, X_train_list, X_val_list, transforms, without_covid_max=9999999):
        # Загружаем сканы кт
        path_images = os.path.join(path, 'images')
        path_labels = os.path.join(path, 'labels')
        # Подгружаем json с инфой по разметке
        with open(core_path + 'training_data.json', 'r') as f:
            dict_training = json.load(f)

        self.X_train = [] 
        self.Y_train = []
        self.X_val = []
        self.Y_val = []
        self.transforms = transforms
        without_covid = 0
        for entry in tqdm(dict_training):
            image = nib.load(os.path.join(path_images, entry['image'][:-3])) # Загружаем конкретный кт-скан по названию из json
            label = nib.load(os.path.join(path_labels, entry['label'][:-3])) # Загружаем лейблы/разметку для кт-скана
            image = torch.tensor(image.get_fdata(), dtype=torch.uint8).transpose(1, 2).transpose(0, 1) # Меняем размерность с [43, 512, 512]
            label = torch.tensor(label.get_fdata(), dtype=torch.uint8).transpose(1, 2).transpose(0, 1) # на [512, 512, 43] для всех картинок
            
            
            if entry['image'][:-3] in X_train_list: # Если этот кт-скан в трейне - загружаем его туда
                for i in range(len(image)): # Пробегаемся по всем слоям в нужном кт-скане image
                    if label[i].sum() != 0:
                        self.X_train.append(image[i]) # Добавляем отдельные картинки
                        self.Y_train.append(label[i])
                    else:
                        if without_covid >= without_covid_max:
                            continue
                        else:
                            without_covid += 1
                            self.X_train.append(image[i]) # Добавляем отдельные картинки
                            self.Y_train.append(label[i])
                        
            else: 
                for i in range(len(image)): # То же самое, но для валидации
                    if label[i].sum() != 0:
                        self.X_val.append(image[i])
                        self.Y_val.append(label[i])
     
    
    
    def __len__(self):
        return len(self.X_train)
    
    def __getitem__(self, idx):
        # Делаем случайную аугментацию. Метод делает аугментацию как для image - нашего скана слоя, так и для его разметки
        # Для начала определяем поворот на угол...
        degrees = [-35, -30, -25, -20, -15, -10, -5, 0, 5, 10, 15, 20, 25, 30, 35]
        X = self.X_train[idx]
        y = self.Y_train[idx]
        X = X.type(torch.float)
        y = y.type(torch.float)
        X = (torch.Tensor(np.array([X.numpy()]) / 255))
        y = (torch.Tensor(np.array([y.numpy()])))
        value = random.random()
        if random.random() > 0.5:
            value = random.random()
            if value > 0.5:
                X = torchvision.transforms.functional.vflip(X)
                y = torchvision.transforms.functional.vflip(y)
            else:
                X = torchvision.transforms.functional.hflip(X)
                y = torchvision.transforms.functional.hflip(y)
        value = random.random()
        if value >= 0.1:
            degree = random.choice(degrees)
            X = torchvision.transforms.functional.rotate(X, degree)
            y = torchvision.transforms.functional.rotate(y, degree)
        else:
            pass
        value = random.random()
        if value > 0.5:
            X = torchvision.transforms.RandomPerspective(distortion_scale=0.15, p=0.5, interpolation=2, fill=0)(X)
            y = torchvision.transforms.RandomPerspective(distortion_scale=0.15, p=0.5, interpolation=2, fill=0)(y)
        else:
            pass
        value = random.random()
        if value > 0.5:
            X = torchvision.transforms.GaussianBlur(1)(X)
            y = torchvision.transforms.GaussianBlur(1)(y)
        else:
            pass
        
        # sl = self.transforms(image=X.numpy(), label=y.numpy())
        return torch.Tensor(X), torch.Tensor(y) # Важно, нельзя передать просто картинку (512, 512), так как используется свёртка по многим измерениям. Необходимо передать в формате
                                            # [палитра, ширина, высота] - [1, 512, 512]
    def get_validation_set(self):
        return (self.X_val, self.Y_val)

In [None]:
batch_size = 8

np.random.shuffle(onlyfiles) # Перемешаем названия файлов в случайном порядке (для генерации трэйна и валидации)
dataset = CovidDataset(onlyfiles[40:], onlyfiles[34:], None, 50)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## 2. Определяемся со структурой нейронной сети

В данном случае, была выбрана сеть Unet (https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/)

Она отлично подходит для анализа медицинских изображений, так как в ней есть множество skip-connection для решения проблемы затухания градиента (vanishing gradient), а также увеличения информации для следующих слоёв, которые отвечают за выбор признаков (т. е. decoder)

In [None]:
import torch.nn as nn

class Unet(nn.Module): # Определим структуру нейронной сети Unet
    def block_down(self, in_features, out_features):
        return nn.Sequential(*[nn.Conv2d(in_features, out_features, (3, 3), padding=1),
                              nn.ReLU(),
                              nn.BatchNorm2d(out_features)])
    
    def block_up(self, in_features, out_features):
        return nn.Sequential(*[nn.Conv2d(in_features, out_features, (3, 3), padding=1),
                              nn.ReLU(),
                              nn.BatchNorm2d(out_features)])
    
    
    def __init__(self):
        super(Unet, self).__init__()
        self.block_up11 = self.block_down(1, 32)
        self.block_up12 = self.block_down(32, 32)
        self.max_pooling11 = nn.MaxPool2d((2, 2), stride=(2, 2))
        
        self.block_up21 = self.block_down(32, 64)
        self.block_up22 = self.block_down(64, 64)
        self.max_pooling22 = nn.MaxPool2d((2, 2), stride=(2, 2))
        
        self.block_up31 = self.block_down(64, 128)
        self.block_up32 = self.block_down(128, 128)
        self.max_pooling33 = nn.MaxPool2d((2, 2), stride=(2, 2))
        
        self.block_up41 = self.block_down(128, 256)
        self.block_up42 = self.block_down(256, 256)
        self.max_pooling44 = nn.MaxPool2d((2, 2), stride=(2, 2))
        
        self.block_up51 = self.block_down(256, 512)
        self.block_up52 = self.block_down(512, 512)
        
        self.block_up61 = nn.Upsample(scale_factor=2)
        self.block_up62 = self.block_up(512, 256)
        self.block_up63 = self.block_up(512, 256)
        self.block_up64 = self.block_up(256, 256)
        
        self.block_up71 = nn.Upsample(scale_factor=2)
        self.block_up72 = self.block_up(256, 128)
        self.block_up73 = self.block_up(256, 128)
        self.block_up74 = self.block_up(128, 128)
        
        self.block_up81 = nn.Upsample(scale_factor=2)
        self.block_up82 = self.block_up(128, 64)
        self.block_up83 = self.block_up(128, 64)
        self.block_up84 = self.block_up(64, 64)
        
        self.block_up91 = nn.Upsample(scale_factor=2)
        self.block_up92 = self.block_up(64, 32)
        self.block_up93 = self.block_up(64, 32)
        self.block_up94 = self.block_up(32, 32)
        
        self.block_up100 = self.block_up(32, 1) 
        
    
    def forward(self, x):
        out = self.block_up11(x)
        out = self.block_up12(out)
        
        save1 = out.clone()
        
        out = self.max_pooling11(out)
        
        out = self.block_up21(out)
        out = self.block_up22(out)
        
        save2 = out.clone()
        
        out = self.max_pooling22(out)
        
        out = self.block_up31(out)
        out = self.block_up32(out)
        
        save3 = out.clone()
        
        out = self.max_pooling33(out)
        
        out = self.block_up41(out)
        out = self.block_up42(out)
        
        save4 = out.clone()
        
        out = self.max_pooling44(out)
        
        out = self.block_up51(out)
        out = self.block_up52(out)
        
        
        out = self.block_up61(out)
        out = self.block_up62(out)
        out = self.block_up63(torch.cat((out, save4), 1))
        out = self.block_up64(out)

        out = self.block_up71(out)
        out = self.block_up72(out)
        out = self.block_up73(torch.cat((out, save3), 1))
        out = self.block_up74(out)

        out = self.block_up81(out)
        out = self.block_up82(out)
        out = self.block_up83(torch.cat((out, save2), 1))
        out = self.block_up84(out)

        out = self.block_up91(out)
        out = self.block_up92(out)
        out = self.block_up93(torch.cat((out, save1), 1))
        out = self.block_up94(out)

        out = self.block_up100(out)
        out = nn.Sigmoid()(out)
        
        return out

Также в качестве Лосса возьмём взвешенную кросс-энтропию (focal loss), потому что она позволяет учесть дисбаланс классов

In [None]:
import torch
import torch.nn.functional as F

def dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
    inp = inputs.contiguous().view(-1)
    tar = targets.contiguous().view(-1)
    noise = random.randint(1, 1000) / 10000000000
    
    return 1 - ((2 * (inp * tar).sum() + noise) / ((inp).sum() + (tar).sum() + noise))

In [None]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.7):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha

    def forward(self, inputs, targets, smooth=1):
        y_pred = inputs
        y_true = targets
        y_true_pos = y_true.view(-1)
        y_pred_pos = y_pred.view(-1)
        true_pos = torch.sum(y_true_pos * y_pred_pos)
        false_neg = torch.sum(y_true_pos * (1 - y_pred_pos))
        false_pos = torch.sum((1 - y_true_pos) * y_pred_pos)
        return 1 - (true_pos + smooth) / (true_pos + self.alpha * false_neg + (1 - self.alpha) * false_pos + smooth)

In [None]:
import torch
import torch.nn.functional as F



def sigmoid_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = 0.25,
    gamma: float = 2,
    reduction: str = "none"):
    """
    Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples or -1 for ignore. Default = 0.25
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
    Returns:
        Loss tensor with the reduction option applied.
    """
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(
        inputs, targets, reduction="none"
    )
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss


### 3. Обучение модели

In [None]:
import segmentation_models_pytorch as smp

In [None]:
use_previous_versions = False
previous_i = 0
path_to_model = "output/kaggle/working/"

if use_previous_versions:
    models_variation = []
    for put, papki, files in os.walk("."):
        for el in files:
            if "lungs_ct_model" in el:
                models_variation.append(el)
    # Название файла - lungs_ct_model_1.h5
    if len(models_variation) != 0:
        
        models_variation = sorted(models_variation, key=lambda x: - int(x.split("_")[-1].split(".")[0]))
        model = torch.load(models_variation[-1])
        previous_i = int(models_variation[0].split("_")[-1].split(".")[0])
        print(models_variation)
        print("Загружена прошлая модель Unet: {}".format(str(previous_i)))
    else:
        # model = UNet(1, in_channels=1, depth=5, 
        #          start_filts=64, up_mode='transpose', 
        #          merge_mode='concat')
        model = Unet()
        print("Нет предобученных моделей")
else:
    model = smp.UnetPlusPlus(encoder_name='resnet18', in_channels=1, classes=1, activation="sigmoid")
    print("Загружен непредобученный Unet")

# lambda1 = lambda epoch: 8839103922077863 ** epoch
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

device = torch.device('cuda:0')

model = model.to(device)

In [None]:
num_epoch = 175
lr = 0.0005

dice_loss_criterion = dice_loss
focal_loss_criterion = sigmoid_focal_loss
tverskoy_loss = TverskyLoss(alpha=0.7)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, threshold=0.035, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)

1. Unet -> Unet++
2. Add scheduler
3. Add validation_check
4. Loss (?)
5. Оформить

In [None]:
class ValidDataset(Dataset):
    def __init__(self, X, Y):
        # Загружаем сканы кт
        self.X = X
        self.Y = Y     
    
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        # Делаем случайную аугментацию. Метод делает аугментацию как для image - нашего скана слоя, так и для его разметки
        # Для начала определяем поворот на угол...
        degrees = [-30, -25, -20, -15, -10, -5, 0, 5, 10, 15, 20, 25, 30]
        X = self.X[idx]
        y = self.Y[idx]
        X = X.type(torch.float)
        y = y.type(torch.float)
        X = (torch.Tensor(np.array([X.numpy()]) / 255))
        y = (torch.Tensor(np.array([y.numpy()])))
        value = random.random()
        if random.random() > 0.5:
            value = random.random()
            if value > 0.5:
                X = torchvision.transforms.functional.vflip(X)
                y = torchvision.transforms.functional.vflip(y)
            else:
                X = torchvision.transforms.functional.hflip(X)
                y = torchvision.transforms.functional.hflip(y)
        value = random.random()
        if value >= 0.1:
            degree = random.choice(degrees)
            X = torchvision.transforms.functional.rotate(X, degree)
            y = torchvision.transforms.functional.rotate(y, degree)
        else:
            pass
        value = random.random()
        if value > 0.5:
            X = torchvision.transforms.RandomPerspective(distortion_scale=0.15, p=0.5, interpolation=2, fill=0)(X)
            y = torchvision.transforms.RandomPerspective(distortion_scale=0.15, p=0.5, interpolation=2, fill=0)(y)
        else:
            pass
        value = random.random()
        if value > 0.5:
            X = torchvision.transforms.GaussianBlur(1)(X)
            y = torchvision.transforms.GaussianBlur(1)(y)
        else:
            pass
        
        # sl = self.transforms(image=X.numpy(), label=y.numpy())
        return torch.Tensor(X), torch.Tensor(y) # Важно, нельзя передать просто картинку (512, 512), так как используется свёртка по многим измерениям. Необходимо передать в формате
                                            # [палитра, ширина, высота] - [1, 512, 512]

In [None]:
def validation_score(model):
    val_tverskoy_losses = []
    valid_dataset = ValidDataset(*dataset.get_validation_set())
    valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False)
    
    for X, Y in valid_loader:
        X = X.to(device)
        Y = Y.to(device)
        output = model.forward(X)
        val_tverskoy_losses.append(tverskoy_loss(output, Y).item())
    
    del valid_dataset
    del valid_loader
    
    return sum(val_tverskoy_losses) / len(val_tverskoy_losses)

In [None]:
losses = []

for epoch in range(num_epoch):
    epoch_losses = []
    dice_losses = []
    focal_losses = []
        
    for X, Y in tqdm(loader):
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        output = model(X)
        
        # dice = dice_loss(output, Y)
        # focal = focal_loss_criterion(output, Y, 0.25, 2, "mean")
        loss = tverskoy_loss(output, Y)
        loss.backward()
        # print(list(model.parameters()))
        clip_grad_norm_(model.parameters(), 99999)
        
        optimizer.step()

        del X
        del Y
        torch.cuda.empty_cache()
        epoch_losses.append(loss.item())
        # dice_losses.append(dice.item())
        # focal_losses.append(focal.item())
    
    # val_loss = validation_score(model)
    common_loss = sum(epoch_losses)/len(epoch_losses)
    scheduler.step(common_loss)
    
    # Выводим и сохраняем Лосс
    # print("Mean Dice Loss: {}".format(str(sum(dice_losses)/len(dice_losses))))
    # print("Mean Focal Loss: {}".format(str(sum(focal_losses)/len(focal_losses))))
    print("Tversky Loss: {}".format(str(common_loss)))
    # print("Validation Loss: {}".format(val_loss))
    losses.append(sum(epoch_losses)/len(epoch_losses))
    
    torch.save(model, "lungs_ct_model_" + str(epoch + previous_i) + ".h5")
    print("Saved model {}".format(epoch + previous_i))

### 4. Визуализация

In [None]:
valid_dataset = ValidDataset(*dataset.get_validation_set())
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False)

In [None]:
#Visualize some of the slices
from PIL import Image
import matplotlib.pyplot as plt

def blend(image, mask): # Функция возвращает картинку с наложенной на неё маской - label, которые являются ковидной штукой
    print(image)
    image = image.astype(np.float32)
    min_in = image.min()
    max_in = image.max()
    image = (image - min_in) / (max_in - min_in + 1e-8) * 255
    image = np.dstack((image, image, image)).astype(np.uint8)
    zeros = np.zeros_like(mask)
    mask = np.dstack((zeros, zeros, mask * 255)).astype(np.uint8)
    return Image.blend(
        Image.fromarray(image),
        Image.fromarray(mask),
        alpha=.3
    )

slices_num = (7, )
slices = []
for idx in slices_num:
    k = valid_dataset[idx]
    slices.append(blend(
        k[0][0].numpy(),
        k[1][0].numpy()
    ))
    prediction = model.forward(k[0].view(1, 1, 512, 512).to(device)).cpu().detach().transpose(0, 1).transpose(1, 2).transpose(2, 3)[0]
    prediction[prediction >= 0.85] = 1
    prediction[prediction < 0.85] = 0
    slices.append(
        torch.cat([prediction, prediction, prediction], 2)
    )
    print(prediction[(prediction > 0.5) & (prediction < 1)])

figure = plt.figure(figsize=(18, 18))
for i, image in enumerate(slices):
    ax = figure.add_subplot(1, len(slices), i + 1)
    ax.imshow(slices[i])

### 5. Inference

In [None]:
"""
Load testing data into images and labels lists

images list consists of CT scans -  numpy arrays of shape (512, 512, n_slices)
"""
with open(core_path + 'testing_data.json', 'r') as f:
    dict_testing = json.load(f)

images_testing = []
label_testing = []
for entry in tqdm(dict_testing):
    image = nib.load(os.path.join(path + "images/", entry['image'][:-3]))
    images_testing.append(image.get_fdata())

In [None]:
thresh = 0.99
n_id = len(images_testing)

for i in range(n_id):
    n_imgs = images_testing[i][1].shape[-1]
    name = images_testing[i][0]
    
    for j in range(n_imgs):
        inp = images_testing[i][:, :, j]

        inp = torch.from_numpy(np.array(inp))
        inp = inp.to(device)
        inp = inp.type(torch.float)
        inp = inp.view(-1, 1, 512, 512)

        res = model(inp)
        res = res.cpu()
        outp = res.detach().numpy()[0][0]

        outp /= np.max(outp)
        outp[outp >= thresh] = 1
        outp[outp < thresh] = 0

        label_testing.append(outp)

In [None]:
"""
Write your code here

You need to:
 1. Predict labels for CT scans from images list
 2. Store them in the labels_predicted list in form of numpy arrays of shape (512, 512, n_slices), where:
    0 - background class
    1 - regions of consolidation class
"""
# model = model.to("cpu")
# labels_predicted = []
# for ct in images_predicted:
#     for ct_slice in ct.swapaxes(2, 1).swapaxes(1, 0):
#         label = model.forward(torch.tensor(np.array([[ct_slice]]), dtype=torch.float32))
#         try: 
#             ct_label = torch.cat((ct_label, label), 0)
#         except:
#             ct_label = torch.tensor(label)

In [None]:
# Visualize some of the predictions

# patient_num = 5
# slices_num = (10, 20, 30)
# slices = []
# for idx in slices_num:
#     slices.append(blend(
#         images_testing[patient_num][..., idx],
#         label_testing[patient_num][..., idx]
#     ))

# figure = plt.figure(figsize=(18, 18))
# for i, image in enumerate(slices):
#     ax = figure.add_subplot(1, len(slices), i + 1)
#     ax.imshow(slices[i])

In [None]:
def rle_encoding(x):
    dots = np.where(x.T.flatten() >= 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b > prev + 1):
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return [str(item) for item in run_lengths]

In [None]:
import csv
with open(f'{core_path}testing_data.json', 'r') as f:
            dict_testing = json.load(f)

# output = []
# for entry in tqdm(dict_testing):
#     image = nib.load(os.path.join(f'{core_path}data/data/images', entry['image'][:-3]))
#     image = image.get_fdata().swapaxes(0, 2).swapaxes(1, 2).reshape(-1, 1, 512, 512)
#     tmp = None
#     for i in range(image.shape[0]):  
#         img = image[i:i+1]
#         img = torch.tensor(img).to(device).float()
#         img = model(img).cpu().detach().numpy()[0]
#         tmp = img if tmp is None else np.vstack([tmp, img])

#     tmp = tmp.swapaxes(0, 1).swapaxes(1, 2)
#     output.append(tmp)
# print(output[0][0])

with open(f'submission.csv', "wt") as sb:
    submission_writer = csv.writer(sb, delimiter=',')
    submission_writer.writerow(["Id", "Predicted"])
    for k_i, patient_i in tqdm(zip(dict_testing, label_testing)):
        submission_writer.writerow([
                f"{k_i['image'][:-7]}",
                " ".join(rle_encoding(patient_i))
            ])