In [1]:
import os
import shutil
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from collections import defaultdict
from pathlib import Path

from python import classes

In [2]:
classification = ""
cv_folds = 5

input_dir = "crop"
output_dir = "data"

In [3]:
# Parameters
classification = "models"


In [4]:
assert classification in ['models', 'types'], "classification must be one of ['models', 'types']"

In [5]:
output_dir = Path(output_dir) / f"{classification}"

In [6]:
# Gather file paths for each class
data = defaultdict(list)
for model_dir in os.listdir(input_dir):
    class_path = os.path.join(input_dir, model_dir)
    if os.path.isdir(class_path):
        for img_file in os.listdir(class_path):
            key = classes.classes_map[classification][model_dir]
            data[key].append(os.path.join(class_path, img_file))
            
df = pd.DataFrame(map(lambda kv: (kv[0], len(kv[1])), data.items()), columns=['model', 'total'])
df.set_index('model', inplace=True)

print(f"Classes found: {len(df)}")

Classes found: 74


In [7]:
# Create data and label arrays
file_paths = []
labels = []
for label, paths in data.items():
    file_paths.extend(paths)
    labels.extend([label] * len(paths))

file_paths = np.array(file_paths)
labels = np.array(labels)

In [8]:
# Stratified K-Fold cross-validation
skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)

shutil.rmtree(output_dir, ignore_errors=True)
for fold_idx, (_, test_idx) in enumerate(skf.split(file_paths, labels)):
    fold_path = os.path.join(output_dir, f"cv{fold_idx + 1}")
    os.makedirs(fold_path, exist_ok=True)
    df[f"cv{fold_idx + 1}"] = 0

    for idx in test_idx:
        class_label = labels[idx]
        class_fold_path = os.path.join(fold_path, class_label)
        os.makedirs(class_fold_path, exist_ok=True)
        shutil.copy(file_paths[idx], class_fold_path)
        df.at[f"{class_label}", f"cv{fold_idx + 1}"] += 1 
    print(f"cv{fold_idx + 1} completed.")

cv1 completed.


cv2 completed.


cv3 completed.


cv4 completed.


cv5 completed.


In [9]:
df

Unnamed: 0_level_0,total,cv1,cv2,cv3,cv4,cv5
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
A10,690,138,138,138,138,138
A400M,463,93,93,93,92,92
AG600,259,52,52,51,52,52
AH64,376,75,75,76,75,75
An124,145,29,29,29,29,29
...,...,...,...,...,...,...
WZ7,95,19,19,19,19,19
XB70,172,34,34,34,35,35
Y20,197,40,40,39,39,39
YF23,136,27,27,28,27,27


In [10]:
df_sum = df.sum()
df_sum

total    31917
cv1       6384
cv2       6384
cv3       6383
cv4       6383
cv5       6383
dtype: int64