# Dataset Splitting for Model Training

This notebook performs the final step of preparing the data for a diffusion model. It splits the dataset into training, validation, and test sets.

**Workflow:**
1.  **Define Paths**: Sets up all input and output directories.
2.  **Create Directories**: Creates the `final` directory structure (`train/sketches`, `train/maps`, etc.).
3.  **Load Data**: Loads the `final_prompts.csv` which contains the metadata for the images.
4.  **Shuffle and Split**:
    - Randomly shuffles the dataset to ensure unbiased splits.
    - Splits the data into 85% training, 5% validation, and 10% testing sets.
    - Adds a `split` column to the dataframe to track which set each sample belongs to.
5.  **Copy Files**: Copies the sketch and map images into their respective `train`, `val`, or `test` folders.
6.  **Save Final Prompts**: Saves a new `prompts.csv` in the `final` directory containing the metadata for all splits.
7.  **Verification**: Checks the file counts in each directory to confirm the split was successful.


In [6]:
# Imports and Path Definitions
import pandas as pd
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import shutil
from tqdm.notebook import tqdm

# --- Configuration ---
BASE_PATH = Path(__file__).resolve().parent if '__file__' in globals() else Path.cwd()
PROJECT_ROOT = BASE_PATH.parent
OUTPUT_DIR = PROJECT_ROOT / 'output'

# Input paths
PROMPTS_INPUT_PATH = OUTPUT_DIR / 'amazing_logos_v4' / 'data' / 'meta_postprep' / 'final_prompts.csv'
SKETCHES_DIR = OUTPUT_DIR / 'amazing_logos_v4' / 'images' / 'balanced_sample_2k_512x512_sketches'
MAPS_DIR = OUTPUT_DIR / 'amazing_logos_v4' / 'images' / 'balanced_sample_2k_512x512_maps' # lineart map based on postrocessed sketches
LOGO_DIR = OUTPUT_DIR / 'amazing_logos_v4' / 'images' / 'balanced_sample_2k_512x512'

# Output paths
FINAL_DIR = OUTPUT_DIR / 'final'

# Define split directories
SPLITS = ['train', 'val', 'test']
SUBFOLDERS = ['sketches', 'maps', 'logo']

# --- Create Directories ---
for split in SPLITS:
    for subfolder in SUBFOLDERS:
        path = FINAL_DIR / split / subfolder
        path.mkdir(parents=True, exist_ok=True)
        print(f"Created directory: {path}")

print(f"\nOutput will be saved in: {FINAL_DIR}")

Created directory: c:\studium\master_thesis\data_prep\output\final\train\sketches
Created directory: c:\studium\master_thesis\data_prep\output\final\train\maps
Created directory: c:\studium\master_thesis\data_prep\output\final\train\logo
Created directory: c:\studium\master_thesis\data_prep\output\final\val\sketches
Created directory: c:\studium\master_thesis\data_prep\output\final\val\maps
Created directory: c:\studium\master_thesis\data_prep\output\final\val\logo
Created directory: c:\studium\master_thesis\data_prep\output\final\test\sketches
Created directory: c:\studium\master_thesis\data_prep\output\final\test\maps
Created directory: c:\studium\master_thesis\data_prep\output\final\test\logo

Output will be saved in: c:\studium\master_thesis\data_prep\output\final


In [7]:
# Load the dataset
if not PROMPTS_INPUT_PATH.exists():
    raise FileNotFoundError(f"Input prompts file not found: {PROMPTS_INPUT_PATH}")

df = pd.read_csv(PROMPTS_INPUT_PATH)
print(f"Loaded {len(df)} records from {PROMPTS_INPUT_PATH.name}")

# --- Shuffle and Split the Data ---
# Shuffle the DataFrame
df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)

# 85/5/10 split
train_val_df, val_df = train_test_split(df_shuffled, test_size=0.1, random_state=42)
train_df, test_df = train_test_split(train_val_df, test_size=0.1 / 0.9, random_state=42) # 0.1/0.9 ensures 10% of total

# Add a 'split' column
train_df['split'] = 'train'
val_df['split'] = 'val'
test_df['split'] = 'test'

# Combine back into a single DataFrame with the new 'split' column
final_df = pd.concat([train_df, val_df, test_df]).sort_index()

print("\nDataset split:")
print(f"Training set:   {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")
print(f"Test set:       {len(test_df)} samples")
print(f"Total:          {len(final_df)} samples")

final_df.head()


Loaded 1810 records from final_prompts.csv

Dataset split:
Training set:   1448 samples
Validation set: 181 samples
Test set:       181 samples
Total:          1810 samples


Unnamed: 0,id,prompt,category_main,category,description,tags,company,split
0,amazing_logo_v4291200,"minimalistic logo, solid background; descripti...",health,healthcare_general,Care Circle Holistic Leaf Healthcare Hands,"successful_vibe,minimalist,thoughtprovoking,ab...",Simple elegant logo for Elmhurst,train
1,amazing_logo_v4014162,"minimalistic logo, solid background; descripti...",health,wellness_fitness,fitness leap jaguar,"successful_vibe,minimalist,thoughtprovoking,ab...",Simple elegant logo for jaguar,test
2,amazing_logo_v4274807,"minimalistic logo, solid background; descripti...",professional_financial_legal,marketing_advertising,advertising graphic design branding creative d...,"successful_vibe,minimalist,thoughtprovoking,ab...",Simple elegant logo for David Day Associates,train
3,amazing_logo_v4204167,"minimalistic logo, solid background; descripti...",tech,telecommunications,Blue Currier Graph Sans Modern Green Chart Lin...,"successful_vibe,minimalist,thoughtprovoking,ab...",Simple elegant logo for Smith Micro Analytics,train
4,amazing_logo_v4230976,"minimalistic logo, solid background; descripti...",other,lifestyle_personal,gray handwritten wordmark,"successful_vibe,minimalist,thoughtprovoking,ab...",Simple elegant logo for Omma,test


In [8]:
# --- Copy Files to Split Directories ---
tqdm.pandas(desc="Copying files")

def copy_files_for_row(row):
    image_id = row['id']
    split = row['split']
    
    # Source files
    sketch_source = SKETCHES_DIR / f"{image_id}.png"
    map_source = MAPS_DIR / f"{image_id}.png"
    logo_source = LOGO_DIR / f"{image_id}.png"
    
    # Destination paths
    sketch_dest = FINAL_DIR / split / 'sketches' / f"{image_id}.png"
    map_dest = FINAL_DIR / split / 'maps' / f"{image_id}.png"
    logo_dest = FINAL_DIR / split / 'logo' / f"{image_id}.png"
    
    # Copy sketch
    if sketch_source.exists():
        shutil.copy(sketch_source, sketch_dest)
    else:
        print(f"Warning: Sketch not found for {image_id}")
        
    # Copy map
    if map_source.exists():
        shutil.copy(map_source, map_dest)
    else:
        print(f"Warning: Map not found for {image_id}")

    # Copy logo
    if logo_source.exists():
        shutil.copy(logo_source, logo_dest)
    else:
        print(f"Warning: Logo not found for {image_id}")

# Apply the copy function to each row
final_df.progress_apply(copy_files_for_row, axis=1)

print("\nFile copying complete.")

Copying files:   0%|          | 0/1810 [00:00<?, ?it/s]


File copying complete.


In [9]:
# --- Save the Split Prompts CSVs ---
for split_name, df_split in [('train', train_df), ('val', val_df), ('test', test_df)]:
    output_path = FINAL_DIR / split_name / 'prompts.csv'
    df_split.to_csv(output_path, index=False)
    print(f"Saved {split_name} prompts CSV to: {output_path}")

Saved train prompts CSV to: c:\studium\master_thesis\data_prep\output\final\train\prompts.csv
Saved val prompts CSV to: c:\studium\master_thesis\data_prep\output\final\val\prompts.csv
Saved test prompts CSV to: c:\studium\master_thesis\data_prep\output\final\test\prompts.csv


In [10]:
# --- Verification Step ---
print("--- Verifying file counts ---")
total_files = 0
for split in SPLITS:
    print(f"\n-- {split.upper()} --")
    # Verify prompts.csv
    prompts_path = FINAL_DIR / split / 'prompts.csv'
    if prompts_path.exists():
        df_check = pd.read_csv(prompts_path)
        print(f"Found {len(df_check)} records in {split}/prompts.csv")
    else:
        print(f"❌ Missing prompts.csv in {split} directory")

    # Verify image files
    for subfolder in SUBFOLDERS:
        count = len(list((FINAL_DIR / split / subfolder).glob('*.png')))
        print(f"Found {count} files in {split}/{subfolder}")
        total_files += count

print(f"\nTotal image files copied: {total_files}")
print(f"Expected total: {len(final_df) * 3}")

if total_files == len(final_df) * 3:
    print("\n✅ Verification successful: File counts match the dataframe size.")
else:
    print("\n❌ Verification failed: Mismatch in file counts.")

--- Verifying file counts ---

-- TRAIN --
Found 1448 records in train/prompts.csv
Found 1448 files in train/sketches
Found 1448 files in train/maps
Found 1448 files in train/logo

-- VAL --
Found 181 records in val/prompts.csv
Found 181 files in val/sketches
Found 181 files in val/maps
Found 181 files in val/logo

-- TEST --
Found 181 records in test/prompts.csv
Found 181 files in test/sketches
Found 181 files in test/maps
Found 181 files in test/logo

Total image files copied: 5430
Expected total: 5430

✅ Verification successful: File counts match the dataframe size.
