In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

In [None]:
import os
import cv2
import json
import glob
import torch
import pydicom
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F

from collections import Counter
from tqdm.notebook import tqdm


pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *
from data.preparation import *
from data.dataset import *
from data.transforms import *

### External data

- https://www.kaggle.com/datasets/brendanartley/lumbar-coordinate-pretraining-dataset

In [None]:
df = pd.read_csv(DATA_PATH + "coords/coords_pretrain.csv")

df["img_path"] = (
    DATA_PATH + "coords/data/processed_" + df["source"] + "_jpgs/" + df["filename"]
)
df = df.sort_values(["source", "filename", "level"])
df = df.groupby(["source", "filename", "img_path"]).agg(list).reset_index()

In [None]:
idx = 0
img = cv2.imread(df['img_path'][idx], 0)

plt.figure(figsize=(8, 8))
plt.imshow(img, cmap="gray")
for x, y, l in zip(df['x'][idx], df['y'][idx], df['level'][idx]):
    # print(x, y, l)
    plt.text(x, y, f"x   {l}", c="r", horizontalalignment="left", size=12)
plt.axis(False)
plt.show()

### Comp data

In [None]:
SAVE = False
PLOT = True

SAVE_FOLDER = "../input/coords/comp_data/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

In [None]:
df = prepare_data()
df = df.dropna(axis=0).reset_index(drop=True)

In [None]:
dfs = []
for idx in tqdm(range(len(df))):
    if df['orient'][idx] == "Axial":
        continue
    # if not df['series_id'][idx] == 4089185953:
    #     continue

    img = np.load(df['img_path'][idx])

    img = img[len(img) // 2]
    # img = np.concatenate([
    #     img[len(img) // 4][..., None],
    #     img[len(img) // 2][..., None],
    #     img[3 * len(img) // 4][..., None],
    # ], -1)

    img = np.clip(img, np.percentile(img.flatten(), 0), np.percentile(img.flatten(), 98))  # DO NOT FORGET
    img = (img - img.min()) / (img.max() - img.min()) 
    img = (img * 255).astype(np.uint8)

    df_coords = pd.DataFrame(
        df['coords'][idx], df['level'][idx], columns=['z', 'x', 'y']
    ).reset_index()
    df_coords = df_coords.groupby('index').mean().reset_index()
    df_coords = df_coords.rename(columns={"index": "level"}).sort_values('level', ignore_index=True)

    df_coords['relative_x'] = df_coords['x'] / img.shape[1]
    df_coords['relative_y'] = df_coords['y'] / img.shape[0]

    cols = ["study_id", "series_id", "orient", "weighting"]
    for col in cols:
        df_coords[col] = df[col][idx]

    df_coords['img_path'] = SAVE_FOLDER + f'{df["study_id"][idx]}_{df["series_id"][idx]}.png'
    dfs.append(df_coords)

    # if len(df_coords) == 5:
    #     continue
    # print(df['series_id'][idx])

    if SAVE:
        cv2.imwrite(df_coords["img_path"][0], img)

    if PLOT:
        plt.figure(figsize=(8, 8))
        plt.imshow(img, cmap="gray")
        for x, y, l in zip(df_coords['x'], df_coords['y'], df_coords['level']):
            # print(x, y, l)
            plt.text(x, y, f"x   {l}", c="r", horizontalalignment="left", size=12)
        # plt.axis(False)
        plt.show()
    
        # if idx > 5:
        break

In [None]:
if SAVE:
    df_ = pd.concat(dfs)
    df_ = df_[
        ["study_id", "series_id", "img_path", "level", "x", "y", "relative_x", "relative_y"]
    ]
    df_.to_csv("../input/coords/coords_comp.csv", index=False)
    display(df_.head(1))

### Dataset

In [None]:
from data.dataset import CoordsDataset
from data.preparation import prepare_coords_data

df_coords = prepare_coords_data()

dataset = CoordsDataset(df_coords, transforms=get_transfos(resize=(384, 384), strength=0, use_keypoints=True))

In [None]:
for idx in tqdm(range(len(dataset))):
    x, y, _  = dataset[idx]
    assert y.size() == torch.Size([5, 2])
    break

In [None]:
idx = np.random.choice(len(dataset))

x, y, _  = dataset[idx]

In [None]:

y = y[y.sum(-1) > 0]
img = x[0]
img = (img - img.min()) / (img.max() - img.min())
y *= img.shape[0]

plt.figure(figsize=(10, 10))
plt.imshow(img, cmap="gray")
plt.scatter(y[:, 0], y[:, 1])
plt.show()

Done ! 