In [1]:
import config
import numpy as np 
import pandas as pd 
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.io import imread
from data_preparation import train_df, valid_df
from rle_and_mask_related import make_image_gen
from model import unet, callbacks_list, dg_args
import losses
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
train_gen = make_image_gen(train_df, config.BATCH_SIZE, config.IMG_SCALING, config.TRAIN_DIR)
train_x, train_y = next(train_gen)
valid_x, valid_y = next(make_image_gen(valid_df, config.VALID_IMG_COUNT, config.IMG_SCALING, config.TRAIN_DIR))

image_gen = tf.keras.preprocessing.image.ImageDataGenerator(**dg_args)
label_gen = tf.keras.preprocessing.image.ImageDataGenerator(**dg_args)

def gen_pred(test_dir, img, model):
    rgb_path = os.path.join(TEST_DIR,img)
    img = cv2.imread(rgb_path)
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = np.squeeze(pred, axis=0)
    return cv2.imread(rgb_path), pred

def create_aug_gen(in_gen, seed = None):
    np.random.seed(seed if seed is not None else np.random.choice(range(9999)))
    for in_x, in_y in in_gen:
        seed = np.random.choice(range(9999))
        # keep the seeds syncronized otherwise the augmentation to the images is different from the masks
        g_x = image_gen.flow(255*in_x, 
                             batch_size = in_x.shape[0], 
                             seed = seed, 
                             shuffle=True)
        g_y = label_gen.flow(in_y, 
                             batch_size = in_x.shape[0], 
                             seed = seed, 
                             shuffle=True)

        yield next(g_x)/255.0, next(g_y)
seg_model = unet()

aug_gen = create_aug_gen(make_image_gen(train_df, config.BATCH_SIZE, config.IMG_SCALING, config.TRAIN_DIR))
seg_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001), 
                  loss = losses.FocalLoss,
                  metrics = [losses.dice_coef, "binary_accuracy"])

seg_model.fit(aug_gen,
              steps_per_epoch=config.MAX_TRAIN_STEPS,
              epochs= 20,
              validation_data=(valid_x, valid_y),
              callbacks=callbacks_list,
              workers=1)
fullres_model = tf.keras.models.Sequential()
fullres_model.add(tf.keras.layers.AvgPool2D(config.IMG_SCALING, input_shape = (None, None, 3)))
fullres_model.add(seg_model)
fullres_model.add(tf.keras.layers.UpSampling2D(config.IMG_SCALING))
fullres_model.save('/kaggle/working/fullres_model & weights/fullres_model.h5')


caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


Epoch 1/20
Epoch 1: val_dice_coef improved from -inf to 0.00897, saving model to /kaggle/working/fullres_model & weights/seg_model_weights.best.hdf5
Epoch 2/20
Epoch 2: val_dice_coef did not improve from 0.00897
Epoch 3/20
Epoch 3: val_dice_coef did not improve from 0.00897
Epoch 4/20
Epoch 4: val_dice_coef improved from 0.00897 to 0.00945, saving model to /kaggle/working/fullres_model & weights/seg_model_weights.best.hdf5
Epoch 5/20
Epoch 5: val_dice_coef improved from 0.00945 to 0.01136, saving model to /kaggle/working/fullres_model & weights/seg_model_weights.best.hdf5
Epoch 6/20
Epoch 6: val_dice_coef improved from 0.01136 to 0.01365, saving model to /kaggle/working/fullres_model & weights/seg_model_weights.best.hdf5
Epoch 7/20
Epoch 7: val_dice_coef improved from 0.01365 to 0.01433, saving model to /kaggle/working/fullres_model & weights/seg_model_weights.best.hdf5
Epoch 8/20
Epoch 8: val_dice_coef improved from 0.01433 to 0.01462, saving model to /kaggle/working/fullres_model & w