# Create train, validation, and test sets for uddder segmentation task
- use frames from DCC videos
- use frames from VMS farm 
- sets must have differnt cows
- all collection date-cumputers are represented across sets

In [65]:
import os
import pandas as pd
import numpy as np
import shutil
from sklearn.model_selection import StratifiedGroupKFold

# settings
# splits 
splits = 5
# frame cap
frame_cap = 200
seed = 5

# directories
dirpath = os.getcwd()
newlbl_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r'udder_labels\labels\segments')
oldlbl_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r'udder_dcc\labels\segments')
img_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r'udder_video')


# data collection groups
group_dict = {20210625:{"lab": "A"}, \
              20211022: {"lab": "B"}, \
              20231117:{"guilherme":  "C", \
                        "maria":"D"}}

# reading and formating data
# list segment labels
labels_new = pd.DataFrame([[file.replace(".txt","") ,"_".join(file.split("_")[:-2])]for file in os.listdir(newlbl_dir)], columns = ["filename", "video_file"])
old_df =  pd.DataFrame([[file.split("_")[1],file.replace(".txt", ""), file.split("_")[0]] for file in os.listdir(oldlbl_dir)], columns = ["cow", "filename", "date"])
old_df["computer"] = "lab"

metadata_df = pd.read_csv(os.path.join(img_dir, r"video_metadata\video_metadata_20231117.csv"))
metadata_df["video_file"] = ["_".join([str(int(file.split("_")[0]))]+file.split("_")[1:]).replace(".bag", "") for file in metadata_df.filename]
metadata_df = metadata_df.drop(["time", "size", "filename"],  axis=1).drop_duplicates()
new_df = pd.merge(labels_new, metadata_df, on = ["video_file"], how = "left")
new_df = new_df[["cow", "filename", "date", "computer"]]

# merge the info from both collections
allframes_df = pd.concat([new_df, old_df], axis = 0, ignore_index = True)
allframes_df["group"] = [group_dict[int(x)][y] for x, y in allframes_df[["date", "computer"]].to_numpy()]

## Group and count how many labels and cows are in each collection date-computer

In [67]:
grouped_df = allframes_df[["date", "computer", "cow", "filename"]].groupby(["date", "computer", "cow"]).agg(["count"]).reset_index()
grouped_df.columns = ["_".join(name) if len(name[1]) >1 else name[0] for name in grouped_df.columns]
grouped_df[["date", "computer", "cow", "filename_count"]].groupby(["date", "computer"]).agg({ "cow": "count","filename_count": "sum"})

Unnamed: 0_level_0,Unnamed: 1_level_0,cow,filename_count
date,computer,Unnamed: 2_level_1,Unnamed: 3_level_1
20231117,guilherme,25,749
20231117,maria,25,750
20210625,lab,29,969
20211022,lab,34,1190


## split cows into n groups

In [120]:
fold_df = pd.DataFrame(columns=["cow", "fold_group"])
for group in ["A", "B", "C", "D"]:
    cows = np.unique(np.array(allframes_df[allframes_df.group == group]["cow"]))
    y = list(range(splits))*int(np.ceil((len(cows)/splits)))
    fold_list = y[:len(cows)]
    np.random.seed(seed)
    np.random.shuffle(fold_list)
    temp_df = pd.DataFrame({"cow":  cows, "fold_group": fold_list})
    fold_df = pd.concat([fold_df , temp_df], axis = 0, ignore_index = True)

In [124]:
merged = pd.merge(allframes_df, fold_df, on = "cow")

split_groups = list(range(splits))
np.random.seed(seed)
np.random.shuffle(split_groups)
proportions = np.multiply(splits,[0.6, 0.2, 0.2]).astype(int)
set_list = ["train"] * proportions[0] + ["val"] * proportions[1] + ["test"] * proportions[2]
set_df = pd.DataFrame({"fold_group":split_groups, "set_name": set_list})

merged2 = pd.merge(merged, set_df, on = "fold_group")
merged2.to_csv("segment_sets.csv", index = False)

In [125]:
merged_group = merged[["fold_group", "group", "cow"]].groupby(["fold_group", "group"]).agg(["count"]).reset_index()
merged_group.columns = ["_".join(name) if len(name[1]) >1 else name[0] for name in merged_group.columns]
merged_group.sort_values(by =["group", "fold_group"]).reset_index(drop = True)

Unnamed: 0,fold_group,group,cow_count
0,0,A,202
1,1,A,210
2,2,A,210
3,3,A,180
4,4,A,167
5,0,B,245
6,1,B,245
7,2,B,245
8,3,B,245
9,4,B,210


## verify cows and collection groups

In [126]:
# verify cows do not reapeatd
train_cows = set(merged2[merged2.set_name == "train"]["cow"].values)
test_cows = set(merged2[merged2.set_name == "test"]["cow"].values)
val_cows = set(merged2[merged2.set_name == "val"]["cow"].values)
print(val_cows.intersection(test_cows, train_cows))
print(train_cows.intersection(test_cows, val_cows))
print(test_cows.intersection(val_cows, train_cows))

# verify all 4 groups across
train_groups = np.unique(merged2[merged2.set_name == "train"]["group"])
test_groups = np.unique(merged2[merged2.set_name == "test"]["group"])
val_groups = np.unique(merged2[merged2.set_name == "test"]["group"])
print(train_groups)
print(test_groups)
print(val_groups)

set()
set()
set()
['A' 'B' 'C' 'D']
['A' 'B' 'C' 'D']
['A' 'B' 'C' 'D']
