# Single Veretbrae -- Inference Pipeline for trained models

## Utils

In [None]:
print("optimal image size: 4736 1920")
w = 256
h = 256
sizes = [(h, w), (h//2, w//2), (h//4, w//4), (h//8, w//8)]
print(sizes[0], sizes[1], sizes[2], sizes[3])
size = sizes[0]

optimal image size: 4736 1920
(256, 256) (128, 128) (64, 64) (32, 32)


In [None]:
import numpy as np
from numpy.linalg import matrix_power
i = np.array([[1/2, -np.sqrt(3)/2], [np.sqrt(3)/2, 1/2]])
matrix_power(i, 2022)

array([[ 1.00000000e+00,  6.60027588e-14],
       [-6.59887609e-14,  1.00000000e+00]])

In [None]:
import os
import pydicom
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class DatasetNew(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return {"image": self.data[index]['image'], "mask": self.data[index]['mask']}

In [None]:
test_cases[46:69]

['060_SD_C3.png',
 '060_SD_C4.png',
 '060_SD_C5.png',
 '060_SD_C6.png',
 '060_SD_C7.png',
 '060_SD_L1.png',
 '060_SD_L2.png',
 '060_SD_L3.png',
 '060_SD_L4.png',
 '060_SD_L5.png',
 '060_SD_S1.png',
 '060_SD_Th1.png',
 '060_SD_Th10.png',
 '060_SD_Th11.png',
 '060_SD_Th12.png',
 '060_SD_Th2.png',
 '060_SD_Th3.png',
 '060_SD_Th4.png',
 '060_SD_Th5.png',
 '060_SD_Th6.png',
 '060_SD_Th7.png',
 '060_SD_Th8.png',
 '060_SD_Th9.png']

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from functionality import *

path_to_dataset = "C:\\Users\\gieko\\Dropbox\\NIITO_Vertebrae\\NIITO_Vertebrae_Dataset\\NIITO_Vertebrae_Dataset_Final_Test_resized\\data_single_vertebra"
#  "C:\Users\EUgolnikova\Dropbox\NIITO_Vertebrae\NIITO_Vertebrae_Dataset\NIITO_Vertebrae_Dataset_Test"

path_to_images = os.path.join(path_to_dataset, "images")
path_to_labels = os.path.join(path_to_dataset, "labels")

test_cases = os.listdir(path_to_images)

test_transforms = A.ReplayCompose(
    [   
        A.Resize(height=size[0], width=size[1]),
        ToTensorV2()
    ],
    additional_targets={'image': 'image', 'mask': 'mask'})



test_aug = []
for case in test_cases:
    print(case)
    path_mask = os.path.join(path_to_labels, case)
    path_image = os.path.join(path_to_images, case)
    
    image = cv2.imread(path_image, 1)
    # mask = read_mask(path_mask)
    mask = cv2.imread(path_mask, 0)   
    mask[mask==255] = 1.0 
    # print(image.shape, mask.shape)

    # image = np.moveaxis(image, 0, 2)

    
    augmentations = test_transforms(image=image, mask=mask)
    # print("!!",augmentations['image'].shape, augmentations['mask'].shape)

    test_aug.append({
        "image": augmentations['image'],
        "mask": augmentations['mask']
    })


test_dataset = DatasetNew(test_aug)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=True)

print(len(test_dataset))

053_SD_C3.png
053_SD_C4.png
053_SD_C5.png
053_SD_C6.png
053_SD_C7.png
053_SD_L1.png
053_SD_L2.png
053_SD_L3.png
053_SD_L4.png
053_SD_L5.png
053_SD_S1.png
053_SD_Th1.png
053_SD_Th10.png
053_SD_Th11.png
053_SD_Th12.png
053_SD_Th2.png
053_SD_Th3.png
053_SD_Th4.png
053_SD_Th5.png
053_SD_Th6.png
053_SD_Th7.png
053_SD_Th8.png
053_SD_Th9.png
057_SD_C3.png
057_SD_C4.png
057_SD_C5.png
057_SD_C6.png
057_SD_C7.png
057_SD_L1.png
057_SD_L2.png
057_SD_L3.png
057_SD_L4.png
057_SD_L5.png
057_SD_S1.png
057_SD_Th1.png
057_SD_Th10.png
057_SD_Th11.png
057_SD_Th12.png
057_SD_Th2.png
057_SD_Th3.png
057_SD_Th4.png
057_SD_Th5.png
057_SD_Th6.png
057_SD_Th7.png
057_SD_Th8.png
057_SD_Th9.png
060_SD_C3.png
060_SD_C4.png
060_SD_C5.png
060_SD_C6.png
060_SD_C7.png
060_SD_L1.png
060_SD_L2.png
060_SD_L3.png
060_SD_L4.png
060_SD_L5.png
060_SD_S1.png
060_SD_Th1.png
060_SD_Th10.png
060_SD_Th11.png
060_SD_Th12.png
060_SD_Th2.png
060_SD_Th3.png
060_SD_Th4.png
060_SD_Th5.png
060_SD_Th6.png
060_SD_Th7.png
060_SD_Th8.png
060_

In [None]:
def get_mean_std(loader):
    ch_sum, ch_squared_sum, count_of_batches = 0, 0, 0
    
    for data in loader:
        data = data['image'].float()
        data /= 255        

        ch_sum += torch.mean(data, dim=[0, 2, 3])
        ch_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
        count_of_batches += 1

    mean = ch_sum / count_of_batches 
    std = (ch_squared_sum / count_of_batches - mean**2)**0.5

    return mean, std 


def soft_dice(*, y_true, y_pred):
    eps = 1e-15
    y_pred = y_pred.contiguous().view(y_pred.numel())
    y_true = y_true.contiguous().view(y_true.numel())
    intersection = (y_pred * y_true).sum(0)
    scores = 2. * (intersection + eps) / (y_pred.sum(0) + y_true.sum(0) + eps)
    score = scores.sum() / scores.numel()
    
    return torch.clamp(score, 0., 1.)


def hard_dice(*, y_true, y_pred, thr=0.5):
    y_pred = (y_pred > thr).float()
    return soft_dice(y_true=y_true, y_pred=y_pred)


def accuracy(y_true, y_pred, thr=0.5):
    num_correct = 0
    num_pixels = 0
    
    y_pred = (y_pred > thr).float()
    num_correct += (y_true == y_pred).sum()
    num_pixels += torch.numel(y_pred)
    
    return num_correct/num_pixels*100






In [None]:
import torch
import torchvision
import skimage
import numpy as np
import time


def read_mask(mask_name):
    mask = (skimage.io.imread(mask_name)[:,:]==255).astype(np.uint8)*255
    mask = (mask > 0).astype(np.uint8)
    return mask



def make_blending(img_path, mask_path, alpha=0.5):
    img, mask = read_mask(img_path), read_mask(mask_path)[:, :, 0]
    colors = np.array([[0,0,0], [255,0,0]], np.uint8)
    return (img*alpha + colors[mask.astype(np.int32)]*(1. - alpha)).astype(np.uint8)


def show_images_with_mask(image_path,  mask_path_fill, alpha=0.5):
    plt.figure(figsize=(20, 14))
    plt.subplot(1, 2, 1)
    orig, _m = read_mask(image_path), read_mask(mask_path_fill)
    plt.imshow(orig)
    plt.subplot(1, 2, 2)
    blend = make_blending(image_path, mask_path_fill, alpha)
    plt.imshow(blend)


def save_predictions_as_imgs(loader, model, thr=0.5, folder="/content/saved_images", device='cpu'):
  if not os.path.exists(folder):
    os.makedirs(folder)

  model.to(device=device).eval()
#   model.eval()
  acc = []
  s_dice = []
  h_dice = []
  times = []
  y_trues = []
  y_preds = []
  for idx, data in enumerate(loader):

    x = img = data['image'].float().to(device=device)
    img = torch.squeeze(img, 0)
    img = img.permute(1, 2, 0)
    y = mask = data['mask'].to(device=device)
    mask = torch.squeeze(mask, 0)

    with torch.no_grad():
      start_time= time.time() 
      preds = torch.sigmoid(model(x))
      preds = (preds > thr).float()
      stop_time=time.time()

    duration =stop_time - start_time
    hours = duration // 3600
    minutes = (duration - (hours * 3600)) // 60
    seconds = duration - ((hours * 3600) + (minutes * 60))
    msg = f'training elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds)'
    times.append(duration)

    x = x.float() / 255
    # print(type(x), type(preds))
    # print(y.shape, preds.shape)
    # y_trues = torch.cat((y_trues, y), 0)
    # y_preds = torch.cat((y_preds, preds), 0)
    y_trues.append(y)
    y_preds.append(preds)
    acc.append(accuracy(y, preds))
    h_dice.append(hard_dice(y_true=y, y_pred=preds))
    s_dice.append(soft_dice(y_true=y, y_pred=preds))

    torchvision.utils.save_image(x, f"{folder}/orig_{idx}.png")
    torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
    torchvision.utils.save_image(y.float(), f"{folder}/gt_{idx}.png")
    # torchvision.utils.save_image(y, f"{folder}/{idx}.png")
  
  # sum_true = sum(y_trues)
  # sum_preds = sum(y_preds)
  dice_per_case = hard_dice(y_true=torch.cat(y_trues), y_pred=torch.cat(y_preds))
  means = [np.mean(acc), np.mean(s_dice), np.mean(h_dice), np.mean(times)]
  return {"accuracy": acc, "soft DICE": s_dice, "DICE": h_dice, "time": times, "means": means, "dice per case": dice_per_case} 
  


In [None]:
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt

def save_blended(path_to_folder, number_of_images, acc = None, s_dice = None, h_dice = None, alpha = 0.75, beta = 0.95):

    for k in range(number_of_images):
        path_img = os.path.join(path_to_folder, "orig_" + str(k) + ".png")
        path_mask = os.path.join(path_to_folder, "pred_" + str(k) + ".png")
        path_mask_gt = os.path.join(path_to_folder, "gt_" + str(k) + ".png")
        path_save = os.path.join(path_to_folder, "blend_" + str(k) + ".png")

        img = plt.imread(path_img)
        msk = plt.imread(path_mask)
        msk_gt = plt.imread(path_mask_gt)

        red  = np.array([1,0,0],dtype=np.uint8)
        blue = np.array([0,0,1],dtype=np.uint8)

        # print(accuracy(img, msk))
        # print(hard_dice(y_true=img, y_pred=msk))
        # print(soft_dice(y_true=img, y_pred=msk))

        # plt.figure(figsize=(20, 14))
        # plt.subplot(1, 3, 1)
        # plt.imshow(img)
        # plt.subplot(1, 3, 2)
        # plt.imshow(msk)

        # print(img.shape, msk.shape)
        for i in range(msk.shape[0]):
            for j in range(msk.shape[1]):
                if msk[i, j].all() == 1:
                    msk[i, j] = red
                if msk_gt[i, j].all() == 1:
                    msk_gt[i, j] = blue

        # for i in range(msk.shape[0]):
        #     for j in range(msk.shape[1]):
        #         if msk[i, j].all() == 1:
        #             msk[i, j] = red
                

        res = (img*alpha + msk*(1 - alpha))
        res = (res*beta + msk_gt*(1 - beta))


        # plt.subplot(1, 3, 3)
        # plt.imshow(res)


        res = (res * 255).astype(np.uint8)

        im = Image.fromarray(res)

        str_ = ""
        str_ += f"\n accuracy: {acc[k]}" if acc is not None else ""
        str_ += f"\n DICE: {h_dice[k]}" if h_dice is not None else ""

        # print(str_)
        font = ImageFont.truetype('arial', size=16)
        ImageDraw.Draw(
            im  # Image
        ).text(
            (0, 0),  # Coordinates
            str_,  # Text
            (255, 185, 93),  # Color
            font
        )


        im.save(path_save)


## Model

In [None]:
import torch
import torch.nn as nn


class ConvLRelu(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False)
        self.batchNorm = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchNorm(x)
        x = self.activation(x)
        return x

    
class DoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_block = nn.Sequential(
            ConvLRelu(in_channels, out_channels),
            ConvLRelu(out_channels, out_channels),
        )
    
    def forward(self, x):
        x = self.conv_block(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv_block = DoubleConvBlock(in_channels, out_channels)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
            
    def forward(self, x):
        before_pool = self.conv_block(x)
        x = self.max_pool(before_pool)
        return x, before_pool
    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()              
        self.conv_block = DoubleConvBlock(in_channels, out_channels)

    def forward(self, x, y):
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        return self.conv_block(torch.cat([x, y], dim=1))

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, n_filters=64):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)

        self.enc1 = EncoderBlock(in_channels, n_filters)
        self.enc2 = EncoderBlock(n_filters, n_filters * 2)
        self.enc3 = EncoderBlock(n_filters * 2, n_filters * 4)
        self.enc4 = EncoderBlock(n_filters * 4, n_filters * 8)
        
        self.center = DoubleConvBlock(n_filters * 8, n_filters * 16)
        
        self.dec4 = DecoderBlock(n_filters * (16 + 8), n_filters * 8)
        self.dec3 = DecoderBlock(n_filters * (8 + 4), n_filters * 4)
        self.dec2 = DecoderBlock(n_filters * (4 + 2), n_filters * 2)
        self.dec1 = DecoderBlock(n_filters * (2 + 1), n_filters)

        self.final = nn.Conv2d(n_filters, out_channels, kernel_size=1)

    def forward(self, x):
        x = x.float()
        x, enc1 = self.enc1(x)
        x, enc2 = self.enc2(x)
        x, enc3 = self.enc3(x)
        x, enc4 = self.enc4(x)

        center = self.center(x)

        dec4 = self.dec4(center, enc4)
        dec3 = self.dec3(dec4, enc3)
        dec2 = self.dec2(dec3, enc2)
        dec1 = self.dec1(dec2, enc1)

        
        final = self.final(dec1)

        return final

# Inference 

## Jolly Wind 17



In [None]:
hyperparametrs = {
    'n_filters': 32,
    'loss_weight': 0.8,
    'lr': 1e-3,
    'epochs': 50,
    'lr_reduce_rate': 0.5,
    'patience': 4,
    'early_stopping': 50, # пока уберем раннюю остановку
    'model': 'test'
}
model = UNet(n_filters=hyperparametrs['n_filters'])
model.load_state_dict(torch.load('C:/Users/gieko/Dropbox/NIITO_Vertebrae/Scripts/weight/UNet_single-vertebrae/jolly-wind-17/weights.pth', map_location=torch.device('cpu')))

folder = 'C:/Users/gieko/Dropbox/NIITO_Vertebrae/Scripts/inference/UNet_single-vertebrae/jolly-wind-17/'
print("model loaded")
metrics = save_predictions_as_imgs(test_loader, model, folder = folder)
print(metrics["dice per case"])

model loaded
tensor(0.9491)


In [None]:

save_blended(path_to_folder=folder, number_of_images=115, acc = metrics['accuracy'], h_dice = metrics['DICE'], alpha = 0.75, beta = 0.8)


# Upbeat Tree 20

In [None]:
hyperparametrs = {
    'n_filters': 32,
    'loss_weight': 0.8,
    'lr': 1e-3,
    'epochs': 50,
    'lr_reduce_rate': 0.5,
    'patience': 4,
    'early_stopping': 50, # пока уберем раннюю остановку
    'model': 'test'
}
model = UNet(n_filters=hyperparametrs['n_filters'])
model.load_state_dict(torch.load('C:/Users/gieko/Dropbox/NIITO_Vertebrae/Scripts/weight/UNet_single-vertebrae/upbeat-tree-20/weights.pth', map_location=torch.device('cpu')))

folder = 'C:/Users/gieko/Dropbox/NIITO_Vertebrae/Scripts/inference/UNet_single-vertebrae/upbeat-tree-20/'
print("model loaded")
metrics = save_predictions_as_imgs(test_loader, model, folder = folder)
print(metrics["dice per case"])

model loaded
tensor(0.9429)


In [None]:

save_blended(path_to_folder=folder, number_of_images=115, acc = metrics['accuracy'], h_dice = metrics['DICE'], alpha = 0.75, beta = 0.8)


In [None]:
metrics

{'accuracy': [tensor(97.1207),
  tensor(95.7474),
  tensor(97.5525),
  tensor(93.8568),
  tensor(94.7937),
  tensor(96.9269),
  tensor(96.8979),
  tensor(95.3522),
  tensor(97.2672),
  tensor(95.9946),
  tensor(90.9210),
  tensor(95.7245),
  tensor(91.6199),
  tensor(91.7175),
  tensor(91.5054),
  tensor(96.2555),
  tensor(93.2007),
  tensor(85.8215),
  tensor(96.4813),
  tensor(97.0474),
  tensor(97.3038),
  tensor(95.1508),
  tensor(98.2178),
  tensor(96.6141),
  tensor(94.3298),
  tensor(95.4163),
  tensor(90.4602),
  tensor(93.5593),
  tensor(96.9788),
  tensor(94.7845),
  tensor(93.8110),
  tensor(87.9272),
  tensor(94.7113),
  tensor(94.3893),
  tensor(96.0220),
  tensor(96.4188),
  tensor(97.2168),
  tensor(95.4224),
  tensor(97.6593),
  tensor(91.0889),
  tensor(95.6207),
  tensor(97.3663),
  tensor(96.1639),
  tensor(92.2684),
  tensor(97.9980),
  tensor(96.8933),
  tensor(96.5790),
  tensor(95.1645),
  tensor(94.1605),
  tensor(96.3837),
  tensor(97.6120),
  tensor(91.5909),


In [None]:
folder = "C:/Users/gieko/Dropbox/NIITO_Vertebrae/Scripts/inference/UNet_spinal-cord/apricot-water-43"
save_blended(path_to_folder=folder, number_of_images=5, acc = metrics['accuracy'], h_dice = metrics['DICE'], alpha = 0.75, beta = 0.8)
