**About** : This notebook is used to train models.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
import torch

print(torch.__version__)
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.cuda.get_device_name(0)
print(device)

In [None]:
import os
import re
import sys
import glob
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.metrics import *

pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *
from data.dataset import *
from data.preparation import *
from data.transforms import get_transfos

from model_zoo.models import define_model
from training.main_seg import k_fold

from util.plots import *

## Preparation

In [None]:
df_patient, df_img = prepare_data(DATA_PATH)

In [None]:
series = df_img.groupby('series')[['patient_id', "frame"]].max().reset_index()

segs = pd.DataFrame({"path": glob.glob("../input/segmentations/*.nii")})
segs['series'] = segs['path'].apply(lambda x: int(x.split('/')[-1][:-4]))
segs = segs.merge(series)
segs = segs[["patient_id", "series", "frame", "path"]]

In [None]:
SAVE = False
PLOT = True

SAVE_FOLDER = "../input/segs/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

In [None]:
show_cmap()

for idx in tqdm(range(len(segs))):
    patient_id = segs['patient_id'][idx]
    series = segs['series'][idx]

    imgs = sorted(glob.glob("../input/imgs/" + f"{patient_id}_{series}_*"))
    seg = load_segmentation(segs['path'][idx])
    
    ids = [i * len(imgs) // 5 for i in range(1, 5)]

    if PLOT:
        plt.figure(figsize=(20, 5))
    for i, frame in enumerate(seg):
        if SAVE:
            cv2.imwrite(SAVE_FOLDER + f"{patient_id}_{series}_{i:04d}.png", frame)
        
        if i in ids and PLOT:
            plt.subplot(1, len(ids), ids.index(i) + 1)
            img = cv2.imread(imgs[i], cv2.IMREAD_GRAYSCALE)
            
            plot_mask(img, frame)
            plt.title(f'Frame {i}')

    if PLOT:
        plt.show()
#         if idx > 10:
    break

In [None]:
# df_seg = pd.DataFrame({"mask_path": sorted(glob.glob('../input/segs/*.png'))})
# df_seg['patient_id'] = df_seg['mask_path'].apply(lambda x: int(x.split('/')[-1].split('_')[0]))
# df_seg['series'] = df_seg['mask_path'].apply(lambda x: int(x.split('/')[-1].split('_')[1]))
# df_seg['frame'] = df_seg['mask_path'].apply(lambda x: int(x.split('/')[-1].split('_')[2][:-4]))

# df_seg['img_path'] = df_seg['mask_path'].apply(lambda x: re.sub("/segs/", "/imgs/", x))
# df_seg = df_seg[["patient_id", "series", "frame", "img_path", "mask_path"]]

# df_seg.head()

In [None]:
# pixel_counts = []
# ph = np.zeros(6)

# for i in tqdm(range(len(df_seg))):
#     mask = cv2.imread(df_seg['mask_path'][i], cv2.IMREAD_GRAYSCALE)

#     cts = np.zeros(6)
#     counts = np.bincount(mask.flatten())
#     cts[:len(counts)] = counts

#     pixel_counts.append(cts)
    
# pixel_counts = np.array(pixel_counts)
# for k in labels:
#     df_seg[f'pixel_count_{labels[k]}'] = pixel_counts[:, k]

In [None]:
# df_seg.to_csv('../input/df_seg.csv', index=False)
# print('-> Saved df to ', '../input/df_seg.csv')

In [None]:
# i = np.random.choice(df_seg[df_seg['pixel_count_bowel'] > 10000].index)

# img = cv2.imread(df_seg['img_path'][i], cv2.IMREAD_GRAYSCALE)
# mask = cv2.imread(df_seg['mask_path'][i], cv2.IMREAD_GRAYSCALE)
            
# plt.imshow(img, cmap='gray')
# plt.imshow(np.where(mask, mask, np.nan), cmap='Set3', alpha=0.3)        
# plt.axis(False)
# plt.title(f'Frame {i}')
# plt.show()

### 3D

In [None]:
def center_crop_pad(img, size=384):
    h, w = img.shape[-2:]
    if h >= size:
        margin = (h - size) // 2
        img = img[..., margin : margin + size, :]
    else:
        new_img = np.zeros(list(img.shape[:-2]) + [size, img.shape[-1]])
        margin = (size - h) // 2
        new_img[..., margin: margin + h, :] = img
        img = new_img
    if w >= size:
        margin = (w - size) // 2
        img = img[..., margin : margin + size]
    else:
        new_img = np.zeros(list(img.shape[:-2]) + [size, size])
        margin = (size - w) // 2
        new_img[..., margin: margin + w] = img
        img = new_img
    
    return img

In [None]:
SAVE_FOLDER = "../input/3ds/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

MAX_LEN = 600
SIZE = 256

SAVE = True
PLOT = False

In [None]:
show_cmap()

for idx in tqdm(range(len(segs))):
    patient_id = segs['patient_id'][idx]
    series = segs['series'][idx]

    imgs = sorted(glob.glob("../input/imgs/" + f"{patient_id}_{series}_*"))
    imgs = np.array([cv2.imread(f, 0) for f in imgs[-MAX_LEN:]])

    imgs = center_crop_pad(imgs, 384)

    imgs = F.interpolate(torch.from_numpy(imgs).unsqueeze(0).unsqueeze(0), size=(SIZE, SIZE, SIZE), mode="nearest")[0][0]
    imgs = imgs.numpy()
    
    seg = load_segmentation(segs['path'][idx])[-MAX_LEN:]
    seg = center_crop_pad(seg, 384).copy()
    seg = F.interpolate(torch.from_numpy(seg).unsqueeze(0).unsqueeze(0), size=(SIZE, SIZE, SIZE), mode="nearest")[0][0]
    seg = seg.numpy()
    
    if SAVE:
        np.save(SAVE_FOLDER + "imgs/" + f"{patient_id}_{series}.npy", imgs)
        np.save(SAVE_FOLDER + "segs/" + f"{patient_id}_{series}.npy", seg)
    
    if PLOT:
        ids = [i * len(imgs) // 5 for i in range(1, 5)]
        plt.figure(figsize=(20, 5))
        for i, id_ in enumerate(ids):
            plt.subplot(1, len(ids), i + 1)
            plot_mask(imgs[id_], seg[id_])
            plt.title(f'Frame {id_}')
        plt.show()
        
        ids = [i * imgs.shape[1] // 5 for i in range(1, 5)]
        plt.figure(figsize=(20, 5))
        for i, id_ in enumerate(ids):
            plt.subplot(1, len(ids), i + 1)
            plot_mask(imgs[:, id_], seg[:, id_])
            plt.title(f'Frame {id_}')
        plt.show()
        
        ids = [i * imgs.shape[2] // 5 for i in range(1, 5)]
        plt.figure(figsize=(20, 5))
        for i, id_ in enumerate(ids):
            plt.subplot(1, len(ids), i + 1)
            plot_mask(imgs[:, :, id_], seg[:, :, id_])
            plt.title(f'Frame {id_}')
        plt.show()
#         if idx > 10:
#     break

In [None]:
df_seg = pd.DataFrame({
    "mask_path": sorted(glob.glob(f'{SAVE_FOLDER}/segs/*.npy'))
})
df_seg['patient_id'] = df_seg['mask_path'].apply(lambda x: int(x.split('/')[-1].split('_')[0]))
df_seg['series'] = df_seg['mask_path'].apply(lambda x: int(x.split('/')[-1].split('_')[1][:-4]))

df_seg['img_path'] = df_seg['mask_path'].apply(lambda x: re.sub("/segs/", "/imgs/", x))
df_seg = df_seg[["patient_id", "series", "img_path", "mask_path"]]

df_seg.to_csv('../input/df_seg_3d.csv', index=False)
df_seg.head()

## Data

In [None]:
df_seg = prepare_seg_data(DATA_PATH)

In [None]:
folds = pd.read_csv("../input/folds_4.csv")
df_seg = df_seg.merge(folds, how="left")

In [None]:
df_seg = df_seg[df_seg['fold'] == 0].reset_index(drop=True)

In [None]:
df_seg = df_seg[df_seg[[c for c in df_seg.columns if "norm" in c]].max(1) > 0.1]

In [None]:
df_seg = df_seg[
    (df_seg[SEG_TARGETS] > 1000).max(1)
].reset_index(drop=True)  # subsample for speed

In [None]:
len(df_seg)

In [None]:
# sns.histplot(df_seg['pixel_count_liver_norm'])

In [None]:
# df_seg = df_seg[(df_seg[SEG_TARGETS] > 0).max(1)].reset_index(drop=True)

In [None]:
transforms = get_transfos(augment=False, resize=(384, 384), crop=True, strength=0)

dataset = SegDataset(df_seg, transforms=transforms, for_classification=False, use_soft_target=True)

dataset_cls = SegDataset(df_seg, transforms=transforms, for_classification=True)

In [None]:
plt.figure(figsize=(20, 5))

for i, idx in enumerate(np.random.choice(len(dataset), 5)):
# for i, idx in enumerate(range(0, len(dataset), 10)):
    img, mask, y = dataset[idx]
    
#     _, y_cls, _ = dataset_cls[idx]
#     print(y, y_cls)

    img_ = img.numpy()[1].squeeze()
    mask_ = mask.numpy().squeeze()

    plt.subplot(1, 5, i + 1)
    plot_mask(img_, mask_)
#     plt.title(str(y.numpy().astype(int)))
    plt.title(str(np.round(y.numpy(), 2)))
    plt.axis(False)
    break

plt.show()

### Seg model

In [None]:
from model_zoo.models_seg import define_model
from util.torch import load_model_weights
from training.losses import SegLoss

In [None]:
model = define_model(
    "Unet",
    "tf_efficientnetv2_s",
    num_classes=5,
    num_classes_aux=5,
    n_channels=3,
    use_cls=False,
    increase_stride=False
)

In [None]:
model = load_model_weights(model, "../logs/2023-09-21/23/tf_efficientnetv2_s_0.pt")

In [None]:
x = img.cuda().unsqueeze(0)  # .repeat(2, 1, 1, 1)

In [None]:
model = model.cuda()

In [None]:
mask_pred, pred = model(x)
mask_pred.size(), pred.size()

In [None]:
# pred = (mask_pred.sigmoid() > 0.5)
# pred = torch.where(pred.amax(1) > 0, pred.int().argmax(1) + 1, 0)

In [None]:
msk = mask_pred.argmax(1).cpu().numpy()[0].astype(int)
# pred = torch.where(pred.amax(1) > 0, pred.int().argmax(1) + 1, 0)

In [None]:
np.bincount(msk.flatten())

In [None]:
plot_mask(img_, msk)

In [None]:
mask_pred.size()

In [None]:
loss = SegLoss({"name": "ce", "name_aux": "bce", "aux_loss_weight": 0., "num_classes": 6})

In [None]:
mask_pred.size()

In [None]:
mask_pred.size()

In [None]:
import torch.nn.functional as F
p = F.one_hot(mask.long(), num_classes=6).float().transpose(2, 3).transpose(2, 1)

In [None]:
a = 10
p = p * a - 2*a

In [None]:
loss(
    mask_pred.cpu().float(),
    pred.cpu(),
    mask.unsqueeze(0),  # .repeat(2, 1, 1, 1),
    y.unsqueeze(0),  # .repeat(2, 1)
)

In [None]:
loss(
    p.cpu().float(),
    pred.cpu(),
    mask.unsqueeze(0),  # .repeat(2, 1, 1, 1),
    y.unsqueeze(0),  # .repeat(2, 1)
)

## Training

In [None]:
class Config:
    """
    Parameters used for training
    """
    # General
    seed = 42
    verbose = 1
    device = "cuda"
    save_weights = True

    # Data
    resize = (512, 512)
    aug_strength = 3
    for_classification = True

    # k-fold
    k = 4
    folds_file = f"../input/folds_{k}.csv"
    selected_folds = [0]  # , 1, 2, 3]

    # Model
    name = "tf_efficientnetv2_s"
    pretrained_weights = None
    
    num_classes = 5
    num_classes_aux = 0
    drop_rate = 0
    drop_path_rate = 0
    n_channels = 3
    reduce_stride = False
    replace_pad_conv = False
    use_gem = True

    # Training    
    loss_config = {
        "name": "bce",
        "smoothing": 0,
        "activation": "sigmoid",
        "aux_loss_weight": 0,
        "name_aux": "patient",
        "smoothing_aux": 0,
        "activation_aux": "",
        "ousm_k": 0,  # todo ?
    }

    data_config = {
        "batch_size": 16,
        "val_bs": 16,
        "mix": "mixup",
        "mix_proba": 0.,
        "mix_alpha": 4.,
        "additive_mix": False,
        "num_classes": num_classes,
        "num_workers": 8,
    }

    optimizer_config = {
        "name": "AdamW",
        "lr": 5e-4,
        "warmup_prop": 0.,
        "betas": (0.9, 0.999),
        "max_grad_norm": 10.,
        "weight_decay": 0.,
    }

    epochs = 1

    use_fp16 = True
    verbose = 1
    verbose_eval = 100
    
    fullfit = False
    n_fullfit = 1
    
    local_rank = 0
    distributed = False
    world_size = 1

In [None]:
DEBUG = True
log_folder = None
run = None

In [None]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH)
    print(f"Logging results to {log_folder}")
    config_df = save_config(Config, log_folder + "config.json")
    create_logger(directory=log_folder, name="logs.txt")

preds = k_fold(Config, df_seg, log_folder=log_folder, run=run)

Done ! 