In [43]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
from scipy.signal import detrend
import json

In [44]:
def shift_with_edge_pad(x, shift):
    if shift > 0:
        pad = np.full(shift, x[0])
        return np.concatenate([pad, x[:-shift]])
    elif shift < 0:
        pad = np.full(-shift, x[-1])
        return np.concatenate([x[-shift:], pad])
    else:
        return x

In [45]:
folder = './Dataset/Open/Edited_plant_split/Plectranthus/train'
shift_amount_sec = 1.0
segment_len_sec = 3.0

normal_segments = []
segment_labels = []

for file in os.listdir(folder):
    if file.endswith('.wav'):
        wav_path = os.path.join(folder, file)
        txt_path = os.path.join(folder, file.replace('.wav', '-events.txt'))
        if not os.path.exists(txt_path):
            continue

        sr, data = wavfile.read(wav_path)
        if data.ndim > 1:
            data = data[:, 0]
        data = detrend(data)
        data_norm = (data - np.mean(data)) / np.std(data)
        seg_len = int(segment_len_sec * sr)

        # --- Parse all (1,2) marker pairs ---
        markers = []
        with open(txt_path, 'r') as f:
            for line in f:
                if line.strip() and line[0].isdigit():
                    m_id, m_time = line.split(',')
                    markers.append((int(m_id.strip()), float(m_time.strip())))

        # Make ordered list of anomaly regions (sample indices)
        anomaly_regions = []
        temp_start = None
        for m_id, m_time in markers:
            if m_id == 1:
                temp_start = int(m_time * sr)
            elif m_id == 2 and temp_start is not None:
                temp_end = int(m_time * sr)
                anomaly_regions.append((temp_start, temp_end))
                temp_start = None

        # If no valid anomaly region, treat entire file as normal
        if not anomaly_regions:
            normal_regions = [(0, len(data_norm))]
        else:
            # Compute the normal (non-anomaly) regions
            normal_regions = []
            last_end = 0
            for start, end in anomaly_regions:
                if last_end < start:
                    normal_regions.append((last_end, start))
                last_end = end
            if last_end < len(data_norm):
                normal_regions.append((last_end, len(data_norm)))

        # --- Extract non-anomaly segments from original signal ---
        for n_start, n_end in normal_regions:
            # Only full segments
            for i in range(n_start, n_end - seg_len + 1, seg_len):
                normal_segments.append(data_norm[i:i + seg_len])
                segment_labels.append(0)

        # --- Augmentation ---
        max_shift = int(shift_amount_sec * sr)
        shift_amount = random.randint(-max_shift, max_shift)
        data_shifted = shift_with_edge_pad(data_norm, shift_amount)

        # Shift all region indices
        anomaly_regions_shifted = []
        for start, end in anomaly_regions:
            s = max(0, min(len(data_shifted)-1, start + shift_amount))
            e = max(0, min(len(data_shifted)-1, end + shift_amount))
            anomaly_regions_shifted.append((s, e))
        # Compute normal regions for shifted
        if not anomaly_regions_shifted:
            normal_regions_shifted = [(0, len(data_shifted))]
        else:
            normal_regions_shifted = []
            last_end = 0
            for start, end in anomaly_regions_shifted:
                if last_end < start:
                    normal_regions_shifted.append((last_end, start))
                last_end = end
            if last_end < len(data_shifted):
                normal_regions_shifted.append((last_end, len(data_shifted)))
        # Extract non-anomaly segments from shifted
        for n_start, n_end in normal_regions_shifted:
            for i in range(n_start, n_end - seg_len + 1, seg_len):
                normal_segments.append(data_shifted[i:i + seg_len])
                segment_labels.append(1)

normal_segments = np.array(normal_segments, dtype=np.float32)
segment_labels = np.array(segment_labels, dtype=np.int32)
print("Total normal (non-anomaly) segments:", len(normal_segments))
print("Original (non-augmented):", np.sum(segment_labels == 0))
print("Augmented (shifted):", np.sum(segment_labels == 1))

Total normal (non-anomaly) segments: 344
Original (non-augmented): 172
Augmented (shifted): 172


In [46]:
class AudioSegDataset(Dataset):
    def __init__(self, segments):
        self.data = torch.from_numpy(segments)
    def __len__(self):
        return self.data.shape[0]
    def __getitem__(self, idx):
        return self.data[idx], self.data[idx]

dataset = AudioSegDataset(normal_segments)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [47]:
class Autoencoder(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(n, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(16, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, n),
        )
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [48]:
n = normal_segments.shape[1]

load_path = './Dataset/Open/Model/autoencoder/MAE/Solanum.pth'
with open('./Dataset/Open/Model/autoencoder/MAE/Solanum.json', 'r') as f:
    threshold = json.load(f)['threshold']
print(f"Loaded threshold: {threshold}")

model = Autoencoder(n)
model.load_state_dict(torch.load(load_path))
model.eval()

splits = [
    ('train', './Dataset/Open/Edited_plant_split/Solanum/train'),
    ('test', './Dataset/Open/Edited_plant_split/Solanum/test')
]

split_ious = {'train': [], 'test': []}
split_success = {'train': [], 'test': []}

success_threshold = 0.5   # Set IoU threshold for successful detection

# ---- Extract mean latent embedding from normal (non-anomaly) segments ----

latent_list = []

with torch.no_grad():
    for split_name, folder in splits:
        wav_files = [file for file in os.listdir(folder) if file.endswith('.wav')]
        for wav_file in wav_files:
            txt_file = wav_file.replace('.wav', '-events.txt')
            txt_path = os.path.join(folder, txt_file)
            wav_path = os.path.join(folder, wav_file)
            if not os.path.exists(txt_path):
                continue

            # Load and preprocess
            sr, data = wavfile.read(wav_path)
            if data.ndim > 1:
                data = data[:, 0]
            data = detrend(data)
            data_norm = (data - np.mean(data)) / np.std(data)

            # Load anomaly markers
            with open(txt_path, 'r') as f:
                lines = f.readlines()
            markers = {}
            for line in lines:
                if line.strip() and line[0].isdigit():
                    m_id, m_time = line.split(',')
                    markers[int(m_id.strip())] = float(m_time.strip())
            seg_len = int(segment_len_sec * sr)

            # Find anomaly regions
            anomaly_regions = []
            temp_start = None
            for line in lines:
                if line.strip() and line[0].isdigit():
                    m_id, m_time = line.split(',')
                    m_id = int(m_id.strip())
                    m_time = float(m_time.strip())
                    if m_id == 1:
                        temp_start = int(m_time * sr)
                    elif m_id == 2 and temp_start is not None:
                        temp_end = int(m_time * sr)
                        anomaly_regions.append((temp_start, temp_end))
                        temp_start = None
            # Find normal (non-anomaly) regions
            if not anomaly_regions:
                normal_regions = [(0, len(data_norm))]
            else:
                normal_regions = []
                last_end = 0
                for start, end in anomaly_regions:
                    if last_end < start:
                        normal_regions.append((last_end, start))
                    last_end = end
                if last_end < len(data_norm):
                    normal_regions.append((last_end, len(data_norm)))
            # Extract normal segments and get latents
            for n_start, n_end in normal_regions:
                for i in range(n_start, n_end - seg_len + 1, seg_len):
                    segment = data_norm[i:i + seg_len]
                    segment_tensor = torch.from_numpy(segment).float().unsqueeze(0)
                    latent = model.encoder(segment_tensor)  # shape [1, latent_dim]
                    latent_list.append(latent.squeeze(0))   # shape [latent_dim]

# Stack all latent vectors and compute the mean
if len(latent_list) > 0:
    latent_all = torch.stack(latent_list)    # [N, latent_dim]
    mean_latent = latent_all.mean(dim=0)     # [latent_dim]
    # Save mean latent embedding
    mean_latent_path = load_path.replace('.pth', '_mean_latent.pt')
    torch.save(mean_latent, mean_latent_path)
    print(f"Saved mean latent embedding to {mean_latent_path}")
else:
    print("No normal segments found to extract mean latent embedding.")

Loaded threshold: 0.06286919116973877
Saved mean latent embedding to ./Dataset/Open/Model/autoencoder/MAE/Solanum_mean_latent.pt
