### imports

In [1]:
import os
from tqdm import tqdm
tqdm.pandas()

import numpy as np
import pandas as pd

In [2]:
import scipy
import tensorflow as tf

print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))
print(tf.test.is_built_with_cuda())

2.10.1
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
True


In [3]:
from tensorflow.keras.applications import EfficientNetB4, EfficientNetB7
from tensorflow.keras.layers import Dense, Dropout

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array

from tensorflow.keras.callbacks import ModelCheckpoint

In [4]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

### load splits

In [6]:
train_df = pd.read_csv("embryoAI_datasets/embryoAI_train.csv")
val_df = pd.read_csv("embryoAI_datasets/embryoAI_val.csv")
test_df = pd.read_csv("embryoAI_datasets/embryoAI_test.csv")

In [7]:
train_df.columns

Index(['case', 'frame', 'phase', 'path', 'flag'], dtype='object')

In [8]:
train_df.shape

(244767, 5)

In [9]:
train_df.phase.value_counts()

t9+     41877
tPNa    35476
t8      26631
t4      24357
t2      24145
tEB     15760
tM      14282
tSB     14086
t7       8779
tB       8704
tPB2     7371
t6       7075
t5       6500
tPNf     5565
t3       4100
tHB        59
Name: phase, dtype: int64

### parameters

In [10]:
IMG_WIDTH = 224
IMG_HEIGHT = 224
CHANNELS = 3

INPUT_COL = 'path'
LABEL_COL = 'phase'

DATA_DIR = 'embryoAI_images'

BATCH_SIZE = 32
MAX_EPOCHS = 30

In [11]:
categories = {}
for i, name in enumerate(np.unique(train_df.phase)):
    categories[i] = name

NUM_CLASSES = len(categories.keys())
print('Detected %d classes' % NUM_CLASSES)
categories

Detected 16 classes


{0: 't2',
 1: 't3',
 2: 't4',
 3: 't5',
 4: 't6',
 5: 't7',
 6: 't8',
 7: 't9+',
 8: 'tB',
 9: 'tEB',
 10: 'tHB',
 11: 'tM',
 12: 'tPB2',
 13: 'tPNa',
 14: 'tPNf',
 15: 'tSB'}

### model

In [12]:
# Create model
def create_model():
    # Load the EfficientNet B4 pre-trained model, excluding the top layer
    base_model = EfficientNetB4(weights='imagenet', include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, CHANNELS))

    # Freeze the base model layers
    base_model.trainable = False

    # Add a custom top layer for classification
    x = base_model.output
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.5)(x)
    predictions = Dense(NUM_CLASSES, activation='softmax')(x)

    # Create the model with EfficientNet B4 base and custom top layer
    model = tf.keras.models.Model(inputs=base_model.input, outputs=predictions)

    # Compile the model
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    metrics = ['accuracy']
    model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

    return model

model = create_model()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 224, 224, 3)  0           ['input_1[0][0]']                
                                                                                                  
 normalization (Normalization)  (None, 224, 224, 3)  7           ['rescaling[0][0]']              
                                                                                                  
 rescaling_1 (Rescaling)        (None, 224, 224, 3)  0           ['normalization[0][0]']      

 block2a_expand_conv (Conv2D)   (None, 112, 112, 14  3456        ['block1b_add[0][0]']            
                                4)                                                                
                                                                                                  
 block2a_expand_bn (BatchNormal  (None, 112, 112, 14  576        ['block2a_expand_conv[0][0]']    
 ization)                       4)                                                                
                                                                                                  
 block2a_expand_activation (Act  (None, 112, 112, 14  0          ['block2a_expand_bn[0][0]']      
 ivation)                       4)                                                                
                                                                                                  
 block2a_dwconv_pad (ZeroPaddin  (None, 113, 113, 14  0          ['block2a_expand_activation[0][0]
 g2D)     

 ivation)                                                                                         
                                                                                                  
 block2c_dwconv (DepthwiseConv2  (None, 56, 56, 192)  1728       ['block2c_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block2c_bn (BatchNormalization  (None, 56, 56, 192)  768        ['block2c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block2c_activation (Activation  (None, 56, 56, 192)  0          ['block2c_bn[0][0]']             
 )                                                                                                
          

 block3a_dwconv (DepthwiseConv2  (None, 28, 28, 192)  4800       ['block3a_dwconv_pad[0][0]']     
 D)                                                                                               
                                                                                                  
 block3a_bn (BatchNormalization  (None, 28, 28, 192)  768        ['block3a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block3a_activation (Activation  (None, 28, 28, 192)  0          ['block3a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block3a_se_squeeze (GlobalAver  (None, 192)         0           ['block3a_activation[0][0]']     
 agePoolin

 agePooling2D)                                                                                    
                                                                                                  
 block3c_se_reshape (Reshape)   (None, 1, 1, 336)    0           ['block3c_se_squeeze[0][0]']     
                                                                                                  
 block3c_se_reduce (Conv2D)     (None, 1, 1, 14)     4718        ['block3c_se_reshape[0][0]']     
                                                                                                  
 block3c_se_expand (Conv2D)     (None, 1, 1, 336)    5040        ['block3c_se_reduce[0][0]']      
                                                                                                  
 block3c_se_excite (Multiply)   (None, 28, 28, 336)  0           ['block3c_activation[0][0]',     
                                                                  'block3c_se_expand[0][0]']      
          

 block4a_se_reshape (Reshape)   (None, 1, 1, 336)    0           ['block4a_se_squeeze[0][0]']     
                                                                                                  
 block4a_se_reduce (Conv2D)     (None, 1, 1, 14)     4718        ['block4a_se_reshape[0][0]']     
                                                                                                  
 block4a_se_expand (Conv2D)     (None, 1, 1, 336)    5040        ['block4a_se_reduce[0][0]']      
                                                                                                  
 block4a_se_excite (Multiply)   (None, 14, 14, 336)  0           ['block4a_activation[0][0]',     
                                                                  'block4a_se_expand[0][0]']      
                                                                                                  
 block4a_project_conv (Conv2D)  (None, 14, 14, 112)  37632       ['block4a_se_excite[0][0]']      
          

                                                                                                  
 block4c_project_bn (BatchNorma  (None, 14, 14, 112)  448        ['block4c_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4c_drop (Dropout)         (None, 14, 14, 112)  0           ['block4c_project_bn[0][0]']     
                                                                                                  
 block4c_add (Add)              (None, 14, 14, 112)  0           ['block4c_drop[0][0]',           
                                                                  'block4b_add[0][0]']            
                                                                                                  
 block4d_expand_conv (Conv2D)   (None, 14, 14, 672)  75264       ['block4c_add[0][0]']            
          

                                                                                                  
 block4e_add (Add)              (None, 14, 14, 112)  0           ['block4e_drop[0][0]',           
                                                                  'block4d_add[0][0]']            
                                                                                                  
 block4f_expand_conv (Conv2D)   (None, 14, 14, 672)  75264       ['block4e_add[0][0]']            
                                                                                                  
 block4f_expand_bn (BatchNormal  (None, 14, 14, 672)  2688       ['block4f_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block4f_expand_activation (Act  (None, 14, 14, 672)  0          ['block4f_expand_bn[0][0]']      
 ivation) 

 ivation)                                                                                         
                                                                                                  
 block5b_dwconv (DepthwiseConv2  (None, 14, 14, 960)  24000      ['block5b_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block5b_bn (BatchNormalization  (None, 14, 14, 960)  3840       ['block5b_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5b_activation (Activation  (None, 14, 14, 960)  0          ['block5b_bn[0][0]']             
 )                                                                                                
          

 block5d_bn (BatchNormalization  (None, 14, 14, 960)  3840       ['block5d_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5d_activation (Activation  (None, 14, 14, 960)  0          ['block5d_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block5d_se_squeeze (GlobalAver  (None, 960)         0           ['block5d_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5d_se_reshape (Reshape)   (None, 1, 1, 960)    0           ['block5d_se_squeeze[0][0]']     
          

                                                                                                  
 block5f_se_squeeze (GlobalAver  (None, 960)         0           ['block5f_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5f_se_reshape (Reshape)   (None, 1, 1, 960)    0           ['block5f_se_squeeze[0][0]']     
                                                                                                  
 block5f_se_reduce (Conv2D)     (None, 1, 1, 40)     38440       ['block5f_se_reshape[0][0]']     
                                                                                                  
 block5f_se_expand (Conv2D)     (None, 1, 1, 960)    39360       ['block5f_se_reduce[0][0]']      
                                                                                                  
 block5f_s

                                                                                                  
 block6b_se_expand (Conv2D)     (None, 1, 1, 1632)   112608      ['block6b_se_reduce[0][0]']      
                                                                                                  
 block6b_se_excite (Multiply)   (None, 7, 7, 1632)   0           ['block6b_activation[0][0]',     
                                                                  'block6b_se_expand[0][0]']      
                                                                                                  
 block6b_project_conv (Conv2D)  (None, 7, 7, 272)    443904      ['block6b_se_excite[0][0]']      
                                                                                                  
 block6b_project_bn (BatchNorma  (None, 7, 7, 272)   1088        ['block6b_project_conv[0][0]']   
 lization)                                                                                        
          

                                                                                                  
 block6d_project_conv (Conv2D)  (None, 7, 7, 272)    443904      ['block6d_se_excite[0][0]']      
                                                                                                  
 block6d_project_bn (BatchNorma  (None, 7, 7, 272)   1088        ['block6d_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6d_drop (Dropout)         (None, 7, 7, 272)    0           ['block6d_project_bn[0][0]']     
                                                                                                  
 block6d_add (Add)              (None, 7, 7, 272)    0           ['block6d_drop[0][0]',           
                                                                  'block6c_add[0][0]']            
          

                                                                                                  
 block6f_drop (Dropout)         (None, 7, 7, 272)    0           ['block6f_project_bn[0][0]']     
                                                                                                  
 block6f_add (Add)              (None, 7, 7, 272)    0           ['block6f_drop[0][0]',           
                                                                  'block6e_add[0][0]']            
                                                                                                  
 block6g_expand_conv (Conv2D)   (None, 7, 7, 1632)   443904      ['block6f_add[0][0]']            
                                                                                                  
 block6g_expand_bn (BatchNormal  (None, 7, 7, 1632)  6528        ['block6g_expand_conv[0][0]']    
 ization)                                                                                         
          

                                                                                                  
 block7a_expand_conv (Conv2D)   (None, 7, 7, 1632)   443904      ['block6h_add[0][0]']            
                                                                                                  
 block7a_expand_bn (BatchNormal  (None, 7, 7, 1632)  6528        ['block7a_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block7a_expand_activation (Act  (None, 7, 7, 1632)  0           ['block7a_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block7a_dwconv (DepthwiseConv2  (None, 7, 7, 1632)  14688       ['block7a_expand_activation[0][0]
 D)       

 dense (Dense)                  (None, 512)          918016      ['global_average_pooling2d[0][0]'
                                                                 ]                                
                                                                                                  
 dropout (Dropout)              (None, 512)          0           ['dense[0][0]']                  
                                                                                                  
 dense_1 (Dense)                (None, 16)           8208        ['dropout[0][0]']                
                                                                                                  
Total params: 18,600,047
Trainable params: 926,224
Non-trainable params: 17,673,823
__________________________________________________________________________________________________


### data generators

In [24]:
# Define the image generator with augmentations
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
)

# Define the image generator with data augmentation disabled
val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

In [25]:
# Create the train and validation generators
train_generator = train_datagen.flow_from_dataframe(
    train_df,
    x_col=INPUT_COL,
    y_col=LABEL_COL,
    directory = DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    color_mode='grayscale',
    class_mode='categorical'
)

val_generator = val_datagen.flow_from_dataframe(
    val_df,
    x_col=INPUT_COL,
    y_col=LABEL_COL,
    directory = DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    color_mode='grayscale',
    class_mode='categorical'
)

Found 244767 validated image filenames belonging to 16 classes.
Found 61192 validated image filenames belonging to 16 classes.


### training

In [13]:
checkpoint_filepath = 'embryoAI_weights/weights_{}_{}.h5'.format('B4', 'embryo_16')
checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True, 
    verbose=1
)

In [14]:
# Load the saved model
if os.path.exists(checkpoint_filepath):
    model.load_weights(checkpoint_filepath)

In [31]:
history = model.fit(
    train_generator,
    workers=8,
    epochs= 5, #MAX_EPOCHS,
    validation_data=val_generator,
    callbacks=[checkpoint_callback]
)

### evaluation

In [30]:
# Create a test generator using validation image data generator (with data augmentation disabled)
test_generator = val_datagen.flow_from_dataframe(
    test_df,
    x_col=INPUT_COL,
    y_col=LABEL_COL,
    directory = DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    color_mode='grayscale',
    class_mode='categorical',
    shuffle=False
)

Found 137682 validated image filenames belonging to 16 classes.


In [None]:
# Load the saved model
if os.path.exists(checkpoint_filepath):
    model.load_weights(checkpoint_filepath)

In [32]:
# Evaluate the model on the test data
scores = model.evaluate(test_generator, verbose=1)
print(f'Test loss: {scores[0]}')
print(f'Test accuracy: {scores[1]}')

Test loss: 0.8278752565383911
Test accuracy: 0.7070423364639282


### prediction

In [None]:
# Load the saved model
if os.path.exists(checkpoint_filepath):
    model.load_weights(checkpoint_filepath)

In [24]:
def predict_class(file_path):
    # Load the image
    image_path = os.path.join(DATA_DIR, file_path)
    img = load_img(image_path, target_size=(IMG_WIDTH, IMG_HEIGHT))

    # Convert the image to a numpy array
    img_arr = img_to_array(img)

    # Preprocess the image
    img_arr = preprocess_input(img_arr)

    # Expand the dimensions of the image array to match the expected input shape of the model
    img_arr = np.expand_dims(img_arr, axis=0)
    
    # Make the prediction
    prediction = model.predict(img_arr, verbose=0)
    
#     # Print the predicted class
#     predicted_class = np.argmax(prediction)
#     predicted_class = categories.get(predicted_class)
    
    return prediction 

In [41]:
# # Parallelize prediction
# import concurrent.futures
# from tqdm.notebook import tqdm_notebook

# with concurrent.futures.ThreadPoolExecutor() as executor:
#     predictions = list(tqdm_notebook(executor.map(predict_class, test_df.path.values), total=len(test_df)))
    
# test_df["prediction"] = predictions

In [26]:
test_df["prediction"] = None
test_df["prediction"] = test_df.path.progress_apply(predict_class)

Pandas Apply:   0%|          | 0/137682 [00:00<?, ?it/s]

In [27]:
test_df["predicted_phase"] = test_df.prediction.progress_apply(lambda x: categories.get(np.argmax(x)))

Pandas Apply:   0%|          | 0/137682 [00:00<?, ?it/s]

In [28]:
test_df.to_csv("embryoAI_datasets/embryoAI_predictions.csv", index=False)

In [29]:
# file_path = "DS61-1/D2012.01.27_S0364_I132_WELL1_RUN20.jpeg"
# categories.get(np.argmax(predict_class(file_path)))

In [30]:
predictions = pd.read_csv("embryoAI_datasets/embryoAI_predictions.csv")
predictions

Unnamed: 0,case,frame,phase,path,flag,prediction,predicted_phase
0,DS61-1,19,tPNa,DS61-1/D2012.01.27_S0364_I132_WELL1_RUN19.jpeg,growing,[[1.3657552e-04 1.6470891e-07 1.1463476e-07 9....,tPNa
1,DS61-1,20,tPNa,DS61-1/D2012.01.27_S0364_I132_WELL1_RUN20.jpeg,growing,[[6.9838177e-05 1.4113060e-07 5.1814305e-07 3....,tPB2
2,DS61-1,29,tPNa,DS61-1/D2012.01.27_S0364_I132_WELL1_RUN29.jpeg,growing,[[8.6491083e-05 1.6546151e-07 5.0405244e-08 2....,tPB2
3,DS61-1,30,tPNa,DS61-1/D2012.01.27_S0364_I132_WELL1_RUN30.jpeg,growing,[[8.4943909e-05 1.0639483e-07 3.0971755e-08 8....,tPNa
4,DS61-1,32,tPNa,DS61-1/D2012.01.27_S0364_I132_WELL1_RUN32.jpeg,growing,[[9.5896976e-05 1.5052959e-07 8.7054017e-09 3....,tPNa
...,...,...,...,...,...,...,...
137677,RC755-9,462,tEB,RC755-9/D2013.07.08_S0875_I132_WELL9_RUN462.jpeg,expanded,[[8.3686167e-08 1.3817188e-07 2.7560696e-07 3....,tEB
137678,RC755-9,467,tEB,RC755-9/D2013.07.08_S0875_I132_WELL9_RUN467.jpeg,expanded,[[4.2381889e-06 1.4916056e-06 1.2730972e-04 8....,tEB
137679,RC755-9,470,tEB,RC755-9/D2013.07.08_S0875_I132_WELL9_RUN470.jpeg,expanded,[[1.27709545e-05 2.23672464e-06 1.46593040e-04...,tEB
137680,RC755-9,472,tEB,RC755-9/D2013.07.08_S0875_I132_WELL9_RUN472.jpeg,expanded,[[1.3986474e-04 1.9262983e-05 8.8179752e-04 1....,tEB
