In [1]:
# Imports

# Data and ML
import numpy as np
import tensorflow as tf
# import tensorflow_io as tfio
from cloud_tpu_client import Client
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.inception_v3 import InceptionV3
import pandas as pd

# Image preprocessing
import pydicom
import cv2
import matplotlib.pyplot as plt

# Miscellaneous
from typing import Tuple, List, Callable
import os
import multiprocessing
import random

In [2]:
DATA_DIR = r"/kaggle/input/osic-pulmonary-fibrosis-progression"
PATIENT_DIR = "test"

def pathfinder():
    paths_to_search = [["/kaggle"]]
    
    while paths_to_search:
        p = paths_to_search.pop(0)
        for f in os.listdir(os.path.join(*p)):
            if ".dcm" in f:
                return os.path.join(*p[:-2]), p[-2]
            new_p = p + [f]
            if os.path.isdir(os.path.join(*new_p)):
                paths_to_search.append(new_p)
                
    raise Exception("NOT FOUND!")
    
    
# DATA_DIR, PATIENT_DIR = pathfinder()

In [3]:
TARGET_IMAGE_SHAPE = (224, 224, 3)
OUTPUT_SHAPE = (1)
META_SHAPE = (10,)
QUOTA = 12

In [4]:
def clip_border(image: np.ndarray) -> np.ndarray:
    bounds = np.array(np.nonzero(~(image == 0)))
    tl = np.min(bounds, axis=1)
    br = np.max(bounds, axis=1)

    return image[tl[0]:br[0], tl[1]:br[1]]


def rescale_for_lungs(image: np.ndarray, meta_image: pydicom.FileDataset, scan_range: str = "none") -> np.ndarray:
    hounsfield_units = image * meta_image.RescaleSlope + meta_image.RescaleIntercept
    if scan_range == "low":
        lung_min = -1400
        lung_max = -950
    elif scan_range == "high":
        lung_min = -240
        lung_max = -160
    elif scan_range == "lung":
        lung_min = -1400
        lung_max = -200
    else:
        return hounsfield_units
    return (hounsfield_units - lung_min) / (lung_max - lung_min)


def map_to_unit(image: np.ndarray) -> np.ndarray:
    min_val, max_val = np.min(image), np.max(image)
    return (image - min_val) / (max_val - min_val)


def unified_tone_map(image: np.ndarray) -> np.ndarray:
    rescale = 0.1
    offset = 0.2
    return image * rescale + offset


def preprocess(medical_image: pydicom.FileDataset, apply_tone_map=False, scan_range: str = "none") -> np.ndarray:
    processed = rescale_for_lungs(clip_border(medical_image.pixel_array), medical_image, scan_range)
    if apply_tone_map:
        return unified_tone_map(processed)
    else:
        return processed


def preprocess_tri_channel(medical_image: pydicom.FileDataset, apply_tone_map=False) -> np.ndarray:
    processed = np.stack((
        preprocess(medical_image, False, "low"),
        preprocess(medical_image, False, "high"),
        preprocess(medical_image, False, "lung")
    ), axis=2)

    if apply_tone_map:
        return unified_tone_map(processed)
    else:
        return processed

In [5]:
df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
first_readings = df.drop_duplicates(["Patient"], keep="first")
# remaining_weeks = df[~df.isin(first_readings)].dropna()
# remaining_weeks = remaining_weeks.merge(first_readings[["Patient", "Weeks", "FVC", "Percent"]], 
#                                         left_on="Patient", right_on="Patient", suffixes=("", "_BASE"))
# remaining_weeks = remaining_weeks.drop("Percent", 1).drop_duplicates(["Patient", "Weeks"])

In [6]:
def splitter(l: List[int], quota: int) -> List[List[int]]:
    qty = len(l) // quota + (1 if len(l) % quota != 0 else 0)
    aux_qty = qty * quota
    
    frac_ls = [len(l) * i / aux_qty for i in range(aux_qty)]
    
    subsets = [[] for _ in range(qty)]
    
    for i, fl in enumerate(frac_ls):
        subsets[i % qty].append(l[round(fl)])
        
    return subsets

def extract_meta(patient):
    # Base FVC: continous (1)
    # Healthy FVC: continous (1)
    # Base week: continous (1)
    # Age: continous (1)
    # Gender: categorical (2)
    # Smoking status: categorical (3)
    # Prediction Week: continous (1)
    # Total: 10
    
    row = first_readings[first_readings["Patient"] == patient]
    meta = np.zeros(META_SHAPE, dtype=np.float32)
    meta[0] = row["FVC"]
    meta[1] = 100 * meta[0] / row["Percent"]
    meta[2] = row["Weeks"]
    meta[3] = row["Age"] # TRANSFORM
    if (row["Sex"] == "Male").all():
        meta[4: 6] = [1, 0]
    else:
        meta[4: 6] = [0, 1]
    status = row["SmokingStatus"]
    if (status == "Ex-smoker").all():
        meta[6:9] = [1, 0, 0]
    elif (status == "Never smoked").all():
        meta[6:9] = [0, 1, 0]
    elif (status == "Currently smokes").all():
        meta[6:9] = [0, 0, 1]
    meta[9] = 0
    
#     fvc = np.zeros([1], dtype=np.float32)
#     fvc[0] = row["FVC"]
    
    return tf.convert_to_tensor(meta, dtype_hint=tf.float32)


def load_image_sample(patient, sample_set):
    w, h, c = TARGET_IMAGE_SHAPE
    images = np.zeros([len(sample_set), w, h, c])
#     patient = patient.numpy().decode("utf8")
    for i, sample in enumerate(sample_set):
        try:
            dcm_file = pydicom.dcmread(os.path.join(DATA_DIR, PATIENT_DIR, patient, f"{sample}.dcm"))
            images[i] = cv2.resize(preprocess_tri_channel(dcm_file, True), (w, h))
        except Exception as e:
            #print("Broken:", e)
            pass
    return tf.convert_to_tensor(images, dtype=tf.float32)
            
            
def file_set_generator(quota: int):
    TOTAL_PATIENT_SET = os.listdir(os.path.join(DATA_DIR, PATIENT_DIR))
    
    all_subsets = []

    for patient in TOTAL_PATIENT_SET:
        files = sorted(map(lambda f: int(f.replace(".dcm", "")), os.listdir(os.path.join(DATA_DIR, PATIENT_DIR, patient))))
        subsets = splitter(files, quota)
        for ss in subsets:
            all_subsets.append((patient, ss))
            
#     rand.shuffle(all_subsets)
        
    for patient, ss in all_subsets:
        # Load the image data
        image_data = load_image_sample(patient, ss)
        meta = extract_meta(patient)
        yield (meta, image_data), tf.convert_to_tensor([patient], dtype_hint=tf.string)


def test_set_generator():
    return file_set_generator(QUOTA)

In [7]:
w, h, c = TARGET_IMAGE_SHAPE

dataset = tf.data.Dataset.from_generator(
    test_set_generator,
    ((tf.float32, tf.float32), tf.string),
    ((tf.TensorShape(META_SHAPE), tf.TensorShape([QUOTA, w, h, c])), tf.TensorShape([1]))
).batch(1)

In [8]:
model = keras.models.load_model(r"../input/hopefully-good-model/continue_probably_overfit.hdf5")

In [9]:
fixed_model = keras.Model(
    inputs=model.input[1],
    outputs=model.get_layer("lstm_2").output
)
dist_model = keras.Model(
    inputs=[model.input[0], keras.Input(tensor=fixed_model.output)], 
    outputs=model.output
)

In [10]:
samples_per_patient_week = 1
seen_patients = []

curr_patient = 1

last_patient = None

last_image = None

counter = 0

LAST_PATIENT_DATA = []

START_WEEK = -12
END_WEEK = 133


def process_patient(data, patient_name, out_file):
    num_subsets = len(data)

    data_points = np.zeros([END_WEEK - START_WEEK + 1, samples_per_patient_week], dtype=np.float32)

    for images, meta in data:
#         base_cnn_calc = fixed_cnn_model(image, training=False)
#         base_meta_calc = fixed_meta_model(meta, training=False)
        base_calc = fixed_model(image, training=False)
    
        for week in range(START_WEEK, END_WEEK + 1):
            meta[0, 9] = week
            for retry in range(samples_per_patient_week):
#                 data_points[week, retry] += dist_model((base_cnn_calc, base_meta_calc), training=True) / num_subsets
                data_points[week, retry] += dist_model((meta, base_calc), training=True) / num_subsets

    week_preds = np.mean(data_points, axis=1)
    week_stds = np.ones([data_points.shape[0]]) * 200 # np.std(data_points, axis=1, ddof=1)
    
    for week, mean, std in zip(range(START_WEEK, END_WEEK + 1), week_preds, week_stds):
        out_file.write(f"{patient_name}_{week},{mean},{std}\n")
    

with open("submission.csv", "w") as out_file:
    out_file.write("Patient_Week,FVC,Confidence\n")
    
    for ((meta, image), patient) in dataset:
        patient = patient.numpy()[0, 0].decode("utf8")
        
        if last_patient is None:
            last_patient = patient

        patient_data = df[df["Patient"] == patient]

        if last_patient != patient:
            process_patient(LAST_PATIENT_DATA, last_patient, out_file)
            LAST_PATIENT_DATA = []

        last_patient = patient

        meta_numpy = meta.numpy()

        LAST_PATIENT_DATA.append((image.numpy(), meta_numpy))
