# Expandings moscow dataset

## Prepare

In [1]:
!pip install torchinfo
!pip install -U segmentation-models-pytorch
!pip install lungmask
!pip install openpyxl

Collecting torchinfo
  Downloading torchinfo-1.7.1-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.1
Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.0-py3-none-any.whl (97 kB)
     |████████████████████████████████| 97 kB 968 kB/s            
[?25hCollecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
     |████████████████████████████████| 58 kB 2.5 MB/s            
[?25h  Preparing metadata (setup.py) ... [?25l- done
Collecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
     |████████████████████████████████| 376 kB 3.5 MB/s            
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels
  Building wheel for efficientnet-pytorch (setup.py) ..

In [2]:
import os
import json
import numpy as np
import pandas as pd
import nibabel as nib
# DL
import torch
from torch import nn
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import SimpleITK as sitk
from lungmask import mask as unet_mask
# Visualization
import matplotlib.pyplot as plt
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid
from tqdm.notebook import tqdm
from IPython.display import clear_output, display
import seaborn as sns

In [3]:
TGCOVID_PATH = "../input/tgcovid"
KAGGLECOVID_PATH = "../input/covid19-ct-scans"
MOSCOWCOVID_PATH = "../input/covid19moscow/COVID19_1110"

In [4]:
def expand_path(*right_part):
    def _expand_path(left_part):
        return os.path.join(*right_part, left_part) \
            if left_part is not np.nan else np.nan
    return _expand_path

def load_tgcovid_data(json_file):
    datapath = os.path.join(TGCOVID_PATH, "data", "data")
    path_images = os.path.join(datapath, 'images')
    path_labels = os.path.join(datapath, 'labels')
    with open(os.path.join(TGCOVID_PATH, json_file), 'r') as f:
        dict_data = json.load(f)

    data = pd.DataFrame(dict_data)
    # expand path to full
    data["image"] = data["image"].apply(expand_path(path_images))
    data["label"] = data["label"].apply(expand_path(path_labels))
    # remove .gz
    data["image"] = data["image"].str[:-3]
    data["label"] = data["label"].str[:-3]
    return data

def load_kagglecovid_data(csv_file):
    datapath = os.path.join(KAGGLECOVID_PATH, csv_file)
    data = pd.read_csv(datapath)
    data.rename(columns={"ct_scan":"image", "infection_mask":"label"}, inplace=True)
    data.drop("lung_and_infection_mask", inplace=True, axis=1)
    return data

def load_moscowcovid_data(xlsx_file):
    datapath = os.path.join(MOSCOWCOVID_PATH, xlsx_file)
    data = pd.read_excel(datapath)
    data.rename(columns={"study_file":"image", "mask_file":"label"}, inplace=True)
    data.drop(["category", "study_id"], axis=1, inplace=True)
    data["image"] = data["image"].str[1:].apply(expand_path(MOSCOWCOVID_PATH))
    data["label"] = data["label"].str[1:].apply(expand_path(MOSCOWCOVID_PATH))
    # remove .gz
    data["image"] = data["image"].str[:-3]
    data["label"] = data["label"].str[:-3]
    return data

In [5]:
dataset = load_moscowcovid_data("dataset_registry.xlsx")
dataset

Unnamed: 0,image,label
0,../input/covid19moscow/COVID19_1110/studies/CT...,
1,../input/covid19moscow/COVID19_1110/studies/CT...,
2,../input/covid19moscow/COVID19_1110/studies/CT...,
3,../input/covid19moscow/COVID19_1110/studies/CT...,
4,../input/covid19moscow/COVID19_1110/studies/CT...,
...,...,...
1105,../input/covid19moscow/COVID19_1110/studies/CT...,
1106,../input/covid19moscow/COVID19_1110/studies/CT...,
1107,../input/covid19moscow/COVID19_1110/studies/CT...,
1108,../input/covid19moscow/COVID19_1110/studies/CT...,


## Visualization 

In [6]:
def normalize(x):
    min_in = np.min(x)
    max_in = np.max(x)
    return (x - min_in) / (max_in - min_in + 1e-8)

def slice2rgb(image, normalize_data=True):
    image = image.astype(np.float32)
    image = normalize(image) if normalize_data else image
    image *= 255
    image = np.dstack((image, image, image)).astype(np.uint8)
    return Image.fromarray(image)

def mask2blue(mask):
    zeros = np.zeros_like(mask)
    mask = np.dstack((zeros, zeros, mask * 255)).astype(np.uint8)
    return Image.fromarray(mask)
    
def blend(image, mask, normalize_data=True):
    return Image.blend(
        slice2rgb(image, normalize_data=True),
        mask2blue(mask),
        alpha=.2
    )

In [7]:
def save_ndarray_as_nii(data, path):
    image = sitk.GetImageFromArray(data)
    sitk.WriteImage(image, path, useCompression=True)

## Covid segmenation


In [8]:
!mkdir artificial_masks

In [9]:
IMG_SIZE = (512, 512)
BATCH_SIZE = 4
N_WORKERS = 2
THRESHOLD = .9
WEIGHTS_PATH = "../input/ct-scans-semantic-segmentation/epoch_90"

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
class Scan3D(Dataset):
    def __init__(self, path, transform=None):
        self.transform = transform
        self.images = []
        data = self.load_data(path)
        for idy in range(data.shape[-1]):
            self.images.append(data[..., idy, np.newaxis])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        image = self.normalize(image)
        
        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]
            
        return image
    
    def load_data(self, path, dtype="float32"):
        return nib.load(path).get_fdata().astype(dtype)
    
    def normalize(self, x):
        return normalize(x)

In [12]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.Resize(*IMG_SIZE),
    ToTensorV2()
])

In [13]:
def inference(model, data_path):
    data = Scan3D(data_path, transform)
    dataloader = DataLoader(data, BATCH_SIZE, False, num_workers=N_WORKERS)
    masks = []
    model.eval()
    for image in dataloader:
        image = image.to(device)
        mask = model(image)
        masks.append(mask.cpu().detach().numpy()[:, 0])
    model.train()
    return np.concatenate(masks, axis=0)

In [14]:
from segmentation_models_pytorch import UnetPlusPlus

model = UnetPlusPlus(
    encoder_name="timm-efficientnet-b3", encoder_depth=5, encoder_weights="imagenet", 
    in_channels=1, classes=1, activation="sigmoid", decoder_channels=(256, 128, 64, 32, 16),
).to(device)

try:
    print(summary(model, (1, 1, *IMG_SIZE)))
except:
    print("Something go wrong with Summary.")

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b3_aa-84b4657e.pth


  0%|          | 0.00/47.1M [00:00<?, ?B/s]

Layer (type:depth-idx)                             Output Shape              Param #
UnetPlusPlus                                       [1, 1, 512, 512]          --
├─EfficientNetEncoder: 1-1                         [1, 1, 512, 512]          592,896
│    └─Conv2d: 2-1                                 [1, 40, 256, 256]         360
│    └─BatchNorm2d: 2-2                            [1, 40, 256, 256]         80
│    └─Swish: 2-3                                  [1, 40, 256, 256]         --
│    └─Sequential: 2-4                             --                        --
│    │    └─Sequential: 3-1                        [1, 24, 256, 256]         3,504
│    │    └─Sequential: 3-2                        [1, 32, 128, 128]         48,118
│    │    └─Sequential: 3-3                        [1, 48, 64, 64]           110,912
│    │    └─Sequential: 3-4                        [1, 96, 32, 32]           638,700
│    │    └─Sequential: 3-5                        [1, 136, 32, 32]          1,387,760
│    

In [15]:
checkpoint = torch.load(WEIGHTS_PATH)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
print("Model from epoch", checkpoint["epoch"])
for key, item in checkpoint["history"].items():
    print(key, ":", item[-1])

Model from epoch 90
train_loss : 0.11613768646397542
val_loss : 0.31235065827002895
val_dice : 0.7434878257604746
train_dice : 0.8430125106548526
val_acc : 0.9900273573704255
train_acc : 0.9975728546221232
val_precision : nan
train_precision : 0.8164154704391342
val_recall : nan
train_recall : 0.918734166425528
grads : 0.4345550835132599


In [16]:
def get_ones_box(size, borders): 
    box = np.zeros([size, size])
    box[borders:(size-borders), borders:(size-borders)] = 1
    return box

In [17]:
def infection_segmentation(dataset, model):
    segmentation_table = []
    for idx in tqdm(dataset.index):
        path = dataset.loc[idx, "image"]
        mask = inference(model, path)
        # box is need to remove mask near borders
        box = np.concatenate(
            [get_ones_box(512, 10)[np.newaxis, ...]]*mask.shape[0], axis=0)
        mask = (mask >= THRESHOLD) * box
        mask_path = os.path.join("artificial_masks", os.path.basename(path))
        save_ndarray_as_nii(np.uint8(mask), mask_path)
        segmentation_table.append([idx, path, mask_path])
    segmentation_table = pd.DataFrame(
        segmentation_table, columns=["study_id", "image", "mask"]).set_index("study_id")
    segmentation_table.to_csv("infection_data.csv")

In [18]:
infection_segmentation(dataset, model)

  0%|          | 0/1110 [00:00<?, ?it/s]