**About** : This notebook is used to prepare the data


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
import cv2
import pycocotools
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from collections import Counter
from tqdm.notebook import tqdm

In [None]:
from params import *
from utils.rle import rles_to_mask_fix
from utils.plots import plot_sample
from data.preparation import prepare_mmdet_data

## Data

In [None]:
df = pd.read_csv(DATA_PATH + "train.csv")

In [None]:
df = df.groupby('id').agg(list).reset_index()

In [None]:
for col in df.columns[2:]:
    df[col] = df[col].apply(lambda x: np.unique(x)[0] if len(np.unique(x)) == 1 else np.unique(x))

In [None]:
df.head()

In [None]:
sns.countplot(x=df['cell_type'])
plt.show()

In [None]:
for idx in range(25):
#     idx = np.random.choice(len(df))
#     idx = 1

    img = cv2.imread(TRAIN_IMG_PATH + df['id'][idx] + ".png")

    rles = df['annotation'][idx]
    
    if df['cell_type'][idx] != "astro":
        continue
    
    mask = rles_to_mask_fix(rles, img.shape[:2], single_channel=False, fix=True)
    
    mask_fix = cv2.imread(HCK_FIX_PATH + df['id'][idx] + ".png")
    if mask_fix is not None:
        mask = mask * (mask_fix[:, :, 2] > 0)[None]
    
    plt.figure(figsize=(15, 10))
    plot_sample(img, mask, width=1)
    plt.axis(False)
    plt.title(f"{df['id'][idx]} - {df['cell_type'][idx]}")
    plt.show()
    
#     break

## Generation

### Test

In [None]:
metas = []

for idx in tqdm(range(len(df))):
    img = cv2.imread(TRAIN_IMG_PATH + df['id'][idx] + ".png")
    masks, meta = prepare_mmdet_data(df, idx)
    
    metas.append(meta)
    
    if idx > 0:
        break

In [None]:
plt.figure(figsize=(15, 10))
plot_sample(img, masks.max(0), meta['ann']['bboxes'], width=1)
plt.axis(False)
plt.title(f"{df['id'][idx]} - {df['cell_type'][idx]}")
plt.show()

### Run

In [None]:
from multiprocessing import Pool

FIX = True
p = Pool(processes=4)

df = pd.read_csv(DATA_PATH + "train.csv")
df = df.groupby('id').agg(list).reset_index()
for col in df.columns[2:]:
    df[col] = df[col].apply(lambda x: np.unique(x)[0] if len(np.unique(x)) == 1 else np.unique(x))

def prepare_mmdet_data_(i):
    return prepare_mmdet_data(df, idx=i, fix=FIX)

metas = []
for _, meta in tqdm(p.imap(prepare_mmdet_data_, range(len(df))), total=len(df)):
    metas.append(meta)

meta_df = pd.DataFrame.from_dict(metas)
if len(meta_df) == len(df):
    if FIX:
        meta_df.to_csv(OUT_PATH + "mmdet_data_fix.csv", index=False)
    else:
        meta_df.to_csv(OUT_PATH + "mmdet_data_nofix.csv", index=False)