# Create train, validation, and test sets for frame classification task
- use frames from DCC videos
- use frames from VMS farm 
- sets must have differnt cows
- all collection date-cumputers are represented across sets

In [1]:
import os
import pandas as pd
import numpy as np
import shutil
from sklearn.model_selection import StratifiedGroupKFold

# settings
# splits 
splits = 5
# frame cap
frame_cap =200
seed = 5

# directories
dirpath = os.getcwd()
lbl_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r'udder_labels')
img_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r'udder_video')
old_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r'udder_dcc')

# data collection groups
group_dict = {20210625:{"lab": "A"}, \
              20211022: {"lab": "B"}, \
              20231117:{"guilherme":  "C", \
                        "maria":"D"}}

# reading and formating data
# new collection  is at the VMS farm
class_df = pd.read_csv(os.path.join(lbl_dir, "frame_class_list.csv"))
class_df["video_file"] = ["_".join(file.split("_")[:-2]) for file in class_df.filename]
metadata_df = pd.read_csv(os.path.join(img_dir, r"video_metadata\video_metadata_20231117.csv"))
metadata_df["video_file"] = ["_".join([str(int(file.split("_")[0]))]+file.split("_")[1:]).replace(".bag", "") for file in metadata_df.filename]
metadata_df = metadata_df.drop(["time", "size", "filename"],  axis=1).drop_duplicates()
new_df = pd.merge(class_df, metadata_df, on = ["cow", "video_file"], how = "left")
new_df = new_df[["cow","frame_class", "filename", "date", "computer"]]
# old collection is at dcc
old_frames_df = pd.read_csv(os.path.join(old_dir, "class_filenames.csv"))
old_frames_df["filename"] = [file.replace(".tif", "") for file in old_frames_df.filename]
old_frames_df["date"] = [int(rd) for rd in old_frames_df["round"]]
old_frames_df["computer"] = "lab"
old_frames_df = old_frames_df.rename({"cowID": "cow"} , axis = 1)
old_frames_df["frame_class"] = [1 if frame == "good" else 0 if frame == "empty" else 3 for frame in old_frames_df.category]
old_df = old_frames_df[old_frames_df.frame_class != 3][["cow","frame_class", "filename", "date", "computer"]]

# merge the info from both collections
allframes_df = pd.concat([new_df, old_df], axis = 0, ignore_index = True)

## Group and count how many frames and cows are in each collection date-computer

In [2]:
# group and count how many frames ando cows are in each collection date-computer
grouped_df = allframes_df[["date", "computer", "cow", "frame_class"]].groupby(["date", "computer", "cow"]).agg(["count", "sum"]).reset_index()
grouped_df.columns = ["_".join(name) if len(name[1]) >1 else name[0] for name in grouped_df.columns]
grouped_df["frame_class_diff"] = grouped_df.frame_class_count - grouped_df.frame_class_sum
grouped_df[["date", "computer", "frame_class_sum","frame_class_diff"]].groupby(["date", "computer"]).agg(["count", "min", "max", "median", "mean"])

Unnamed: 0_level_0,Unnamed: 1_level_0,frame_class_sum,frame_class_sum,frame_class_sum,frame_class_sum,frame_class_sum,frame_class_diff,frame_class_diff,frame_class_diff,frame_class_diff,frame_class_diff
Unnamed: 0_level_1,Unnamed: 1_level_1,count,min,max,median,mean,count,min,max,median,mean
date,computer,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
20210625,lab,29,5,258,96.0,104.517241,29,0,891,0.0,56.37931
20211022,lab,34,44,282,133.0,127.794118,34,0,391,0.0,12.794118
20231117,guilherme,25,160,799,559.0,508.24,25,0,631,66.0,144.44
20231117,maria,25,43,1376,500.0,541.36,25,0,1405,136.0,292.0


## Select a maximum of n frames from each cow-frame_class

In [3]:
# filter a max
filtered_frames_df = pd.DataFrame(columns = ["cow", "group", "frame_class","date", "computer","filename"])
for cow in grouped_df.cow:
    date = grouped_df[grouped_df.cow == cow].date.values[0]
    computer = grouped_df[grouped_df.cow == cow].computer.values[0]
    group = group_dict[date][computer]
    for frame_class in [0,1]:
        cow_frames = list(allframes_df[(allframes_df.cow == cow) & (allframes_df.frame_class == frame_class)].filename)
        num = len(cow_frames)
        
        if num > 0:
            np.random.seed(seed)
            np.random.shuffle(cow_frames)
            if num > frame_cap:
                selected_frames = cow_frames[:frame_cap] 
                temp_df = pd.DataFrame([[cow, group, frame_class, date, computer]]*frame_cap, columns = ["cow", "group", "frame_class","date", "computer"], index = range(frame_cap))
                temp_df["filename"] = selected_frames
            else:
                temp_df = pd.DataFrame([[cow, group, frame_class, date, computer]]*num, columns = ["cow", "group", "frame_class","date", "computer"], index = range(num))
                temp_df["filename"] = cow_frames

            filtered_frames_df = pd.concat([filtered_frames_df, temp_df], axis = 0, ignore_index = True)
            

## Create n fold groups with similar frame_class proportions

In [4]:
fold_df = pd.DataFrame(columns=["cow", "fold_group"])
for group in ["A", "B", "C", "D"]:
    x = np.array(filtered_frames_df[filtered_frames_df.group == group]["filename"])
    group_cows = np.array(list(filtered_frames_df[filtered_frames_df.group == group]["cow"]))
    group_class = np.array(list(filtered_frames_df[filtered_frames_df.group == group]["frame_class"]))
    sgkf = StratifiedGroupKFold(n_splits=5)
    sgkf.get_n_splits(x, group_class)
    test_list = []
    fold_list = []
    for i, (train_index, test_index) in enumerate(sgkf.split(x, group_class, group_cows)):
        test_cows = np.unique(group_cows[test_index])
        test_list.extend(list(test_cows))
        fold_list.extend([i]*len(test_cows))
    temp_df = pd.DataFrame({"cow": test_list, "fold_group": fold_list})
    fold_df = pd.concat([fold_df , temp_df], axis = 0, ignore_index = True)
    
merged = pd.merge(filtered_frames_df, fold_df, on = "cow")

split_groups = list(range(splits))
np.random.seed(seed)
np.random.shuffle(split_groups)
proportions = np.multiply(splits,[0.6, 0.2, 0.2]).astype(int)
set_list = ["train"] * proportions[0] + ["val"] * proportions[1] + ["test"] * proportions[2]
set_df = pd.DataFrame({"fold_group":split_groups, "set_name": set_list})

merged2 = pd.merge(merged, set_df, on = "fold_group")
merged2.to_csv("frameclass_sets.csv", index = False)

In [5]:
merged_group = merged[["fold_group", "group", "frame_class"]].groupby(["fold_group", "group"]).agg(["sum", "count"]).reset_index()
merged_group.columns = ["_".join(name) if len(name[1]) >1 else name[0] for name in merged_group.columns]
merged_group.sort_values(by =["group", "fold_group"]).reset_index(drop = True)

Unnamed: 0,fold_group,group,frame_class_sum,frame_class_count
0,0,A,468,668
1,1,A,559,659
2,2,A,643,795
3,3,A,652,852
4,4,A,651,890
5,0,B,870,899
6,1,B,875,1075
7,2,B,832,832
8,3,B,832,847
9,4,B,854,854


In [6]:
merged2_group = merged2[["set_name", "frame_class"]].groupby("set_name").agg(["sum", "count"]).reset_index()
merged2_group.columns = ["_".join(name) if len(name[1]) >1 else name[0] for name in merged2_group.columns]
merged2_group["prop"] = (merged2_group["frame_class_sum"] / merged2_group["frame_class_count"])
merged2_group

Unnamed: 0,set_name,frame_class_sum,frame_class_count,prop
0,test,3444,4778,0.720804
1,train,10120,14164,0.714487
2,val,3475,4687,0.741412


## verify cows and collection groups

In [7]:
# verify cows do not reapeatd
train_cows = set(merged2[merged2.set_name == "train"]["cow"].values)
test_cows = set(merged2[merged2.set_name == "test"]["cow"].values)
val_cows = set(merged2[merged2.set_name == "val"]["cow"].values)
print(val_cows.intersection(test_cows, train_cows))
print(train_cows.intersection(test_cows, val_cows))
print(test_cows.intersection(val_cows, train_cows))

# verify all 4 groups across
train_groups = np.unique(merged2[merged2.set_name == "train"]["group"])
test_groups = np.unique(merged2[merged2.set_name == "test"]["group"])
val_groups = np.unique(merged2[merged2.set_name == "test"]["group"])
print(train_groups)
print(test_groups)
print(val_groups)

set()
set()
set()
['A' 'B' 'C' 'D']
['A' 'B' 'C' 'D']
['A' 'B' 'C' 'D']
