In [6]:
# %%

import pandas as pd
from rich.pretty import pprint
from sklearn.model_selection import train_test_split
import os

In [7]:
# %%

# Load the CSV file
csv_path = "/disk/scratch4/felix/data_sets/my_datasets/deepship/Data/wav_list_with_meta.csv"
df = pd.read_csv("/disk/scratch4/felix/data_sets/my_datasets/deepship/Data/wav_list_with_meta.csv")

In [8]:
ship_labels = df.groupby('Shipname')['Class'].nunique().reset_index()
duplicates = ship_labels[ship_labels['Class'] > 1]
if not duplicates.empty:
    print("Ships appearing in multiple labels:")
    print(duplicates)
else:
    print("No ships appearing in multiple Class")

No ships appearing in multiple Class


In [9]:
# %%

ship_durations = df.groupby(["Class", "Shipname"])["Duration"].sum().reset_index()

ship_durations = ship_durations.sort_values(by="Duration", ascending=True)


# Split the Ships into train, validation, and test sets for each class
def split_ships_by_class_and_duration(
    ship_durations, test_size=0.2, val_size=0.2, random_state=42
):
    train_ships = {}
    val_ships = {}
    test_ships = {}

    for class_ in ship_durations["Class"].unique():
        class_data = ship_durations[ship_durations["Class"] == class_]

        # Calculate target durations for each split
        total_duration = class_data["Duration"].sum()
        test_duration = total_duration * test_size
        val_duration = total_duration * val_size

        # Initialize splits
        train_ships[class_] = []
        val_ships[class_] = []
        test_ships[class_] = []

        current_train_duration = 0
        current_val_duration = 0
        current_test_duration = 0

        for _, row in class_data.iterrows():
            ship = row["Shipname"]
            duration = row["Duration"]
            if current_val_duration < val_duration:
                val_ships[class_].append(ship)
                current_val_duration += duration
            elif current_test_duration < test_duration:
                test_ships[class_].append(ship)
                current_test_duration += duration
            else:
                train_ships[class_].append(ship)
                current_train_duration += duration

    return train_ships, val_ships, test_ships


# Perform the split
train_ships, val_ships, test_ships = split_ships_by_class_and_duration(
    ship_durations, test_size=0.1, val_size=0.1
)

# Assign rows to splits based on Ship and Class
train_df = df[df.apply(lambda row: row["Shipname"] in train_ships[row["Class"]], axis=1)]
val_df = df[df.apply(lambda row: row["Shipname"] in val_ships[row["Class"]], axis=1)]
test_df = df[df.apply(lambda row: row["Shipname"] in test_ships[row["Class"]], axis=1)]

def make_relative_path(path, class_name):
    # Split path into parts
    parts = path.split('/')
    
    # Remove the class from the path and return the rest
    if parts[0].lower() == class_name:
        return os.path.join(*parts[1:])  # Join the remaining parts relative to the class
    return path  # In case something goes wrong (unexpected path format)

# Print duration details for each split
def print_duration_details(df, split_name):
    total_duration = df["Duration"].sum()
    print(f"\n{split_name} Duration: {total_duration} seconds")

    # Ensure all classes are included, even if duration is 0
    classes = sorted(df["Class"].unique())
    for class_ in classes:
        # Get the duration for this class in the split
        class_duration = df[df["Class"] == class_]["Duration"].sum()

        # Get the file path parts for this class
        class_paths = df[df["Class"] == class_]["Filename"].tolist()

        # Convert paths to be relative to the class folder
        relative_paths = [make_relative_path(path, class_) for path in class_paths]

        # Print class details
        print(f"  {class_}: {class_duration} seconds, Assigned file paths: {relative_paths}")

train_df.to_csv("grouped_train.csv", index=False)
val_df.to_csv("grouped_validation.csv", index=False)
test_df.to_csv("grouped_test.csv", index=False)

print_duration_details(train_df, "Train")
print_duration_details(val_df, "Validation")
print_duration_details(test_df, "Test")


Train Duration: 33610.0 seconds
  Cargo: 5827.0 seconds, Assigned file paths: ['Cargo/20171104-1/1.wav', 'Cargo/20171124c-42/134001.wav', 'Cargo/20171205g-73/124454.wav', 'Cargo/20171125e-48/234924.wav', 'Cargo/20171121c-32/111158.wav', 'Cargo/20171130-57/165031.wav', 'Cargo/20171125c-46/134833.wav', 'Cargo/20171202e-64/141802.wav', 'Cargo/20171126b-51/044753.wav', 'Cargo/20171129c-56/075459.wav', 'Cargo/103.wav', 'Cargo/41.wav', 'Cargo/69.wav', 'Cargo/110.wav']
  Passengership: 13305.0 seconds, Assigned file paths: ['Passengership/20180926-190/011928.wav', 'Passengership/20160603-15/042208.wav', 'Passengership/20160625a-22/041906.wav', 'Passengership/20160820a-33/33.wav', 'Passengership/20180511-165/101015.wav', 'Passengership/20180329-150/181312.wav', 'Passengership/20170906-115/104632.wav', 'Passengership/20180828a-181/125034.wav', 'Passengership/20160507a-4/4.wav', 'Passengership/20160629b-24/135359.wav', 'Passengership/20180215-144/185044.wav', 'Passengership/20161214-50/50.wav',

In [10]:
# %%

assert len(set(train_df["Shipname"]) & set(val_df["Shipname"])) == 0, "Ships overlap between Train and Validation!"
assert len(set(train_df["Shipname"]) & set(test_df["Shipname"])) == 0, "Ships overlap between Train and Test!"
assert len(set(val_df["Shipname"]) & set(test_df["Shipname"])) == 0, "Ships overlap between Validation and Test!"