**About** : This notebook is used to train models.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
import torch

print(torch.__version__)
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.cuda.get_device_name(0)
print(device)

In [None]:
import os
import re
import sys
import glob
import json
import time
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.metrics import *

pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *
from data.dataset import *
from data.preparation import *
from data.transforms import get_transfos

from model_zoo.models_seg import define_model, convert_3d
from training.main_seg import k_fold
from inference.extract_features import Config

from util.plots import *

In [None]:
from model_zoo.models_seg import define_model
from util.torch import load_model_weights
from training.losses import SegLoss

### Preparation

In [None]:
df_patient, df_img = prepare_data(DATA_PATH)

In [None]:
series = df_img.groupby('series')[['patient_id', "frame"]].max().reset_index()

segs = pd.DataFrame({"path": glob.glob("../input/segmentations/*.nii")})
segs['series'] = segs['path'].apply(lambda x: int(x.split('/')[-1][:-4]))
segs = segs.merge(series)
segs = segs[["patient_id", "series", "frame", "path"]]

In [None]:
SAVE_FOLDER = "../input/3ds_2/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

MAX_LEN = 600
SIZE = 256

SAVE = False
PLOT = False

In [None]:
# %%time
# if PLOT:
#     show_cmap()

# for idx in tqdm(range(len(segs))):
#     patient_id = segs['patient_id'][idx]
#     series = segs['series'][idx]

#     imgs = sorted(glob.glob("../input/imgs/" + f"{patient_id}_{series}_*"))
#     imgs = np.array([cv2.imread(f, 0) for f in imgs[-MAX_LEN:]])
#     imgs = center_crop_pad(imgs, 384)

#     imgs = F.interpolate(torch.from_numpy(imgs).unsqueeze(0).unsqueeze(0), size=(SIZE, SIZE, SIZE), mode="nearest")[0][0]
    
# #     imgs = imgs[::2, ::2, ::2]
# #     imgs = F.interpolate(  # Only downsample on x
# #         torch.from_numpy(imgs).contiguous().view(-1, SIZE * SIZE).transpose(0, 1).unsqueeze(1),
# #         size=SIZE,
# #         mode="nearest"
# #     )[:, 0].transpose(0, 1).view(SIZE, SIZE, SIZE)
    
#     imgs = imgs.numpy().astype(np.uint8)
    
#     seg = load_segmentation(segs['path'][idx])[-MAX_LEN:]
#     seg = center_crop_pad(seg, 384).copy()
    
# #     seg = seg[::2, ::2, ::2]
# #     seg = F.interpolate(  # Only downsample on x
# #         torch.from_numpy(seg).contiguous().view(-1, SIZE * SIZE).transpose(0, 1).unsqueeze(1),
# #         size=SIZE,
# #         mode="nearest"
# #     )[:, 0].transpose(0, 1).view(SIZE, SIZE, SIZE)
    
#     seg = F.interpolate(torch.from_numpy(seg).unsqueeze(0).unsqueeze(0), size=(SIZE, SIZE, SIZE), mode="nearest")[0][0]
#     seg = seg.numpy().astype(np.uint8)
    
#     if SAVE:
#         np.save(SAVE_FOLDER + "imgs/" + f"{patient_id}_{series}.npy", imgs)
#         np.save(SAVE_FOLDER + "segs/" + f"{patient_id}_{series}.npy", seg)
    
#     if PLOT:
#         ids = [i * len(imgs) // 5 for i in range(1, 5)]
#         plt.figure(figsize=(20, 5))
#         for i, id_ in enumerate(ids):
#             plt.subplot(1, len(ids), i + 1)
#             plot_mask(imgs[id_], seg[id_])
#             plt.title(f'Frame {id_}')
#         plt.show()
        
#         ids = [i * imgs.shape[1] // 5 for i in range(1, 5)]
#         plt.figure(figsize=(20, 5))
#         for i, id_ in enumerate(ids):
#             plt.subplot(1, len(ids), i + 1)
#             plot_mask(imgs[:, id_], seg[:, id_])
#             plt.title(f'Frame {id_}')
#         plt.show()
        
#         ids = [i * imgs.shape[2] // 5 for i in range(1, 5)]
#         plt.figure(figsize=(20, 5))
#         for i, id_ in enumerate(ids):
#             plt.subplot(1, len(ids), i + 1)
#             plot_mask(imgs[:, :, id_], seg[:, :, id_])
#             plt.title(f'Frame {id_}')
#         plt.show()
# #         if idx > 10:
#     break

In [None]:
# df_seg = pd.DataFrame({
#     "mask_path": sorted(glob.glob(f'{SAVE_FOLDER}/segs/*.npy'))
# })
# df_seg['patient_id'] = df_seg['mask_path'].apply(lambda x: int(x.split('/')[-1].split('_')[0]))
# df_seg['series'] = df_seg['mask_path'].apply(lambda x: int(x.split('/')[-1].split('_')[1][:-4]))

# df_seg['img_path'] = df_seg['mask_path'].apply(lambda x: re.sub("/segs/", "/imgs/", x))
# df_seg = df_seg[["patient_id", "series", "img_path", "mask_path"]]

# df_seg.to_csv('../input/df_seg_3d.csv', index=False)
# df_seg.head()

### Preparation Extra

In [None]:
# PLOT = False
# SAVE = True

# SIZE = 256

# SAVE_FOLDER = '../input/3ds_extra/'

# os.makedirs(SAVE_FOLDER, exist_ok=True)
# os.makedirs(SAVE_FOLDER + "imgs/", exist_ok=True)
# os.makedirs(SAVE_FOLDER + "segs/", exist_ok=True)

# segs_extra = glob.glob('../input/extra_segs/segs/*')

In [None]:
# for i, path in enumerate(tqdm(segs_extra)):
#     study = path.split('/')[-1].split('_')[0]
    
#     if os.path.exists(SAVE_FOLDER + "imgs/" + f"{study}_{study}.npy"):
#         continue

#     try:
#         seg = load_segmentation(path)[::-1]
#         imgs = load_segmentation(f'../input/extra_segs/imgs/{study}/ct.nii.gz')[::-1]
#     except:
#         continue

#     kept = (seg > 0).sum(-1).sum(-1) > 1000
#     start, end = np.argmax(kept), len(kept) - np.argmax(kept[::-1])
#     imgs = imgs[start: end]
#     seg = seg[start: end]
    
#     all_present = all((seg == i).sum() > 1000 for i in range(1, 6))
#     if not all_present:
#         continue

#     imgs, (start, end) = auto_windowing(imgs)
    
#     crop_size = int(0.75 * imgs.shape[1])

#     imgs = center_crop_pad(imgs, crop_size)
#     imgs = F.interpolate(torch.from_numpy(imgs).unsqueeze(0).unsqueeze(0), size=(SIZE, SIZE, SIZE), mode="nearest")[0][0]
#     imgs = imgs.numpy().astype(np.uint8)
    
#     seg = center_crop_pad(seg, crop_size).copy()
#     seg = F.interpolate(torch.from_numpy(seg).unsqueeze(0).unsqueeze(0), size=(SIZE, SIZE, SIZE), mode="nearest")[0][0]
#     seg = seg.numpy().astype(np.uint8)
    
#     if SAVE:
#         np.save(SAVE_FOLDER + "imgs/" + f"{study}_{study}.npy", imgs)
#         np.save(SAVE_FOLDER + "segs/" + f"{study}_{study}.npy", seg)
    
#     if PLOT or (not (i % 100)):
#         ids = [i * len(imgs) // 5 for i in range(1, 5)]
#         plt.figure(figsize=(20, 5))
#         for i, id_ in enumerate(ids):
#             plt.subplot(1, len(ids), i + 1)
#             plot_mask(imgs[id_], seg[id_])
#             plt.title(f'Frame {id_}')
#         plt.show()
        
#         ids = [i * imgs.shape[1] // 5 for i in range(1, 5)]
#         plt.figure(figsize=(20, 5))
#         for i, id_ in enumerate(ids):
#             plt.subplot(1, len(ids), i + 1)
#             plot_mask(imgs[:, id_], seg[:, id_])
#             plt.title(f'Frame {id_}')
#         plt.show()
        
#         ids = [i * imgs.shape[2] // 5 for i in range(1, 5)]
#         plt.figure(figsize=(20, 5))
#         for i, id_ in enumerate(ids):
#             plt.subplot(1, len(ids), i + 1)
#             plot_mask(imgs[:, :, id_], seg[:, :, id_])
#             plt.title(f'Frame {id_}')
#         plt.show()
        
# #     if idx > 10:
# #         break

In [None]:
# df_seg = pd.DataFrame({
#     "mask_path": sorted(glob.glob(f'{SAVE_FOLDER}/segs/*.npy'))
# })
# df_seg['patient_id'] = df_seg['mask_path'].apply(lambda x: x.split('/')[-1].split('_')[0])
# df_seg['series'] = df_seg['mask_path'].apply(lambda x: x.split('/')[-1].split('_')[1][:-4])

# df_seg['img_path'] = df_seg['mask_path'].apply(lambda x: re.sub("/segs/", "/imgs/", x))
# df_seg = df_seg[["patient_id", "series", "img_path", "mask_path"]]

# df_seg.to_csv('../input/df_seg_3d_extra.csv', index=False)

# print(f"-> Saved {len(df_seg)} extra segmentations")
# df_seg.head()

### 3D Inference

In [None]:
EXP_FOLDER = "../logs/2023-09-24/20/"
# EXP_FOLDER = "../logs/2023-09-22/31/"   # slow

SAVE_FOLDER = "../input/3ds/"

os.makedirs(EXP_FOLDER + "masks/", exist_ok=True)

In [None]:
config = Config(json.load(open(EXP_FOLDER + "config.json", "r")))

In [None]:
df_patient, df_img = prepare_data(DATA_PATH)

In [None]:
df_series = df_img[["patient_id", 'series', 'frame']].groupby(["patient_id", 'series']).max().reset_index()

folds = pd.read_csv("../input/folds_4.csv")
df_series = df_series.merge(folds, how="left")

In [None]:
PLOT = False
SAVE = True

In [None]:
# %%time
# # for fold in range(4):
# for fold in [0]:
#     df_seg = df_series[df_series['fold'] == fold].reset_index(drop=True)

#     model = define_model(
#         config.decoder_name,
#         config.name,
#         num_classes=config.num_classes,
#         num_classes_aux=config.num_classes_aux,
#         n_channels=config.n_channels,
#         increase_stride=config.increase_stride,
#     )

#     model = convert_3d(model)
#     model = load_model_weights(model, EXP_FOLDER + f"{config.name}_{fold}.pt")
#     model = model.cuda()
    
#     for idx in tqdm(range(len(df_seg))):
#         patient_id = df_seg['patient_id'][idx]
#         series = df_seg['series'][idx]
#         n_frames = df_seg['frame'][idx]

#         imgs = sorted(glob.glob("../input/imgs/" + f"{patient_id}_{series}_*"))
#         imgs = np.array([cv2.imread(f, 0) for f in imgs[-MAX_LEN:]])
#         n_frames = int(len(imgs))
#         imgs = center_crop_pad(imgs, 384)

#         x = torch.from_numpy(imgs).cuda().float() / 255.
        
# #         t0 = time.time()
        
#         with torch.cuda.amp.autocast(enabled=True):
#             x = x.unsqueeze(0).unsqueeze(0)
#             x = F.interpolate(x, size=(SIZE, SIZE, SIZE), mode="nearest")

#             pred = model(x)[0].argmax(1, keepdims=True).float()
#             pred = F.interpolate(pred, size=(n_frames, 384, 384), mode="nearest")

# #         t1 = time.time()
# #         print(f"inf {t1 - t0 :.3f}")
            
#         msk = pred.cpu().numpy()[0][0].astype(int)
        
#         assert msk.shape == imgs.shape
        
#         counts = np.array([(msk == i).sum(-1).sum(-1) for i in range(1, 5)])

#         if SAVE:
#             np.save(EXP_FOLDER + "masks/" + f"mask_{patient_id}_{series}.npy", msk.astype(np.uint8))
#             np.save(EXP_FOLDER + "masks/" + f"mask_counts_{patient_id}_{series}.npy", counts.astype(int))

#         if PLOT or not (idx % 100):
#             plt.figure(figsize=(20, 4))
#             ids = [i * n_frames // 6 for i in range(1, 5)]

#             plt.subplot(1, 5, 1)
#             plot_mask(imgs[ids[-1]], msk[ids[-1]])
#             plt.title(f'Frame {ids[-1]}')
            
#             ids = [i * 384 // 6 for i in range(1, 5)]

#             plt.subplot(1, 5, 2)
#             plot_mask(imgs[:, ids[2]], msk[:, ids[2]])
#             plt.title(f'Frame {ids[2]}')
            
#             plt.subplot(1, 5, 3)
#             plot_mask(imgs[:, ids[3]], msk[:, ids[3]])
#             plt.title(f'Frame {ids[3]}')

#             plt.subplot(1, 5, 4)
#             plot_mask(imgs[:, :, ids[1]], msk[:, :, ids[1]])
#             plt.title(f'Frame {ids[1]}')

#             plt.subplot(1, 5, 5)
#             plt.plot(counts.T)
#             plt.yticks([], [])
#             plt.title(f'Counts')
            
#             plt.show()

# #         if idx > 5:
# #             break
# #     break

In [None]:
%%time

df_img = pd.DataFrame({"path": sorted(glob.glob("../input/imgs/*"))})
df_img['series'] = df_img['path'].apply(lambda x: x.split('_')[-2])
d = df_img.groupby('series').agg(list)[['path']].to_dict()['path']

In [None]:
%%time
for fold in [0]:
    df_seg = df_series[df_series['fold'] == fold].reset_index(drop=True)

    for idx in tqdm(range(len(df_seg))):
        patient_id = df_seg['patient_id'][idx]
        series = df_seg['series'][idx]

        imgs = sorted(d[str(series)])[-MAX_LEN:]

        msk = np.load(EXP_FOLDER + "masks/" + f"mask_{patient_id}_{series}.npy")
        for i, m in enumerate(msk):
            frame = imgs[i].split('_')[-1][:4]
            cv2.imwrite(EXP_FOLDER + "masks/" + f"mask_{patient_id}_{series}_{frame}.png", m)
#         break
#     break

## Explo

In [None]:
df_seg = pd.read_csv('../input/df_seg_3d.csv')

In [None]:
folds = pd.read_csv("../input/folds_4.csv")
df_seg = df_seg.merge(folds, how="left")

In [None]:
df_seg = df_seg[df_seg['fold'] == 0].reset_index(drop=True)

In [None]:
dataset = Seg3dDataset(df_seg.head(), train=False)

In [None]:
# for i in tqdm(range(len(dataset))):
#     _ = dataset[i]
#     break

In [None]:
idx = 1
x, seg, _ = dataset[idx]
imgs, seg = x[0].numpy(), seg[0].numpy()

In [None]:
show_cmap(True)

In [None]:
plt.figure(figsize=(20, 15))
ids = [i * len(imgs) // 6 for i in range(1, 5)]

for i, id_ in enumerate(ids):
    plt.subplot(3, len(ids), i + 1)
    plot_mask(imgs[id_], seg[id_])
    plt.title(f'Frame {id_}')

for i, id_ in enumerate(ids):
    plt.subplot(3, len(ids), i + 1 + 4)
    plot_mask(imgs[:, id_], seg[:, id_])
    plt.title(f'Frame {id_}')

for i, id_ in enumerate(ids):
    plt.subplot(3, len(ids), i + 1 + 8)
    plot_mask(imgs[:, :, id_], seg[:, :, id_])
    plt.title(f'Frame {id_}')

plt.show()

In [None]:
model = define_model(
    "Unet",
    'resnet18d',
    num_classes=5,
    num_classes_aux=0,
    n_channels=1,
    use_cls=False,
    increase_stride=False
)

model = convert_3d(model)
model = load_model_weights(model, "../logs/2023-09-22/31/resnet18d_0.pt")
model = model.cuda()

In [None]:
%%time
pred, _ = model(x = x.unsqueeze(0).cuda())
msk = pred.argmax(1).cpu().numpy()[0].astype(int)

In [None]:
plt.figure(figsize=(20, 15))
ids = [i * len(imgs) // 6 for i in range(1, 5)]

for i, id_ in enumerate(ids):
    m = np.where(msk[id_] == 0, -1, msk[id_])
    acc = (m == seg[id_]).sum() / ((m > 0).sum() + 1)

    plt.subplot(3, len(ids), i + 1)
    plot_mask(imgs[id_], msk[id_])
    plt.title(f'Frame {id_} - acc {acc:.3f}')

for i, id_ in enumerate(ids):
    m = np.where(msk[:, id_] == 0, -1, msk[:, id_])
    acc = (m == seg[:, id_]).sum() / ((m > 0).sum() + 1)
    
    plt.subplot(3, len(ids), i + 1 + 4)
    plot_mask(imgs[:, id_], msk[:, id_])
    plt.title(f'Frame {id_} - acc {acc:.3f}')

for i, id_ in enumerate(ids):
    m = np.where(msk[:, :, id_] == 0, -1, msk[:, :, id_])
    acc = (m == seg[:, :, id_]).sum() / ((m > 0).sum() + 1)
    
    plt.subplot(3, len(ids), i + 1 + 8)
    plot_mask(imgs[:, :, id_], msk[:, :, id_])
    plt.title(f'Frame {id_} - acc {acc:.3f}')

plt.show()

Done ! 