In [1]:
import pandas as pd
import numpy as np
import os
import re

In [2]:
# df = pd.read_csv(r"E:\数据集\ShipEar\data_preprocessing\annotation\shipear_group_class.csv") 

# def get_article(word):
#     """根据单词首字母选择冠词"""
#     if word and word[0].lower() in 'aeiou':
#         return 'an'
#     return 'a'

# def is_valid_value(value):
#     """检查值是否有效（非空且非NaN且不是'Not available'）"""
#     if pd.isna(value) or value == '' or str(value).strip() == '':
#         return False
#     # 检查是否为"Not available"或类似的无效值
#     str_value = str(value).strip().lower()
#     if str_value in ['not available', 'n/a', 'na', 'null', 'none']:
#         return False
#     return True

# def format_type_with_class(type_val, class_id=None):
#     """格式化船舶类型和类别ID"""
#     if not is_valid_value(type_val):
#         return None
    
#     # 检查class_id是否为有效的数值（包括0）
#     if class_id is not None and not pd.isna(class_id) and str(class_id).strip() != '':
#         try:
#             # 确保class_id是数值
#             class_id_num = int(class_id)
#             return f"{type_val}(size class {class_id_num})"
#         except (ValueError, TypeError):
#             pass
    
#     return str(type_val)

# def format_distance(distance):
#     """格式化距离信息"""
#     if not is_valid_value(distance):
#         return None
    
#     try:
#         dist_val = float(distance)
#         if dist_val >= 100:
#             return f"source distance: >={int(dist_val)} m"
#         elif dist_val <= 50:
#             return f"source distance: <={int(dist_val)} m"
#         else:
#             return f"source distance: {int(dist_val)} m"
#     except:
#         return f"source distance: {distance}"

# def format_depth(depth):
#     """格式化深度信息"""
#     if not is_valid_value(depth):
#         return None
    
#     try:
#         depth_val = float(depth)
#         if depth_val >= 20:
#             return f"channel depth: deep ({depth_val} m)"
#         elif depth_val <= 5:
#             return f"channel depth: shallow ({depth_val} m)"
#         else:
#             return f"channel depth: {depth_val} m"
#     except:
#         return f"channel depth: {depth}"

# def format_wind(wind):
#     """格式化风速信息"""
#     if not is_valid_value(wind):
#         return None
    
#     try:
#         wind_val = float(wind)
#         return f"wind: {wind_val} m/s"
#     except:
#         # 如果无法转换为数字，检查是否为有效的文本值
#         wind_str = str(wind).strip()
#         if wind_str.lower() not in ['not available', 'n/a', 'na', 'null', 'none']:
#             return f"wind: {wind_str}"
#         return None

# def generate_prompt_en(row):
#     """生成英文提示文本"""
#     # 获取船舶类型和class_id
#     type_val = row.get('Type', '')
#     class_id = row.get('class_id', None)  # 使用class_id列
    
#     # 格式化类型信息
#     formatted_type = format_type_with_class(type_val, class_id)
#     if not formatted_type:
#         return ''
    
#     # 选择冠词（基于原始类型，不包含class_id）
#     article = get_article(type_val)
#     base_text = f"The sound belongs to {article} {formatted_type}"
    
#     # 收集有效的参数信息
#     params = []
    
#     # 距离信息
#     distance_text = format_distance(row.get('Distance'))
#     if distance_text:
#         params.append(distance_text)
    
#     # 深度信息
#     depth_text = format_depth(row.get('Channel Depth'))
#     if depth_text:
#         params.append(depth_text)
    
#     # 风速信息
#     wind_text = format_wind(row.get('Wind'))
#     if wind_text:
#         params.append(wind_text)
    
#     # 组合文本
#     if params:
#         # 用分号和空格连接参数
#         param_text = '; '.join(params)
#         return f"{base_text}. {param_text}."
#     else:
#         return f"{base_text}."

# # 应用函数生成prompt_en列
# df['prompt_en'] = df.apply(generate_prompt_en, axis=1)
# out_file = r"E:\数据集\Research_Project\Zoer-Shot_Project\Prompts\zero-shot_prompt.csv"
# df.to_csv(out_file, index=False)

In [5]:
df = pd.read_csv(r"E:\数据集\ShipEar\data_preprocessing\annotation\shipear_group_class.csv")

def is_valid_value(value):
    """检查值是否有效（非空且非NaN且不是'Not available'）"""
    if pd.isna(value) or value == '' or str(value).strip() == '':
        return False
    # 检查是否为"Not available"或类似的无效值
    str_value = str(value).strip().lower()
    if str_value in ['not available', 'n/a', 'na', 'null', 'none']:
        return False
    return True

def extract_numeric_value(value):
    """从文本中提取数值，处理单位和比较符号"""
    if pd.isna(value):
        return None
    
    # 转换为字符串并清理
    str_value = str(value).strip()
    
    # 使用正则表达式提取数字（包括小数）
    # 匹配模式：可选的比较符号 + 数字（整数或小数）+ 可选的单位
    match = re.search(r'([<>=]*)(\d+\.?\d*)', str_value)
    if match:
        try:
            return float(match.group(2))
        except ValueError:
            return None
    return None

def map_distance_category(distance):
    """将距离数值映射为分类标签"""
    if not is_valid_value(distance):
        return None
    
    # 先尝试直接转换为float
    try:
        dist_val = float(distance)
    except (ValueError, TypeError):
        # 如果失败，尝试从文本中提取数值
        dist_val = extract_numeric_value(distance)
        if dist_val is None:
            return None
    
    if dist_val <= 50:
        return "close"
    elif dist_val <= 99:
        return "mid"
    else:
        return "far"

def map_depth_category(depth):
    """将深度数值映射为分类标签"""
    if not is_valid_value(depth):
        return None
    
    # 先尝试直接转换为float
    try:
        depth_val = float(depth)
    except (ValueError, TypeError):
        # 如果失败，尝试从文本中提取数值
        depth_val = extract_numeric_value(depth)
        if depth_val is None:
            return None
    
    if depth_val <= 5:
        return "shallow"
    elif depth_val <= 19:
        return "medium"
    else:
        return "deep"

def map_wind_category(wind):
    """将风速数值映射为分类标签"""
    if not is_valid_value(wind):
        return None
    
    # 先尝试直接转换为float
    try:
        wind_val = float(wind)
    except (ValueError, TypeError):
        # 如果失败，尝试从文本中提取数值
        wind_val = extract_numeric_value(wind)
        if wind_val is None:
            return None
    
    if wind_val <= 2:
        return "calm"
    elif wind_val <= 6:
        return "moderate"
    else:
        return "strong"

def generate_hydrophone_prompt_en(row):
    """生成水听器模板格式的英文提示文本"""
    # 映射各个参数
    distance = map_distance_category(row.get('Distance'))
    depth = map_depth_category(row.get('Channel Depth'))
    wind = map_wind_category(row.get('Wind'))
    
    # 基础文本
    base_text = "Hydrophone recording of a marine vessel"
    
    # 构建参数部分
    params = []
    
    # 添加距离信息
    if distance:
        params.append(f"at {distance} range")
    
    # 添加深度信息
    if depth:
        params.append(f"in {depth} water")
    
    # 添加风速信息
    if wind:
        params.append(f"under {wind} wind")
    
    # 组合最终文本
    if params:
        return f"{base_text} {' '.join(params)}."
    else:
        return f"{base_text}."

df['prompt_en'] = df.apply(generate_hydrophone_prompt_en, axis=1)
out_file = r"E:\数据集\ShipEar\data_preprocessing\annotation\shipear_group_class_prompt_en.csv"
df.to_csv(out_file, index=False)

In [6]:
df = pd.read_csv(r"E:\数据集\ShipEar\data_preprocessing\annotation\shipear_group_class.csv") 

def extract_numeric_value(value):
    """从文本中提取数值，处理单位和比较符号"""
    if pd.isna(value):
        return None
    
    # 转换为字符串并清理
    str_value = str(value).strip()
    
    # 使用正则表达式提取数字（包括小数）
    # 匹配模式：可选的比较符号 + 数字（整数或小数）+ 可选的单位
    match = re.search(r'([<>=]*)(\d+\.?\d*)', str_value)
    if match:
        try:
            return float(match.group(2))
        except ValueError:
            return None
    return None

def map_distance_category(distance):
    """将距离数值映射为分类标签"""
    if not is_valid_value(distance):
        return None
    
    # 先尝试直接转换为float
    try:
        dist_val = float(distance)
    except (ValueError, TypeError):
        # 如果失败，尝试从文本中提取数值
        dist_val = extract_numeric_value(distance)
        if dist_val is None:
            return None
    
    if dist_val <= 50:
        return "close"
    elif dist_val <= 99:
        return "mid"
    else:
        return "far"

def map_depth_category(depth):
    """将深度数值映射为分类标签"""
    if not is_valid_value(depth):
        return None
    
    # 先尝试直接转换为float
    try:
        depth_val = float(depth)
    except (ValueError, TypeError):
        # 如果失败，尝试从文本中提取数值
        depth_val = extract_numeric_value(depth)
        if depth_val is None:
            return None
    
    if depth_val <= 5:
        return "shallow"
    elif depth_val <= 19:
        return "medium"
    else:
        return "deep"

def map_wind_category(wind):
    """将风速数值映射为分类标签"""
    if not is_valid_value(wind):
        return None
    
    # 先尝试直接转换为float
    try:
        wind_val = float(wind)
    except (ValueError, TypeError):
        # 如果失败，尝试从文本中提取数值
        wind_val = extract_numeric_value(wind)
        if wind_val is None:
            return None
    
    if wind_val <= 2:
        return "calm"
    elif wind_val <= 6:
        return "moderate"
    else:
        return "strong"

def generate_prompt_en(row):
    """生成新格式的英文提示文本"""
    # 获取船舶类型
    vessel_class = row.get('Type', '')
    if not is_valid_value(vessel_class):
        return ''
    
    # 映射各个参数
    distance = map_distance_category(row.get('Distance'))
    depth = map_depth_category(row.get('Channel Depth'))
    wind = map_wind_category(row.get('Wind'))
    
    # 基础文本
    base_text = f"Underwater recording of a {vessel_class}"
    
    # 构建参数部分
    params = []
    
    # 添加距离信息
    if distance:
        params.append(f"at {distance} range")
    
    # 添加深度信息
    if depth:
        params.append(f"in {depth} water")
    
    # 添加风速信息
    if wind:
        params.append(f"with {wind} wind")
    
    # 组合最终文本
    if params:
        return f"{base_text} {' '.join(params)}."
    else:
        return f"{base_text}."

df['prompt_en'] = df.apply(generate_prompt_en, axis=1)
out_file = r"E:\数据集\Research_Project\Zero-Shot_Project\Prompts\zero-shot_prompt.csv"
df.to_csv(out_file, index=False)

In [7]:
file_directory = r"E:\数据集\ShipEar\data_preprocessing\8_Frame_Windows_10s_50%"
if os.path.exists(file_directory):
    segmented_files = [f for f in os.listdir(file_directory) if f.endswith('.wav') and os.path.isfile(os.path.join(file_directory, f))]

    original_file_mapping = {}
    for _, row in df.iterrows():
        original_filename = row['Filename']
        original_file_mapping[original_filename] = row.to_dict()

    segmented_records = []

    for segmented_file in segmented_files:
        segment_name = os.path.splitext(segmented_file)[0]
        parts = segment_name.split('_')
        if parts and parts[-1].isdigit():
            segment_id = parts[-1]
            original_filename = '_'.join(parts[:-1]) + '.wav'
        
        if original_filename in original_file_mapping:
            original_record = original_file_mapping[original_filename].copy()

            # 添加分段信息
            original_record['segment_name'] = segment_name
            original_record['segment_id'] = segment_id
            segmented_records.append(original_record)
    
    if segmented_records:
        segmented_df = pd.DataFrame(segmented_records)
        # 调整列顺序
        column_order = ['ID', 'Filename', 'Name', 'Type', 'group_id', 'class_id_5', 'class_id_12', 'segment_name', 'segment_id',
                       'Distance', 'Wind', 'prompt_en']
        segmented_df = segmented_df[column_order]
        segmented_output_path = r"E:\数据集\Research_Project\Zero-Shot_Project\Prompts\zero-shot.csv"
        segmented_df.to_csv(segmented_output_path, index=False)