In [None]:
# install required dependencies

# !conda install -c conda-forge pillow -y
# !conda install -c conda-forge pydicom -y
# !conda install -c conda-forge gdcm -y
# !pip install pylibjpeg pylibjpeg-libjpeg

# only required for tpu processing, which i don't use

# !pip uninstall tensorflow -y
# !pip uninstall tensorflow-io -y
# !pip install -q tensorflow-gpu
# !pip install -q tensorflow-io 

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import pickle

from tqdm import tqdm
import glob
import os
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import seaborn as sns
import pprint
import pydicom as dicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import albumentations as A 
import cv2

from PIL import Image

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
os.listdir('../input')
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## EDA and preprocessing
Preparo il dataset ad essere utilizzato dal mio modello

In [None]:
basepath = '../input/siim-covid19-detection/'

train_study_df = pd.read_csv(basepath + "train_study_level.csv")
train_image_df = pd.read_csv(basepath + "train_image_level.csv")

train_study_df.head()

In [None]:
train_directory = basepath + "train/"
test_directory =  basepath + "test/"

# merging study and image train dataframes
try:
    train_study_df['StudyInstanceUID'] = train_study_df['id'].apply(lambda x: x.replace('_study', ''))
    del train_study_df['id']
except:
    assert 'StudyInstanceUID' in train_study_df.columns, 'Something went wrong with the dataframe. Rerun previous cells'
    print('train_study_df[id] already deleted')
train_df = train_image_df.merge(train_study_df, on='StudyInstanceUID')

# adding path to the df
training_paths = []
for sid in tqdm(train_df['StudyInstanceUID']):
    training_paths.append(glob.glob(os.path.join(train_directory, sid +"/*/*"))[0])

train_df['path'] = training_paths

train_df.head()

In [None]:
train_df.shape

A small preview of the distribution of data

In [None]:
params = {'legend.fontsize': 'x-large',
          'figure.figsize': (20, 32),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

fig, ax = plt.subplots(2,2)
sns.countplot(x = train_df["Negative for Pneumonia"], ax=ax[0,0],color="#ffb4a2")
ax[0,0].set_title("Negative for Pneumonia Distribution",font="Serif", fontsize=20,weight="bold")

sns.countplot(x = train_df["Typical Appearance"], ax=ax[0,1],color="#e5989b")
ax[0,1].set_title("Typical Appearance Distribution",font="Serif", fontsize=20,weight="bold")

sns.countplot(x = train_df["Indeterminate Appearance"], ax=ax[1,0],color="#b5838d")
ax[1,0].set_title("Indeterminate Appearance Distribution",font="Serif", fontsize=20,weight="bold")

sns.countplot(x = train_df["Atypical Appearance"], ax=ax[1,1],color="#6d6875")
ax[1,1].set_title("Atypical Appearance Distribution",font="Serif", fontsize=20,weight="bold")

fig.subplots_adjust(wspace=0.1, hspace=0.2, top=0.5)
plt.show()

In [None]:
def resize(array, size, resample=Image.NEAREST):
    # Original from: https://www.kaggle.com/xhlulu/vinbigdata-process-and-resize-to-image
    im = Image.fromarray(array)
    im.convert('1')
    im = im.resize((size, size), resample)
    
    return np.asarray(im)

In [None]:
voi_lut=True
fix_monochrome=True
size = 224

# if you only want the pixel arary use this
def dicom_to_pixelarray(filename):
    dicom_header = dicom.dcmread(filename) 
    data = apply_voi_lut(dicom_header.pixel_array, dicom_header)
    data = resize(data, size)

    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom_header.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = data - np.min(data)
    data = data / np.max(data)
    modified_image_data = (data * 255).astype(np.uint8)
    return modified_image_data

def dicom_dataset_to_dict(filename,func):
    """Credit: https://github.com/pydicom/pydicom/issues/319
               https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
    """
    
    dicom_header = dicom.dcmread(filename) 
    
    #====== DICOM FILE DATA ======
    dicom_dict = {}
    repr(dicom_header)
    for dicom_value in dicom_header.values():
        if dicom_value.tag == (0x7fe0, 0x0010):
            # discard pixel data
            continue
        if type(dicom_value.value) == dicom.dataset.Dataset:
            dicom_dict[dicom_value.name] = dicom_dataset_to_dict(dicom_value.value)
        else:
            v = _convert_value(dicom_value.value)
            dicom_dict[dicom_value.name] = v
      
    del dicom_dict['Pixel Representation']
    
    if func!='metadata_df':
        #====== DICOM IMAGE DATA ======
        # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to "human-friendly" view
        if voi_lut:
            data = apply_voi_lut(dicom_header.pixel_array, dicom_header)
        else:
            data = dicom_header.pixel_array
        # depending on this value, X-ray may look inverted - fix that:
        if fix_monochrome and dicom_header.PhotometricInterpretation == "MONOCHROME1":
            data = np.amax(data) - data
        data = data - np.min(data)
        data = data / np.max(data)
        modified_image_data = (data * 255).astype(np.uint8)
    
        return dicom_dict, modified_image_data
    
    else:
        return dicom_dict

def _sanitise_unicode(s):
    return s.replace(u"\u0000", "").strip()

def _convert_value(v):
    t = type(v)
    if t in (list, int, float):
        cv = v
    elif t == str:
        cv = _sanitise_unicode(v)
    elif t == bytes:
        s = v.decode('ascii', 'replace')
        cv = _sanitise_unicode(s)
    elif t == dicom.valuerep.DSfloat:
        cv = float(v)
    elif t == dicom.valuerep.IS:
        cv = int(v)
    else:
        cv = repr(v)
    return cv

filename = train_df.path[0]
    
df, img_array = dicom_dataset_to_dict(filename, 'fetch_both_values')

fig, ax = plt.subplots(1, 2, figsize=[15, 8])
ax[0].imshow(img_array, cmap=plt.cm.gray)
ax[1].imshow(img_array, cmap=plt.cm.plasma)    
plt.show()

pprint.pprint(df)

In [None]:
# !! This takes a few minutes to process all the dicom files !!
def load_pixel_array(df):    
    pixel_arrays = []
    y = []
    for i, row in tqdm(df.iterrows(), total=df.shape[0]):
        try:
            pixel_arrays.append(dicom_to_pixelarray(row['path']))
            y.append(np.array([row['Negative for Pneumonia'], 
                              row['Typical Appearance'], 
                              row['Indeterminate Appearance'], 
                              row['Atypical Appearance']]))
        except RuntimeError:
            continue

    pixel_arrays = np.asarray(pixel_arrays)
    x = np.repeat(pixel_arrays[..., np.newaxis], 3, -1)
    y = np.asarray(y)
    return x, y

In [None]:
reload = False
if reload:
    x_train, y_train = load_pixel_array(train_df)
elif os.path.isfile('../input/reti-siim-covid/x_train.npy'):
    x_train = np.load('../input/reti-siim-covid/x_train.npy')
    y_train = np.load('../input/reti-siim-covid/y_train.npy')
elif os.path.isfile('/kaggle/working/x_train.npy'):
    x_train = np.load('/kaggle/working/x_train.npy')
    y_train = np.load('/kaggle/working/y_train.npy')
else:
    x_train, y_train = load_pixel_array(train_df)

print(x_train.shape, y_train.shape)

In [None]:
# salvo gli output così non ci metto troppo temo a caricare nuovamente gli array una volta chiuso il notebook
np.save('/kaggle/working/x_train.npy', x_train)
np.save('/kaggle/working/y_train.npy', y_train)

## Data augmentation with TPU acceleration:
Using [this notebook](https://www.kaggle.com/cdeotte/cutmix-and-mixup-on-gpu-tpu) I try to speed up the learning process 

There is an error when I try to run it on the TPU which i didn't have time to fix, so i w won't use this code

In [None]:
from kaggle_datasets import KaggleDatasets
import tensorflow.keras as K
import tensorflow as tf
try:
    import tensorflow_io as tfio
except:
    print('tensorflow_io not installed')
import re

In [None]:
# Detect hardware, return appropriate distribution strategy
def get_strategy():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
        print('Running on TPU ', tpu.master())
    except ValueError:
        tpu = None

    if tpu:
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    else:
        strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

    print("REPLICAS: ", strategy.num_replicas_in_sync)
    return strategy

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

strategy = get_strategy()

# Configuration
IMAGE_SIZE = [224, 224]
EPOCHS = 25
FOLDS = 3
SEED = 777
BATCH_SIZE = 32 * strategy.num_replicas_in_sync
AUG_BATCH = BATCH_SIZE
FIRST_FOLD_ONLY = False


# Enable mixed precision and XLA?
# Kaggle TPUs and GPUs don't fully support these yet
MIXED_PRECISION = False
XLA_ACCELERATE = False

if MIXED_PRECISION:
    from tensorflow.keras.mixed_precision import experimental as mixed_precision
    if tpu: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    else: policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
    mixed_precision.set_policy(policy)
    print('Mixed precision enabled')

if XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')

In [None]:
# Data access
GCS_PATH = KaggleDatasets().get_gcs_path('siim-covid19-detection')

TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*/*/*.dcm')
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*/*/*.dcm') # predictions on this dataset should be submitted for the competition
len(TRAINING_FILENAMES), len(TEST_FILENAMES)

In [None]:
train_df_id = train_df.set_index('id')

def parse_image(filename):    
    id = bytes.decode(filename.numpy()).split('/')[-1].split('.')[0] + '_image'
    label = tf.constant(train_df_id.loc[id][['Negative for Pneumonia',
                                               'Typical Appearance',
                                               'Indeterminate Appearance',
                                               'Atypical Appearance']],
                        dtype = tf.uint8, shape=[4])
    image_bytes = tf.io.read_file(filename)

    # if bad decoding - raise an error
    image = tfio.image.decode_dicom_image(image_bytes, on_error='strict',
        dtype=tf.float32) 
    image = tf.image.resize(image, IMAGE_SIZE) # optional
    
    return image, label

def load_dataset(filenames, labeled = True, ordered = False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # Diregarding data order. Order does not matter since we will be shuffling the data anyway
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
        
    dataset = tf.data.Dataset.list_files(GCS_PATH + '/train/*/*/*.dcm') # automatically interleaves reads from multiple files    
    dataset = dataset.with_options(ignore_order) # use data as soon as it streams in, rather than in its original order
    dataset = dataset.map(lambda x: tf.py_function(parse_image, [x], [tf.float32, tf.uint8]), num_parallel_calls = AUTO) # returns a dataset of (image, label) pairs if labeled = True or (image, id) pair if labeld = False
    return dataset

def data_augment(image, label):
    # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
    # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
    # of the TPU while the TPU itself is computing gradients.
    image = tf.image.random_flip_left_right(image)
    return image, label   

def get_training_dataset(dataset, do_aug=True):
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.batch(AUG_BATCH)
    if do_aug: dataset = dataset.map(transform, num_parallel_calls=AUTO) # note we put AFTER batching
    dataset = dataset.unbatch()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(dataset, do_onehot=True):
    dataset = dataset.batch(BATCH_SIZE)
    if do_onehot: dataset = dataset.map(onehot, num_parallel_calls=AUTO) # we must use one hot like augmented train data
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    return len(TRAINING_FILENAMES)

NUM_TRAINING_IMAGES = int( count_data_items(TRAINING_FILENAMES) * (FOLDS-1.)/FOLDS )
NUM_VALIDATION_IMAGES = int( count_data_items(TRAINING_FILENAMES) * (1./FOLDS) )
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE

print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

In [None]:
# uncomment to use TPU
# d = load_dataset(TRAINING_FILENAMES)

In [None]:
# show output of dataset
# for element in d.take(1):
#     plt.imshow(element[0].numpy()[0])
#     print(element[0].numpy()[0].shape)
#     print(element[1].numpy())

In [None]:
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

def transform(image,label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 16. * tf.random.normal([1],dtype='float32') 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3]),label

In [None]:
# row = 3; col = 4;
# all_elements = get_training_dataset(load_dataset(TRAINING_FILENAMES[:3]),do_aug=False).unbatch()
# one_element = tf.data.Dataset.from_tensors( next(iter(all_elements)) )
# augmented_element = one_element.repeat().map(transform).batch(row*col)

# for (img,label) in augmented_element:
#     plt.figure(figsize=(15,int(15*row/col)))
#     for j in range(row*col):
#         plt.subplot(row,col,j+1)
#         plt.axis('off')
#         plt.imshow(img[j,])
#     plt.show()
#     break

## Model definition
Provo a sfruttare transfer learning usando il modello ResNet50, pretrained in keras, facendo training solo sul layer FC finale

In [None]:
import tensorflow.keras as K
import tensorflow as tf
import datetime as dt

In [None]:
# building the actual model and enabling GPU acceleration
# TPU acceleration takes a while to set up and i dont have the time
# need to use Google SDK to pipe data directly into tpu, otherwise the 
# TPU is bottlenecked by data transfer

def get_model(file=None):
    if file is not None:
        with strategy.scope():
            model = K.models.load_model(file)
            return model
    with strategy.scope():
        model = K.models.Sequential()

        input_t = K.Input(shape=(224, 224, 3))
        res_model = K.applications.ResNet50(include_top=False,
                                           weights="imagenet",
                                           input_tensor=input_t)
#         for layer in res_model.layers:
#             layer.trainable = False

        model.add(res_model)
        model.add(tf.keras.layers.GlobalAveragePooling2D())
        model.add(K.layers.Dense(4, activation='softmax'))

        model.compile(loss='categorical_crossentropy',
                      optimizer=K.optimizers.RMSprop(lr=2e-5),
                      metrics=['accuracy', 
                               K.metrics.AUC(multi_label=True),
                               K.metrics.AUC(name='prc', curve='PR'),])
    return model

In [None]:
VAL_SPLIT = 0.15
datagen = K.preprocessing.image.ImageDataGenerator(rotation_range=30,
                                                   width_shift_range=0.1,
                                                   height_shift_range=0.1,
                                                   brightness_range=(0.8, 1.2),
                                                   shear_range=15,
                                                   horizontal_flip=True,
                                                   vertical_flip=True,
                                                   validation_split=VAL_SPLIT,
                                                  )


train_it = datagen.flow(x_train,y_train, batch_size=BATCH_SIZE,subset='training')

validation_it = datagen.flow(x_train, y_train, batch_size=BATCH_SIZE, subset='validation')

# generate samples and plot
batch = train_it.next()
for i in range(9):
    # define subplot
    plt.subplot(330 + 1 + i)
    # generate batch of images
    # convert to unsigned integers for viewing
    image = batch[0][i].astype('uint8')
    # plot raw pixel data
    plt.imshow(image)
# show the figure
plt.show()

In [None]:
file = '../input/resmodel-weights/resmodel_weights.h5'
model = get_model(file=None)
model.summary()
#  628,404 = numero trainable parameters
print(f'numero data points: {len(x_train)}\nnumero trainable parameters/numero data points: {628404/len(x_train)}')

In [None]:
from keras.utils import plot_model
plot_model(model, to_file='model.png', show_shapes=True,show_layer_names=True)

In [None]:
check_point = K.callbacks.ModelCheckpoint(filepath="/kaggle/working/resmodel_weights_ckp.h5",
                                          monitor="val_acc",
                                          mode="max",
                                          save_best_only=True,
                                          )

early_stopping = K.callbacks.EarlyStopping(monitor='loss', patience = 5)

history = model.fit(train_it, batch_size=BATCH_SIZE, epochs=100, 
                    verbose=1, validation_data=validation_it, callbacks=[check_point, early_stopping])

now = dt.datetime.now()
name = 'resmodel_pooled'
model.summary()
model.save(f"/kaggle/working/{name}_weights_{now.strftime('%d_%m_%Y__%H_%M')}.h5")
with open(f"/kaggle/working/{name}_history_{now.strftime('%d_%m_%Y__%H_%M')}.pkl", 'wb') as file_pi:
        pickle.dump(history.history, file_pi)

In [None]:
# list all data in history
print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# auc history
plt.plot(history.history['auc'])
plt.plot(history.history['val_auc'])
plt.title('model auc')
plt.ylabel('auc')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:

def TPU_train_cross_validate(folds = 5):
    histories = []
    models = []
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 3)
    kfold = KFold(folds, shuffle = True, random_state = SEED)
    for f, (trn_ind, val_ind) in enumerate(kfold.split(TRAINING_FILENAMES)):
        print(); print('#'*25)
        print('### FOLD',f+1)
        print('#'*25)
        train_dataset = load_dataset(list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[trn_ind]['TRAINING_FILENAMES']), labeled = True)
        val_dataset = load_dataset(list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[val_ind]['TRAINING_FILENAMES']), labeled = True, ordered = True)
        model = get_model()
        history = model.fit(
            get_training_dataset(train_dataset), 
            steps_per_epoch = STEPS_PER_EPOCH,
            epochs = EPOCHS,
            callbacks = [lr_callback],#, early_stopping],
            validation_data = get_validation_dataset(val_dataset),
            verbose=2
        )
        models.append(model)
        histories.append(history)
    return histories, models

def TPU_train_and_predict(folds = 5):
    test_ds = get_test_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and ids, order matters.
    test_images_ds = test_ds.map(lambda image, idnum: image)
    print('Start training %i folds'%folds)
    histories, models = train_cross_validate(folds = folds)
    print('Computing predictions...')
    # get the mean probability of the folds models
    probabilities = np.average([models[i].predict(test_images_ds) for i in range(folds)], axis = 0)
    predictions = np.argmax(probabilities, axis=-1)
    print('Generating submission.csv file...')
    test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
    np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')
    return histories, models
    
# run train and predict for TPU
# histories, models = train_and_predict(folds = FOLDS)