In [None]:
import pandas as pd
import numpy as np
import os
from IPython.display import display
from kaggle_datasets import KaggleDatasets

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.xception import Xception

from sklearn.model_selection import GroupKFold

In [None]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy

In [None]:
strategy = auto_select_accelerator()

In [None]:
class CFG:
    debug=False
    dataset_dir="../input/ranzcr-clip-catheter-line-classification/"
    batch_size=8 if debug else strategy.num_replicas_in_sync * 16
    n_epochs=2 if debug else 20
    n_folds=10
    input_shape=(299,299)
    target_cols=['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline',
       'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']

In [None]:
train=pd.read_csv(f"{CFG.dataset_dir}train.csv")

GCS_DS_PATH = KaggleDatasets().get_gcs_path("ranzcr-clip-catheter-line-classification")

In [None]:
def build_decoder(with_labels=True, target_size=(256, 256), ext='jpg'):
    def decode(path):
        file_bytes = tf.io.read_file(path)
        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")

        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode


def augment_with_labels(img,label):
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img,label
    

def build_dataset(paths, labels=None, bsize=32, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir="cache"):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if (labels is None) else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset=dset.cache()
    dset = dset.map(augment_with_labels, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)
    
    return dset

In [None]:
paths = f"{GCS_DS_PATH}/train/" + train['StudyInstanceUID'] + '.jpg'
labels = train[CFG.target_cols]

In [None]:
def get_fold(train):
    fold=train.copy()
    splitter=GroupKFold(n_splits=CFG.n_folds)
    for n,(train_idx,val_idx) in enumerate(splitter.split(train,groups=train["PatientID"])):
        fold.loc[val_idx,"folds"]=n
    fold["folds"]=fold["folds"].astype(int)
    return fold

fold=get_fold(train)

In [None]:
train_idx=(fold["folds"]!=0)
val_idx=(fold["folds"]==0)

train_paths=paths[train_idx]
valid_paths=paths[val_idx]

train_labels=labels[train_idx]
valid_labels=labels[val_idx]

In [None]:
decoder = build_decoder(with_labels=True, target_size=CFG.input_shape)

train_dataset = build_dataset(
    train_paths, train_labels, bsize=CFG.batch_size, decode_fn=decoder
)

valid_dataset = build_dataset(
    valid_paths, valid_labels, bsize=CFG.batch_size, decode_fn=decoder,
    repeat=False, shuffle=False, augment=False
)

In [None]:
xception=Xception(include_top=False,weights="imagenet",input_shape=(299,299,3),pooling="avg")
xception_dense=keras.Sequential([
    xception,
    keras.layers.Dense(11,activation="sigmoid")
])
xception_dense.summary()

In [None]:
adam=keras.optimizers.Adam(learning_rate=1e-3)
xception_dense.compile(optimizer=adam,loss="binary_crossentropy",metrics=[keras.metrics.AUC(multi_label=True,name="auc")])

save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
fit_callbacks=[
    keras.callbacks.EarlyStopping(monitor="val_auc",patience=8,mode="max"),
    keras.callbacks.ReduceLROnPlateau(monitor="val_auc",patience=3,min_lr=1e-7,mode="max",factor=0.5,verbose=1),
    keras.callbacks.ModelCheckpoint("ckpt",monitor="val_auc",mode="max",save_best_only=True,options=save_locally)
]

In [None]:
steps_per_epoch = train_paths.shape[0]//CFG.batch_size

history = xception_dense.fit(
    train_dataset, 
    epochs=CFG.n_epochs,
    verbose=1,
    callbacks=fit_callbacks,
    steps_per_epoch=steps_per_epoch,
    validation_data=valid_dataset)

In [None]:
models.save_model(xception_dense,"last_ckpt",options=save_locally)

In [None]:
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('history.csv')