<a href="https://colab.research.google.com/github/arjvik/BEEHealthy/blob/master/TrainModels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install dvc fastds >/dev/null 2>/dev/null
!git clone https://dagshub.com/arjvik/BEEHealthy
%cd BEEHealthy
!dvc pull bee_data.csv
!dvc pull bee_imgs

In [5]:
import numpy as np
import pandas as pd
df = pd.read_csv('bee_data.csv')

df["healthy"] = df.health == "healthy"
df["group"] = pd.to_numeric(df.file.str.split("_", n=1, expand=True)[0])
df

Unnamed: 0,file,date,time,location,zip code,subspecies,health,pollen_carrying,caste,healthy,group
0,041_066.png,8/28/18,16:07,"Alvin, TX, USA",77511,-1,hive being robbed,False,worker,False,41
1,041_072.png,8/28/18,16:07,"Alvin, TX, USA",77511,-1,hive being robbed,False,worker,False,41
2,041_073.png,8/28/18,16:07,"Alvin, TX, USA",77511,-1,hive being robbed,False,worker,False,41
3,041_067.png,8/28/18,16:07,"Alvin, TX, USA",77511,-1,hive being robbed,False,worker,False,41
4,041_059.png,8/28/18,16:07,"Alvin, TX, USA",77511,-1,hive being robbed,False,worker,False,41
...,...,...,...,...,...,...,...,...,...,...,...
5167,027_011.png,8/20/18,10:03,"San Jose, CA, USA",95124,-1,healthy,True,worker,True,27
5168,027_007.png,8/20/18,10:03,"San Jose, CA, USA",95124,-1,healthy,True,worker,True,27
5169,027_013.png,8/20/18,10:03,"San Jose, CA, USA",95124,-1,healthy,False,worker,True,27
5170,027_012.png,8/20/18,10:03,"San Jose, CA, USA",95124,-1,healthy,False,worker,True,27


In [6]:
import tensorflow as tf
from sklearn.model_selection import StratifiedGroupKFold
from tensorflow.data import Dataset
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet50 import preprocess_input

fold_datasets = []

for train_idx, test_idx in (StratifiedGroupKFold(n_splits=3).split(df.file, df.healthy, df.group)):
    train_df = df.iloc[train_idx].copy()
    test_df = df.iloc[test_idx].copy()
    train_df["weight"] = train_df.healthy.value_counts().to_numpy()[np.uint8(train_df.healthy.values)] / train_df.healthy.value_counts().min()

    train_gen = ImageDataGenerator(
                    rotation_range=20, width_shift_range=20, height_shift_range=20,
                    brightness_range=(0.6, 1.4), shear_range=10, zoom_range=0.05, channel_shift_range = 30,
                    horizontal_flip=True, vertical_flip=True, preprocessing_function=preprocess_input
                ).flow_from_dataframe(
                    dataframe=train_df, directory='bee_imgs/', x_col='file', y_col=['healthy', 'weight'],
                    target_size=(224, 224), class_mode='raw',
                    save_to_dir='generated_imgs/'
                )
    train_ds = (Dataset.from_generator(lambda: train_gen,
                                       output_types=(tf.bool, tf.float32),
                                       output_shapes=((32, 224, 224, 3), (32, 2)))
                       .map(lambda x, y: (x, y[:, 0], y[:, 1])))
    
    test_gen  = ImageDataGenerator(
                    preprocessing_function=preprocess_input
                ).flow_from_dataframe(
                    dataframe=test_df, directory='bee_imgs/', x_col='file', y_col='healthy',
                    target_size=(224, 224), class_mode='raw'
                )
    test_ds = Dataset.from_generator(lambda: test_gen, output_types=(tf.bool, tf.float32), output_shapes=((32, 224, 224, 3), (32,)))

    fold_datasets.append((train_ds, test_ds))

Found 3441 validated image filenames.
Found 1731 validated image filenames.
Found 3456 validated image filenames.
Found 1716 validated image filenames.
Found 3447 validated image filenames.
Found 1725 validated image filenames.
