## Environment setup

In [None]:
# Install necessary libraries
# Restart the runtime after downgrading tensorflow
!pip install -q xgboost
!pip install -q tensorflow==2.12 # Our NASNetLarge weight works with TensorFlow and Keras v2.

In [None]:
# Download and unzip required files and sample images from Google Drive
!pip install -q gdown
!gdown --id 15Y_R_rNcRym61EPp_jAKqHxkNuXEHX7y
!unzip -q Blastocyst_Prediction.zip

In [None]:
# Import necessary libraries
%matplotlib inline
import os, glob, cv2
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
import tensorflow as tf
from tqdm.notebook import tqdm
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.nasnet import NASNetLarge
import xgboost

import sklearn
import os
import torch
from xgboost import XGBClassifier
import warnings
warnings.filterwarnings('ignore')

annotation_label_dict = {0:'1cell',1:'1cell_PN',
                         2:'2cell_G1',3:'2cell_G2',4:'2cell_G3',
                         5:'3cell_G1',6:'3cell_G2',7:'3cell_G3',
                         8:'4cell_G1',9:'4cell_G2',10:'4cell_G3',
                         11:'5-7cell_G1',12:'5-7cell_G2',13:'5-7cell_G3',
                         14:'8-cell_G1',15:'8-cell_G2',16:'8-cell_G3'}

timelapse_min = 24
timelapse_max = 64
time_list = np.arange(timelapse_min, timelapse_max+15/60.0, 15/60.0)

def convert_to_wide_format(df_long,time_name,label_name):
    df_long[label_name] = df_long[label_name].astype(float)
    df_long = df_long.drop_duplicates(['ID',time_name]).reset_index(drop=True)
    df_wide = df_long[['ID',time_name,label_name]].set_index(['ID',time_name]).unstack().reset_index().T.reset_index().drop(columns=['level_0']).set_index(time_name).T.rename(columns={'':'ID'})
    return df_wide

# grade and cell stage list settings
G1_label_list = [2,5,8,11,14]
G2_label_list = [3,6,9,12,15]
G3_label_list = [4,7,10,13,16]
cell_2_8_list = G1_label_list+G2_label_list+G3_label_list
cell_1_list = [0,1]
cell_2_list = [G1_label_list[0],G2_label_list[0],G3_label_list[0]]
cell_3_list = [G1_label_list[1],G2_label_list[1],G3_label_list[1]]
cell_4_list = [G1_label_list[2],G2_label_list[2],G3_label_list[2]]
cell_5_list = [G1_label_list[3],G2_label_list[3],G3_label_list[3]]
cell_8_list = [G1_label_list[4],G2_label_list[4],G3_label_list[4]]

def process_df_data(df_data,timelapse_min=24, timelapse_max=64, interval=15):
    time_list = np.arange(timelapse_min, timelapse_max, interval/60.0)
    df_data['ratio_1cell'] = df_data[time_list].isin(cell_1_list).sum(axis=1)/df_data[time_list].notna().sum(axis=1)
    df_data['ratio_2cell'] = df_data[time_list].isin(cell_2_list).sum(axis=1)/df_data[time_list].notna().sum(axis=1)
    df_data['ratio_3cell'] = df_data[time_list].isin(cell_3_list).sum(axis=1)/df_data[time_list].notna().sum(axis=1)
    df_data['ratio_4cell'] = df_data[time_list].isin(cell_4_list).sum(axis=1)/df_data[time_list].notna().sum(axis=1)
    df_data['ratio_5cell'] = df_data[time_list].isin(cell_5_list).sum(axis=1)/df_data[time_list].notna().sum(axis=1)
    df_data['ratio_8cell'] = df_data[time_list].isin(cell_8_list).sum(axis=1)/df_data[time_list].notna().sum(axis=1)
    df_data['ratio_G1'] = df_data[time_list].isin(G1_label_list).sum(axis=1)/df_data[time_list].isin(cell_2_8_list).sum(axis=1)
    df_data['ratio_G2'] = df_data[time_list].isin(G2_label_list).sum(axis=1)/df_data[time_list].isin(cell_2_8_list).sum(axis=1)
    df_data['ratio_G3'] = df_data[time_list].isin(G3_label_list).sum(axis=1)/df_data[time_list].isin(cell_2_8_list).sum(axis=1)

    df_data['ratio_2cell_G1'] = 0
    df_data['ratio_2cell_G2'] = 0
    df_data['ratio_2cell_G3'] = 0
    df_data['ratio_3cell_G1'] = 0
    df_data['ratio_3cell_G2'] = 0
    df_data['ratio_3cell_G3'] = 0
    df_data['ratio_4cell_G1'] = 0
    df_data['ratio_4cell_G2'] = 0
    df_data['ratio_4cell_G3'] = 0
    df_data['ratio_5cell_G1'] = 0
    df_data['ratio_5cell_G2'] = 0
    df_data['ratio_5cell_G3'] = 0
    df_data['ratio_8cell_G1'] = 0
    df_data['ratio_8cell_G2'] = 0
    df_data['ratio_8cell_G3'] = 0

    df_data.loc[df_data['ratio_2cell']!=0,'ratio_2cell_G1'] = df_data.loc[df_data['ratio_2cell']!=0,time_list].isin([G1_label_list[0]]).sum(axis=1)/df_data.loc[df_data['ratio_2cell']!=0,time_list].isin(cell_2_list).sum(axis=1)
    df_data.loc[df_data['ratio_2cell']!=0,'ratio_2cell_G2'] = df_data.loc[df_data['ratio_2cell']!=0,time_list].isin([G2_label_list[0]]).sum(axis=1)/df_data.loc[df_data['ratio_2cell']!=0,time_list].isin(cell_2_list).sum(axis=1)
    df_data.loc[df_data['ratio_2cell']!=0,'ratio_2cell_G3'] = df_data.loc[df_data['ratio_2cell']!=0,time_list].isin([G3_label_list[0]]).sum(axis=1)/df_data.loc[df_data['ratio_2cell']!=0,time_list].isin(cell_2_list).sum(axis=1)
    df_data.loc[df_data['ratio_3cell']!=0,'ratio_3cell_G1'] = df_data.loc[df_data['ratio_3cell']!=0,time_list].isin([G1_label_list[1]]).sum(axis=1)/df_data.loc[df_data['ratio_3cell']!=0,time_list].isin(cell_3_list).sum(axis=1)
    df_data.loc[df_data['ratio_3cell']!=0,'ratio_3cell_G2'] = df_data.loc[df_data['ratio_3cell']!=0,time_list].isin([G2_label_list[1]]).sum(axis=1)/df_data.loc[df_data['ratio_3cell']!=0,time_list].isin(cell_3_list).sum(axis=1)
    df_data.loc[df_data['ratio_3cell']!=0,'ratio_3cell_G3'] = df_data.loc[df_data['ratio_3cell']!=0,time_list].isin([G3_label_list[1]]).sum(axis=1)/df_data.loc[df_data['ratio_3cell']!=0,time_list].isin(cell_3_list).sum(axis=1)
    df_data.loc[df_data['ratio_4cell']!=0,'ratio_4cell_G1'] = df_data.loc[df_data['ratio_4cell']!=0,time_list].isin([G1_label_list[2]]).sum(axis=1)/df_data.loc[df_data['ratio_4cell']!=0,time_list].isin(cell_4_list).sum(axis=1)
    df_data.loc[df_data['ratio_4cell']!=0,'ratio_4cell_G2'] = df_data.loc[df_data['ratio_4cell']!=0,time_list].isin([G2_label_list[2]]).sum(axis=1)/df_data.loc[df_data['ratio_4cell']!=0,time_list].isin(cell_4_list).sum(axis=1)
    df_data.loc[df_data['ratio_4cell']!=0,'ratio_4cell_G3'] = df_data.loc[df_data['ratio_4cell']!=0,time_list].isin([G3_label_list[2]]).sum(axis=1)/df_data.loc[df_data['ratio_4cell']!=0,time_list].isin(cell_4_list).sum(axis=1)
    df_data.loc[df_data['ratio_5cell']!=0,'ratio_5cell_G1'] = df_data.loc[df_data['ratio_5cell']!=0,time_list].isin([G1_label_list[3]]).sum(axis=1)/df_data.loc[df_data['ratio_5cell']!=0,time_list].isin(cell_5_list).sum(axis=1)
    df_data.loc[df_data['ratio_5cell']!=0,'ratio_5cell_G2'] = df_data.loc[df_data['ratio_5cell']!=0,time_list].isin([G2_label_list[3]]).sum(axis=1)/df_data.loc[df_data['ratio_5cell']!=0,time_list].isin(cell_5_list).sum(axis=1)
    df_data.loc[df_data['ratio_5cell']!=0,'ratio_5cell_G3'] = df_data.loc[df_data['ratio_5cell']!=0,time_list].isin([G3_label_list[3]]).sum(axis=1)/df_data.loc[df_data['ratio_5cell']!=0,time_list].isin(cell_5_list).sum(axis=1)
    df_data.loc[df_data['ratio_8cell']!=0,'ratio_8cell_G1'] = df_data.loc[df_data['ratio_8cell']!=0,time_list].isin([G1_label_list[4]]).sum(axis=1)/df_data.loc[df_data['ratio_8cell']!=0,time_list].isin(cell_8_list).sum(axis=1)
    df_data.loc[df_data['ratio_8cell']!=0,'ratio_8cell_G2'] = df_data.loc[df_data['ratio_8cell']!=0,time_list].isin([G2_label_list[4]]).sum(axis=1)/df_data.loc[df_data['ratio_8cell']!=0,time_list].isin(cell_8_list).sum(axis=1)
    df_data.loc[df_data['ratio_8cell']!=0,'ratio_8cell_G3'] = df_data.loc[df_data['ratio_8cell']!=0,time_list].isin([G3_label_list[4]]).sum(axis=1)/df_data.loc[df_data['ratio_8cell']!=0,time_list].isin(cell_8_list).sum(axis=1)

    df_data['grade_2cell'] = np.argmax(np.stack([df_data['ratio_2cell_G1'].values,df_data['ratio_2cell_G2'].values,df_data['ratio_2cell_G3'].values],axis=1),axis=1)+1
    df_data['grade_3cell'] = np.argmax(np.stack([df_data['ratio_3cell_G1'].values,df_data['ratio_3cell_G2'].values,df_data['ratio_3cell_G3'].values],axis=1),axis=1)+1
    df_data['grade_4cell'] = np.argmax(np.stack([df_data['ratio_4cell_G1'].values,df_data['ratio_4cell_G2'].values,df_data['ratio_4cell_G3'].values],axis=1),axis=1)+1
    df_data['grade_5cell'] = np.argmax(np.stack([df_data['ratio_5cell_G1'].values,df_data['ratio_5cell_G2'].values,df_data['ratio_5cell_G3'].values],axis=1),axis=1)+1
    df_data['grade_8cell'] = np.argmax(np.stack([df_data['ratio_8cell_G1'].values,df_data['ratio_8cell_G2'].values,df_data['ratio_8cell_G3'].values],axis=1),axis=1)+1

    # Grade at each cell stage is substituted and unified.
    df_data.loc[df_data['grade_2cell']==1,time_list] = df_data.loc[df_data['grade_2cell']==1,time_list].replace({2:4,3:4,4:4})
    df_data.loc[df_data['grade_2cell']==2,time_list] = df_data.loc[df_data['grade_2cell']==2,time_list].replace({2:3,3:3,4:3})
    df_data.loc[df_data['grade_2cell']==3,time_list] = df_data.loc[df_data['grade_2cell']==3,time_list].replace({2:2,3:2,4:2})
    df_data.loc[df_data['grade_3cell']==1,time_list] = df_data.loc[df_data['grade_3cell']==1,time_list].replace({5:7,6:7,7:7})
    df_data.loc[df_data['grade_3cell']==2,time_list] = df_data.loc[df_data['grade_3cell']==2,time_list].replace({5:6,6:6,7:6})
    df_data.loc[df_data['grade_3cell']==3,time_list] = df_data.loc[df_data['grade_3cell']==3,time_list].replace({5:5,6:5,7:5})
    df_data.loc[df_data['grade_4cell']==1,time_list] = df_data.loc[df_data['grade_4cell']==1,time_list].replace({8:10,9:10,10:10})
    df_data.loc[df_data['grade_4cell']==2,time_list] = df_data.loc[df_data['grade_4cell']==2,time_list].replace({8:9,9:9,10:9})
    df_data.loc[df_data['grade_4cell']==3,time_list] = df_data.loc[df_data['grade_4cell']==3,time_list].replace({8:8,9:8,10:8})
    df_data.loc[df_data['grade_5cell']==1,time_list] = df_data.loc[df_data['grade_5cell']==1,time_list].replace({11:13,12:13,13:13})
    df_data.loc[df_data['grade_5cell']==2,time_list] = df_data.loc[df_data['grade_5cell']==2,time_list].replace({11:12,12:12,13:12})
    df_data.loc[df_data['grade_5cell']==3,time_list] = df_data.loc[df_data['grade_5cell']==3,time_list].replace({11:11,12:11,13:11})
    df_data.loc[df_data['grade_8cell']==1,time_list] = df_data.loc[df_data['grade_8cell']==1,time_list].replace({14:16,15:16,16:16})
    df_data.loc[df_data['grade_8cell']==2,time_list] = df_data.loc[df_data['grade_8cell']==2,time_list].replace({14:15,15:15,16:15})
    df_data.loc[df_data['grade_8cell']==3,time_list] = df_data.loc[df_data['grade_8cell']==3,time_list].replace({14:14,15:14,16:14})
    df_data.drop(columns=[f'ratio_{cell}cell_G{grade}' for cell in [2,3,4,5,8] for grade in [1,2,3]]+[f'grade_{cell}cell' for cell in [2,3,4,5,8]],inplace=True)
    return df_data

##Preprocess

In [None]:
# Preprocess the Images

# File Paths Configuration
# The CSV file includes the following columns: 'ID', 'age' (the age at egg retrieval), 'time (hpi)' (post-fertilization hour), and 'img_path' (original image file path).
csv_path = '/content/sample_img_list.csv'
original_image_dir = '/content/sample_images'
preprocess_image_dir = '/content/sample_images_preprocessed'

# Settings : Adjust to the time-lapse images for each incubator.

extend_r = 15  # Pixels to extend the cropping area from the recognized area
export_img_size = 331  # Exported image size (square)
min_recognition_size = 150  # Minimum size of recognized objects
max_recognition_size = 600  # Maximum size of recognized objects
setting_list = [[10, 50, 50],[5, 50, 50], [4, 40, 40], [7, 50, 50]]  # Contour detection settings [kernel_size, canny_th1, canny_th2]
level_adjustment = False  # Enable level_adjustment (set it to True) when the image is dark.

# Start time for benchmarking
start = datetime.now()

# Create directories for saving processed images
os.makedirs(preprocess_image_dir, exist_ok=True)

# Load CSV
df_process = pd.read_csv(csv_path)
df_process['preprocess_img_path'] = df_process['img_path'].str.replace(original_image_dir, preprocess_image_dir)

# Initialize execution status
exe = []

for i, row in df_process.iterrows():
    path = row['img_path']
    save_path = row['preprocess_img_path']
    img = cv2.imread(path, cv2.IMREAD_COLOR)

    if img is None:
        exe.append(0)
        continue

    # Preprocess image: grayscale conversion and level adjustment
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    if level_adjustment:
        add_grayscale = 255 - np.percentile(img_gray, 98)
        img_gray = np.clip(img_gray + add_grayscale, 0, 255).astype(np.uint8)

    # Contour detection and processing
    radius, x, y = None, None, None
    for kernel_size, canny_th1, canny_th2 in setting_list:
        # Canny edge detection
        edges = cv2.Canny(img_gray, canny_th1, canny_th2)
        # Morphological operations
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        processed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel)
        processed = cv2.morphologyEx(processed, cv2.MORPH_OPEN, kernel)

        # Extract contours
        contours, _ = cv2.findContours(processed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            continue

        # Sort contours by area and find the largest valid contour
        contours = sorted(contours, key=cv2.contourArea, reverse=True)
        for cnt in contours:
            if len(cnt) > 5:
                (x, y), (width, height), _ = cv2.fitEllipse(cnt)
                radius = max(width, height) / 2
                if min_recognition_size <= radius <= max_recognition_size:
                    break
                else:
                    radius = None
            else:
                radius = None
        if radius is not None:
            break

    if radius is None:
        exe.append(0)
        continue

    # Extend the cropping radius
    radius += extend_r
    h, w = img_gray.shape
    center = (int(x), int(y))
    export_size = int(radius * 2)

    # Ensure cropping area is within image bounds
    if (center[1] - radius < 0 or center[1] + radius > h or
        center[0] - radius < 0 or center[0] + radius > w):
        exe.append(0)
        continue

    # Crop and resize the image
    cropped_img = img_gray[
        max(0, center[1] - export_size // 2):min(h, center[1] + export_size // 2),
        max(0, center[0] - export_size // 2):min(w, center[0] + export_size // 2)
    ]
    resized_img = cv2.resize(cropped_img, (export_img_size, export_img_size))

    # Create a circular mask
    mask = np.zeros((export_img_size, export_img_size), dtype=np.uint8)
    cv2.circle(mask, (export_img_size // 2, export_img_size // 2), export_img_size // 2, 255, -1)
    masked_img = cv2.bitwise_and(resized_img, resized_img, mask=mask)

    # Save the processed image
    try:
        cv2.imwrite(save_path, masked_img)
        exe.append(1)
    except Exception as e:
        exe.append(0)
        print(f"Error saving image {save_path}: {e}")

# Record the end time and calculate processing time
end = datetime.now()
processing_time = end - start

# Update the DataFrame with execution status
df_process['execution'] = exe
df_process.loc[df_process['execution'] == 0, 'preprocess_img_path'] = np.nan
success_rate = np.mean(exe) * 100
print(f"Processed {len(df_process)} images with a success rate of {success_rate:.1f}%.")
print(f"Processing time: {processing_time}")

# Display the DataFrame
df_process.head()

In [None]:
# Check the Preprocessed Images

# Setting: Number of images per figure
images_per_figure = 48
rows, cols = 6, 8

for start in range(0, len(df_process), images_per_figure):
    end = min(start + images_per_figure, len(df_process))
    chunk = df_process.iloc[start:end]
    fig, axes = plt.subplots(rows, cols, figsize=(20, 15))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        idx = start + i
        if idx < len(df_process):
            img_path = df_process.iloc[idx]['preprocess_img_path']
            try:
                img = Image.open(img_path).convert('RGB')
                ax.imshow(img)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                ax.text(0.5, 0.5, 'Error', ha='center', va='center', fontsize=12)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

## Embryo Image Auto-Annotation

In [None]:
# Embryo Image Auto-Annotation Using Fine-Tuned NASNet-A Large

# Exclude unprocessed images
df_process = df_process.dropna(subset=['preprocess_img_path']).reset_index(drop=True)

# Make dataset from DataFrame (image file paths)
pred_dataset = ImageDataGenerator(rescale=1.0/255).flow_from_dataframe(df_process, directory='', x_col='preprocess_img_path', target_size=(331, 331), color_mode='rgb', classes=None, class_mode='input', batch_size=32, shuffle=False, seed=None, save_to_dir=None, save_prefix='', save_format='png', subset=None, interpolation='nearest', drop_duplicates=False)

# Load the Model and Fine-Tuned Weight
model_path =  f'/content/NASNet_embryo_17class'
model = tf.keras.models.load_model(model_path)

# Auto-Annotation Using Fine-Tuned NASNet-A Large
features = model.predict(pred_dataset)
df_process['annotation'] = np.argmax(features,axis=1)
df_process['annotation_label'] = df_process['annotation'].map(annotation_label_dict)
display(df_process)

## Blastocyst Prediction

In [None]:
# Dataset Creation for XGBoost Predictions
df_data = pd.merge(df_process[['ID','age']].drop_duplicates(),convert_to_wide_format(df_process,'time(hpi)','annotation'))
df_data = process_df_data(df_data)
display(df_data)

# Loading configurations for prediction models
df_config = pd.read_pickle('/content/df_config.pkl')

# Prediction with XGBoost
df_result = df_data[['ID','age',64.0]]
for i in range(len(df_config)):
    target_label = df_config['target_label'].values[i]
    list_features = df_config['list_var'].values[i]
    list_features = [float(x) if x.replace('.', '', 1).isdigit() else x for x in list_features]
    xgb = XGBClassifier()
    xgb.load_model(f'/content/xgb_{target_label}.model')
    df_result[target_label] = xgb.predict_proba(df_data[list_features].astype(float))[:,1]
df_result['annotation at 64hpi'] = df_result[64.0].map(annotation_label_dict)
df_result  = df_result[['ID','age','annotation at 64hpi']+df_config['target_label'].values.tolist()]
display(df_result)