# GAN


In [83]:
# basic
import os
import datetime
import matplotlib.pyplot as plt
import glob
import json
import pandas as pd
import numpy as np

# torch
import torch.autograd as autograd
import torch.optim as optim
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader


# other
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error
import itertools


from sklearn.metrics import mean_absolute_error
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean

# DATASET + MODELY

* **`Signal1DDataset`**

  * Loads signal data from CSV files, grouped by experiment IDs.
  * Each sample contains:

    * A **1D signal** (from column `"intensity"` or last column).
    * A **condition vector** (from parameter CSV, excluding `experiment` column).
  * Supports:

    * Global **signal normalization** to `[-1, 1]`.
    * **Condition normalization** to zero mean, unit variance.

* **`Generator1D`**

  * Takes a random noise vector and a condition vector.
  * Produces a **synthetic 1D signal** of length `signal_len` in range `[-1, 1]`.

* **`Discriminator1D`**

  * Takes a signal (real or generated) and a condition vector.
  * Uses 1D convolutions to extract features.
  * Outputs a single scalar (real vs fake score).



In [84]:
class Signal1DDataset(Dataset):
    def __init__(self, root_dir, params_csv, allowed_experiments=None,
                 normalize_signals=True, normalize_conditions=True):
        """
        Custom PyTorch dataset for loading 1D signal CSV files along with
        associated experimental conditions.

        Args:
            root_dir (str): Root directory containing subfolders for experiments.
            params_csv (str): Path to CSV with experiment parameters.
                              Must contain a column 'experiment' linking to folders.
            allowed_experiments (list[int], optional): Subset of experiments to use.
            normalize_signals (bool): Whether to apply global min-max normalization
                                      of signals into [-1, 1].
            normalize_conditions (bool): Whether to standardize conditions
                                         (zero mean, unit variance).
        """
        self.root_dir = root_dir
        self.params_df = pd.read_csv(params_csv)

        if allowed_experiments is not None:
            self.params_df = self.params_df[self.params_df["experiment"].isin(allowed_experiments)]

        self.samples = []
        for _, row in self.params_df.iterrows():
            folder_id = str(int(row["experiment"]))
            folder_path = os.path.join(root_dir, folder_id)
            if not os.path.isdir(folder_path):
                continue
            for fname in os.listdir(folder_path):
                if fname.lower().endswith(".csv"):
                    fpath = os.path.join(folder_path, fname)
                    conditions = row.drop("experiment").values
                    self.samples.append((fpath, conditions))

        print(f"Loaded {len(self.samples)} files (experiments={allowed_experiments})")

        # === global signal normalization ===
        self.normalize_signals = normalize_signals
        if normalize_signals and len(self.samples) > 0:
            all_signals = []
            for fpath, _ in self.samples:
                df = pd.read_csv(fpath)
                if "intensity" in df.columns:
                    sig = df["intensity"].values.astype(np.float32)
                else:
                    sig = df.iloc[:, -1].values.astype(np.float32)
                all_signals.append(sig)

            all_signals = np.concatenate(all_signals)
            self.global_min = all_signals.min()
            self.global_max = all_signals.max()
            print(f"Global normalization: min={self.global_min:.4f}, max={self.global_max:.4f}")

        # === condition normalization ===
        self.normalize_conditions = normalize_conditions
        if normalize_conditions and len(self.samples) > 0:
            all_conditions = np.stack([s[1] for s in self.samples])
            self.cond_mean = all_conditions.mean(axis=0)
            self.cond_std = all_conditions.std(axis=0)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        signal_path, cond = self.samples[idx]
        df = pd.read_csv(signal_path)

        if "intensity" in df.columns:
            signal = df["intensity"].values.astype(np.float32)
        else:
            signal = df.iloc[:, -1].values.astype(np.float32)

        # global normalization
        if self.normalize_signals:
            signal = (signal - self.global_min) / (self.global_max - self.global_min + 1e-8)
            signal = signal * 2.0 - 1.0  # range [-1, 1]

        # condition normalization
        cond = cond.astype(np.float32)
        if self.normalize_conditions:
            cond = (cond - self.cond_mean) / (self.cond_std + 1e-8)

        return torch.from_numpy(signal), torch.from_numpy(cond)


class Generator1D(nn.Module):
    def __init__(self, noise_dim=64, cond_dim=5, signal_len=450):
        """
        1D Generator network for conditional GAN.

        Args:
            noise_dim (int): Dimension of input noise vector.
            cond_dim (int): Dimension of condition vector.
            signal_len (int): Length of the generated signal.
        """
        super().__init__()
        input_dim = noise_dim + cond_dim
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, signal_len),
            nn.Tanh()   # output range [-1, 1]
        )

    def forward(self, noise, cond):
        x = torch.cat([noise, cond], dim=1)   # (B, noise_dim+cond_dim)
        return self.net(x)                    # (B, signal_len)


class Discriminator1D(nn.Module):
    def __init__(self, cond_dim=5, signal_len=450):
        """
        1D Discriminator network for conditional GAN.

        Args:
            cond_dim (int): Dimension of condition vector.
            signal_len (int): Length of input signal.
        """
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=7, stride=2, padding=3),
            nn.LeakyReLU(0.2),
            nn.Conv1d(16, 32, kernel_size=7, stride=2, padding=3),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=7, stride=2, padding=3),
            nn.LeakyReLU(0.2),
        )

        # compute feature dimension after convolutions
        with torch.no_grad():
            dummy = torch.zeros(1, 1, signal_len)
            out = self.feature_extractor(dummy)
            flat_dim = out.view(1, -1).size(1)

        # fully connected layers
        self.fc1 = nn.Linear(flat_dim + cond_dim, 128)
        self.dropout = nn.Dropout(0.3)
        self.out = nn.Linear(128, 1)

    def forward(self, signal, cond):
        x = signal.unsqueeze(1)       # (B, 1, L)
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)     # flatten
        x = torch.cat([x, cond], dim=1)
        x = self.fc1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)
        return self.out(x)


# TRAIN

* **`train_wgan_gp_l1`**

  * Training loop for WGAN-GP with extra **L1 loss** (encourages signals to match real ones).
  * Updates the **discriminator multiple times per generator update** (`n_critic`).
  * Supports **λL1 decay** and optional **validation with saving outputs**.

* **`gradient_penalty`**

  * Implements the **gradient penalty** term from WGAN-GP.
  * Enforces Lipschitz constraint by penalizing deviation of gradient norm from 1.

* **`validate_gan`**

  * Samples random real signals + conditions.
  * Generates fake signals under same conditions.
  * Plots **real vs generated** side by side for visual sanity check.


In [85]:
def train_wgan_gp_l1(G, D, train_loader, noise_dim, num_epochs=50,
                     lr_G=1e-4, lr_D=5e-5, device="cpu",
                     n_critic=2, lambda_gp=1.0, lambda_l1=10.0,
                     l1_decay_every=None,
                     val_dataset=None, run_dir=None, validate_every=10):
    """
    Hybrid training loop for conditional WGAN-GP with additional L1 reconstruction loss.

    Args:
        G (nn.Module): Generator model.
        D (nn.Module): Discriminator (critic) model.
        train_loader (DataLoader): Training data loader.
        noise_dim (int): Dimension of input noise vector.
        num_epochs (int): Number of training epochs.
        lr_G (float): Learning rate for generator.
        lr_D (float): Learning rate for discriminator.
        device (str): "cpu" or "cuda".
        n_critic (int): Number of discriminator updates per generator update.
        lambda_gp (float): Gradient penalty coefficient.
        lambda_l1 (float): Weight for L1 reconstruction loss.
        l1_decay_every (int or None): Halve lambda_l1 every given epochs.
        val_dataset (Dataset or None): Validation dataset.
        run_dir (str or None): Directory for saving validation outputs.
        validate_every (int): Run validation every N epochs.

    Returns:
        dict: Training history with loss_D, loss_G, lambda_l1.
    """
    optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0.0, 0.9))
    optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0.0, 0.9))

    history = {"loss_D": [], "loss_G": [], "lambda_l1": []}

    for epoch in range(num_epochs):
        loss_D_epoch, loss_G_epoch = 0.0, 0.0

        # decay λL1
        if l1_decay_every is not None and (epoch > 0) and (epoch % l1_decay_every == 0):
            lambda_l1 = lambda_l1 / 2.0
            print(f"[Epoch {epoch}] Decaying lambda_l1 → {lambda_l1}")

        for signals, conds in train_loader:
            signals = signals.to(device).float()
            conds = conds.to(device).float()
            b_size = signals.size(0)

            # --- Critic update ---
            for _ in range(n_critic):
                noise = torch.randn(b_size, noise_dim, device=device)
                fake_signals = G(noise, conds)

                D.zero_grad()
                real_validity = D(signals, conds)
                fake_validity = D(fake_signals.detach(), conds)

                gp = gradient_penalty(D, signals, fake_signals, conds, device, lambda_gp)
                loss_D = fake_validity.mean() - real_validity.mean() + gp

                loss_D.backward()
                optimizer_D.step()

            # --- Generator update ---
            noise = torch.randn(b_size, noise_dim, device=device)
            fake_signals = G(noise, conds)

            G.zero_grad()
            fake_validity = D(fake_signals, conds)

            adv_loss = -fake_validity.mean()
            l1_loss = F.l1_loss(fake_signals, signals)
            loss_G = adv_loss + lambda_l1 * l1_loss

            loss_G.backward()
            optimizer_G.step()

            loss_D_epoch += loss_D.item()
            loss_G_epoch += loss_G.item()

        # epoch averages
        loss_D_epoch /= len(train_loader)
        loss_G_epoch /= len(train_loader)
        history["loss_D"].append(loss_D_epoch)
        history["loss_G"].append(loss_G_epoch)
        history["lambda_l1"].append(lambda_l1)

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Loss_D: {loss_D_epoch:.4f} Loss_G: {loss_G_epoch:.4f} (λL1={lambda_l1:.2f})")

        # --- Validation ---
        if val_dataset is not None and run_dir is not None:
            if (epoch+1) % validate_every == 0 or (epoch+1) == num_epochs:
                val_df = validate_and_save_all(G, val_dataset, noise_dim=noise_dim,
                                               device=device, save_dir=run_dir, epoch=epoch+1)
                print(f"[Epoch {epoch+1}] Val mean MSE={val_df['MSE'].mean():.4f}, "
                      f"Corr={val_df['Corr'].mean():.4f}")

    return history


def gradient_penalty(D, real_data, fake_data, cond, device="cpu", lambda_gp=10):
    """
    Computes gradient penalty for WGAN-GP.

    Args:
        D (nn.Module): Discriminator (critic).
        real_data (Tensor): Batch of real signals.
        fake_data (Tensor): Batch of generated signals.
        cond (Tensor): Condition vectors.
        device (str): Device.
        lambda_gp (float): Penalty weight.

    Returns:
        Tensor: Gradient penalty term.
    """
    b_size = real_data.size(0)
    alpha = torch.rand(b_size, 1, device=device).expand_as(real_data)

    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates = interpolates.to(device)
    interpolates.requires_grad_(True)

    d_interpolates = D(interpolates, cond)

    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates, device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(b_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gp = ((gradient_norm - 1) ** 2).mean() * lambda_gp
    return gp


def validate_gan(G, val_dataset, noise_dim=64, num_samples=5, device="cpu"):
    """
    Visual validation of GAN outputs.

    Args:
        G (nn.Module): Generator model.
        val_dataset (Dataset): Validation dataset.
        noise_dim (int): Noise dimension.
        num_samples (int): Number of samples to visualize.
        device (str): Device.

    Side effects:
        Displays matplotlib plots comparing real vs. generated signals.
    """
    G.eval()
    indices = np.random.choice(len(val_dataset), num_samples, replace=False)

    fig, axes = plt.subplots(num_samples, 1, figsize=(10, 8), sharex=True)
    for i, idx in enumerate(indices):
        real_signal, cond = val_dataset[idx]
        cond = cond.unsqueeze(0).to(device).float()
        noise = torch.randn(1, noise_dim, device=device)
        with torch.no_grad():
            fake_signal = G(noise, cond).cpu().numpy().flatten()

        axes[i].plot(real_signal.numpy(), label="Real", color="black")
        axes[i].plot(fake_signal, label="Generated", color="red", alpha=0.7)
        axes[i].legend()
    plt.suptitle("Validation: Real vs Generated Signals")
    plt.show()


# VALIDATE

* **`validate_gan_with_metrics`**

  * Picks a few random samples from the dataset.
  * Generates fake signals under the same conditions.
  * Computes **MSE** and **Pearson correlation** between real and fake signals.
  * Visualizes and optionally saves the plots.

* **`validate_and_save_all`**

  * Runs generator across the **entire validation dataset**.
  * Computes and saves per-sample metrics (**MSE, Corr**) into a CSV.
  * Saves arrays of real and generated signals (`.npy`).
  * Produces histograms of metric distributions for quick diagnostics.



In [86]:
def validate_gan_with_metrics(G, dataset, noise_dim=64, num_samples=3, device="cpu", save_dir=None):
    """
    Validate generator on a random subset of the dataset and compute metrics.

    Args:
        G (nn.Module): Generator model.
        dataset (Dataset): Dataset with real signals and conditions.
        noise_dim (int): Dimension of the noise vector.
        num_samples (int): Number of random samples to visualize.
        device (str): Device ("cpu" or "cuda").
        save_dir (str or None): If provided, save the plots to this directory.

    Returns:
        list[dict]: List of dictionaries with indices, MSE and correlation.
    """
    G.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    results = []

    fig, axes = plt.subplots(num_samples, 1, figsize=(10, 6), sharex=True)

    for i, idx in enumerate(indices):
        real_signal, cond = dataset[idx]
        real_signal = real_signal.numpy()

        cond = cond.unsqueeze(0).to(device).float()
        noise = torch.randn(1, noise_dim, device=device)

        with torch.no_grad():
            fake_signal = G(noise, cond).cpu().numpy().flatten()

        mse = mean_squared_error(real_signal, fake_signal)
        corr, _ = pearsonr(real_signal, fake_signal)
        results.append({"idx": idx, "MSE": mse, "Corr": corr})

        axes[i].plot(real_signal, label="Real", color="black")
        axes[i].plot(fake_signal, label="Generated", color="red", alpha=0.7)
        axes[i].legend()
        axes[i].set_title(f"Sample {idx} | MSE={mse:.4f}, Corr={corr:.3f}")

    plt.suptitle("Validation: Real vs Generated signals (with metrics)")
    if save_dir is not None:
        plt.savefig(os.path.join(save_dir, "validation_examples.png"), dpi=200)
        plt.close()
    else:
        plt.show()

    return results


def validate_and_save_all(G, dataset, noise_dim=64, device="cpu", save_dir="results", epoch=None):
    """
    Evaluate generator on the entire validation dataset and save results.

    For each sample:
      - Generates a fake signal given the real condition.
      - Computes MSE and Pearson correlation with the real signal.
      - Saves signals and metrics to disk.

    Args:
        G (nn.Module): Generator model.
        dataset (Dataset): Validation dataset.
        noise_dim (int): Dimension of noise vector.
        device (str): Device ("cpu" or "cuda").
        save_dir (str): Directory for saving outputs.
        epoch (int or None): If provided, appended to filenames.

    Returns:
        pd.DataFrame: DataFrame with per-sample metrics (MSE, Corr).
    """
    G.eval()
    all_results = []
    real_signals, fake_signals = [], []

    for idx in range(len(dataset)):
        real_signal, cond = dataset[idx]
        real_signal_np = real_signal.numpy()

        cond = cond.unsqueeze(0).to(device).float()
        noise = torch.randn(1, noise_dim, device=device)

        with torch.no_grad():
            fake_signal = G(noise, cond).cpu().numpy().flatten()

        mse = mean_squared_error(real_signal_np, fake_signal)
        corr, _ = pearsonr(real_signal_np, fake_signal)

        all_results.append({"idx": idx, "MSE": mse, "Corr": corr})
        real_signals.append(real_signal_np)
        fake_signals.append(fake_signal)

    df = pd.DataFrame(all_results)

    # --- dynamic filenames depending on epoch ---
    suffix = f"_epoch{epoch}" if epoch is not None else ""
    df.to_csv(os.path.join(save_dir, f"validation_metrics{suffix}.csv"), index=False)

    np.save(os.path.join(save_dir, f"real_signals{suffix}.npy"), np.array(real_signals))
    np.save(os.path.join(save_dir, f"fake_signals{suffix}.npy"), np.array(fake_signals))

    # histograms of metrics
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.hist(df["MSE"], bins=30, color="skyblue", edgecolor="black")
    plt.title("MSE distribution")
    plt.xlabel("MSE"); plt.ylabel("Count")

    plt.subplot(1, 2, 2)
    plt.hist(df["Corr"], bins=30, color="salmon", edgecolor="black")
    plt.title("Correlation distribution")
    plt.xlabel("Pearson corr"); plt.ylabel("Count")

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"validation_metrics_hist{suffix}.png"), dpi=200)
    plt.close()

    return df


# MAIN

1. **Dataset split**

   * Splits experiments into train/val sets (80/20).
   * Saves or reloads the split from `dataset_split.csv`.

2. **Dataset + DataLoader**

   * Initializes `Signal1DDataset` for train and validation.
   * Builds PyTorch `DataLoader`.

3. **Model initialization**

   * Derives signal length and condition dimension from dataset.
   * Creates `Generator1D` and `Discriminator1D`.

4. **Grid search**

   * Iterates over combinations of `λL1` weights and number of epochs.
   * Each run has its own timestamped results directory.

5. **Training**

   * Calls `train_wgan_gp_l1` with given parameters.
   * Tracks losses and saves them in CSV/JSON.

6. **Saving**

   * Stores trained models (`.pth`).
   * Exports training curves.

7. **Validation**

   * Runs `validate_gan_with_metrics` for sample plots with metrics.
   * Runs `validate_and_save_all` for full validation evaluation.
   * Saves metrics, signals, and histogram plots.



In [88]:
if __name__ == "__main__":
    # ==========================
    # Setup
    # ==========================
    root_dir = "1D_spec"
    params_csv = os.path.join(root_dir, "params.csv")
    base_results_dir = "results"
    os.makedirs(base_results_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # ==========================
    # Dataset split
    # ==========================
    split_file = os.path.join(base_results_dir, "dataset_split.csv")
    if os.path.exists(split_file):
        split_df = pd.read_csv(split_file)
        train_exps = split_df.query("set == 'train'")["experiment"].tolist()
        val_exps   = split_df.query("set == 'val'")["experiment"].tolist()
    else:
        params_df = pd.read_csv(params_csv)
        all_experiments = params_df["experiment"].unique()
        np.random.seed(42)
        np.random.shuffle(all_experiments)

        split_ratio = 0.8
        split_idx = int(len(all_experiments) * split_ratio)
        train_exps = all_experiments[:split_idx]
        val_exps   = all_experiments[split_idx:]

        split_df = pd.DataFrame({
            "experiment": list(train_exps) + list(val_exps),
            "set": ["train"] * len(train_exps) + ["val"] * len(val_exps)
        })
        split_df.to_csv(split_file, index=False)

    # dataset and dataloader
    train_dataset = Signal1DDataset(root_dir, params_csv, allowed_experiments=train_exps)
    val_dataset   = Signal1DDataset(root_dir, params_csv, allowed_experiments=val_exps)
    train_loader  = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # ==========================
    # Signal dimensions
    # ==========================
    signal, cond = train_dataset[0]
    noise_dim = 64
    cond_dim = cond.numel()
    signal_len = signal.numel()

    print(f"Init: noise_dim={noise_dim}, cond_dim={cond_dim}, signal_len={signal_len}")

    # ==========================
    # Grid search parameters
    # ==========================
    lambda_list = [5, 10, 15, 20, 25]
    epoch_list = [10, 20, 30, 40, 50, 60, 70, 80]

    # ==========================
    # Run all combinations
    # ==========================
    for lambda_l1, num_epochs in itertools.product(lambda_list, epoch_list):
        run_name = f"lambda{lambda_l1}_epochs{num_epochs}_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        run_dir = os.path.join(base_results_dir, run_name)
        os.makedirs(run_dir, exist_ok=True)
        print(f"\n=== Running: λ={lambda_l1}, epochs={num_epochs}, output: {run_dir} ===")

        # models
        G = Generator1D(noise_dim=noise_dim, cond_dim=cond_dim, signal_len=signal_len).to(device)
        D = Discriminator1D(cond_dim=cond_dim, signal_len=signal_len).to(device)

        # config
        config = {
            "lambda_l1": lambda_l1,
            "num_epochs": num_epochs,
            "lr_G": 1e-4,
            "lr_D": 5e-5,
            "n_critic": 2,
            "lambda_gp": 1.0,
            "noise_dim": noise_dim,
            "cond_dim": cond_dim,
            "signal_len": signal_len,
            "device": str(device)
        }
        with open(os.path.join(run_dir, "config.json"), "w") as f:
            json.dump(config, f, indent=4)

        # training
        history = train_wgan_gp_l1(
            G, D, train_loader, noise_dim,
            num_epochs=num_epochs,
            device=device,
            lr_G=config["lr_G"], lr_D=config["lr_D"],
            n_critic=config["n_critic"],
            lambda_gp=config["lambda_gp"], lambda_l1=lambda_l1,
            val_dataset=val_dataset, run_dir=run_dir, validate_every=10
        )

        # save training history
        pd.DataFrame(history).to_csv(os.path.join(run_dir, "training_history.csv"), index=False)
        with open(os.path.join(run_dir, "training_history.json"), "w") as f:
            json.dump(history, f)

        # save models
        torch.save(G.state_dict(), os.path.join(run_dir, "generator.pth"))
        torch.save(D.state_dict(), os.path.join(run_dir, "discriminator.pth"))

        # loss curves
        plt.figure(figsize=(8, 5))
        plt.plot(history["loss_D"], label="Discriminator")
        plt.plot(history["loss_G"], label="Generator")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training progress (WGAN-GP + L1)")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig(os.path.join(run_dir, "loss_curve.png"), dpi=200)
        plt.close()

        # validation
        val_results = validate_gan_with_metrics(G, val_dataset, noise_dim=noise_dim,
                                                num_samples=5, device=device, save_dir=run_dir)
        pd.DataFrame(val_results).to_csv(os.path.join(run_dir, "validation_metrics.csv"), index=False)

        val_df = validate_and_save_all(G, val_dataset, noise_dim=noise_dim,
                                       device=device, save_dir=run_dir)
        print(val_df.describe())

        print("Run completed. Results in folder:", run_dir)


Použité zařízení: cpu
Načteno 991 souborů (experiments=[ 1 14  9  2 16  6 21 12  4  5 18 13 19 17  3 10 22])
Globální normalizace: min=0.0015, max=6.9025
Načteno 288 souborů (experiments=[ 8 11 15 20  7])
Globální normalizace: min=0.0039, max=3.2035
Init: noise_dim=64, cond_dim=5, signal_len=450

=== Spouštím běh: λ=5, epochs=10, výstup: results/lambda5_epochs10_20250926_044242 ===
Epoch [1/10] Loss_D: -1.8329 Loss_G: 3.5305 (λL1=5.00)
Epoch [2/10] Loss_D: -2.5101 Loss_G: -5.5305 (λL1=5.00)
Epoch [3/10] Loss_D: 0.1727 Loss_G: -4.6210 (λL1=5.00)
Epoch [4/10] Loss_D: 0.3351 Loss_G: 2.7134 (λL1=5.00)
Epoch [5/10] Loss_D: -0.0079 Loss_G: -0.9835 (λL1=5.00)
Epoch [6/10] Loss_D: -0.1014 Loss_G: 1.9514 (λL1=5.00)
Epoch [7/10] Loss_D: -0.4894 Loss_G: 2.3617 (λL1=5.00)
Epoch [8/10] Loss_D: -1.3326 Loss_G: 0.9576 (λL1=5.00)
Epoch [9/10] Loss_D: -1.1614 Loss_G: 0.7955 (λL1=5.00)
Epoch [10/10] Loss_D: -1.1623 Loss_G: 0.5747 (λL1=5.00)
[Epoch 10] Val mean MSE=0.1507, Corr=0.3332
Nejlepší epocha pod

In [89]:
"""

def collect_runs(parent_dir="results"):

    runs = []
    for run in os.listdir(parent_dir):
        run_dir = os.path.join(parent_dir, run)
        files = sorted(glob.glob(os.path.join(run_dir, "validation_metrics_epoch*.csv")))
        if not files:
            continue

        last_file = files[-1]  # poslední epocha
        if os.path.getsize(last_file) == 0:
            print(f"Soubor {last_file} je prázdný – přeskočeno.")
            continue

        try:
            df = pd.read_csv(last_file)
        except Exception as e:
            print(f"Chyba při čtení {last_file}: {e}")
            continue

        if df.empty:
            print(f"Soubor {last_file} neobsahuje data – přeskočeno.")
            continue

        mean_mse = df["MSE"].mean()
        mean_corr = df["Corr"].mean()
        runs.append({
            "run": run,
            "mean_MSE": mean_mse,
            "mean_Corr": mean_corr,
            "epoch": int(last_file.split("epoch")[-1].split(".")[0])
        })

    if not runs:
        print("Nebyl nalezen žádný platný výsledek.")
        return pd.DataFrame()

    return pd.DataFrame(runs)



def compare_runs(parent_dir="results", save_name="comparison.png"):
    all_dfs = []
    for run in os.listdir(parent_dir):
        run_dir = os.path.join(parent_dir, run)
        files = sorted(glob.glob(os.path.join(run_dir, "validation_metrics_epoch*.csv")))
        if files:
            # vyber nejlepší epochu podle MSE
            best, _ = find_best_epoch(run_dir, metric="MSE", mode="min")
            best_file = os.path.join(run_dir, f"validation_metrics_epoch{int(best['epoch'])}.csv")

            # kontrola, zda soubor není prázdný
            if os.path.exists(best_file) and os.path.getsize(best_file) > 0:
                try:
                    df = pd.read_csv(best_file)
                    if not df.empty:
                        df["run"] = run
                        df["epoch"] = int(best["epoch"])
                        all_dfs.append(df)
                    else:
                        print(f"Soubor {best_file} je prázdný – přeskočeno.")
                except Exception as e:
                    print(f"Chyba při čtení {best_file}: {e}")
            else:
                print(f"Soubor {best_file} neexistuje nebo je prázdný – přeskočeno.")

    if not all_dfs:
        print("Nebyly nalezeny žádné výsledky k porovnání.")
        return None

    big_df = pd.concat(all_dfs)

    # boxploty
    plt.figure(figsize=(14, 6))
    plt.subplot(1, 2, 1)
    sns.boxplot(x="run", y="MSE", data=big_df)
    plt.xticks(rotation=45, ha="right")
    plt.title("MSE podle běhů (nejlepší epochy)")

    plt.subplot(1, 2, 2)
    sns.boxplot(x="run", y="Corr", data=big_df)
    plt.xticks(rotation=45, ha="right")
    plt.title("Korelace podle běhů (nejlepší epochy)")

    plt.tight_layout()
    plt.savefig(os.path.join(parent_dir, save_name), dpi=200)
    plt.close()

    print(f"Porovnání uloženo do {os.path.join(parent_dir, save_name)}")
    return big_df



if __name__ == "__main__":
    parent = "results"
    summary = collect_runs(parent)
    print("Souhrn běhů:")
    print(summary)

    big_df = compare_runs(parent)
    if big_df is not None:
        # histogramy všech běhů dohromady
        plt.figure(figsize=(12,5))
        plt.subplot(1,2,1)
        sns.histplot(big_df, x="MSE", hue="run", bins=40, kde=True, element="step")
        plt.title("Distribuce MSE")

        plt.subplot(1,2,2)
        sns.histplot(big_df, x="Corr", hue="run", bins=40, kde=True, element="step")
        plt.title("Distribuce korelace")

        plt.tight_layout()
        plt.savefig(os.path.join(parent, "comparison_histograms.png"), dpi=200)
        plt.close()

        print(f"Histogramy uloženy do {os.path.join(parent, 'comparison_histograms.png')}")
"""

Souhrn běhů:
                                  run  mean_MSE  mean_Corr  epoch
0   lambda10_epochs70_20250926_050105  0.052753   0.856714     70
1    lambda5_epochs30_20250926_044340  0.044945   0.901697     30
2   lambda20_epochs70_20250926_055553  0.043052   0.967255     70
3   lambda25_epochs60_20250926_082021  0.045667   0.967781     60
4   lambda15_epochs50_20250926_050919  0.042845   0.961768     50
5    lambda5_epochs70_20250926_044924  0.056843   0.722515     70
6   lambda20_epochs80_20250926_060026  0.049340   0.975182     80
7   lambda25_epochs50_20250926_074714  0.045906   0.956209     50
8   lambda15_epochs10_20250926_050554  0.055645   0.759857     10
9    lambda5_epochs50_20250926_044554  0.052858   0.847549     50
10  lambda20_epochs30_20250926_052954  0.044441   0.968381     30
11  lambda15_epochs80_20250926_052631  0.043211   0.965560     80
12  lambda25_epochs70_20250926_082247  0.048417   0.961242     70
13  lambda15_epochs60_20250926_051129  0.040948   0.946157     

  plt.tight_layout()


Histogramy uloženy do results/comparison_histograms.png


## ATTENTION!

## --> find best run/epoch functions
this part is still duplicited - it should fixed later

In [90]:
def find_best_epoch(run_dir, metric="MSE", mode="min"):
    """Find the best epoch within a single run based on the selected metric criterion."""
    files = sorted(glob.glob(os.path.join(run_dir, "validation_metrics_epoch*.csv")))
    if not files:
        return None, None

    results = []
    for f in files:
        epoch = int(os.path.basename(f).replace("validation_metrics_epoch", "").replace(".csv", ""))
        df = pd.read_csv(f)
        mean_val = df[metric].mean()
        results.append({"epoch": epoch, metric: mean_val, "file": f})

    df_all = pd.DataFrame(results)
    if mode == "min":
        best = df_all.loc[df_all[metric].idxmin()]
    else:
        best = df_all.loc[df_all[metric].idxmax()]

    return best, df_all


def find_best_run(parent_dir="results", metric="MSE", mode="min"):
    """Find the globally best run and epoch based on the selected metric."""
    best_overall = None
    all_runs = []

    for run in os.listdir(parent_dir):
        run_dir = os.path.join(parent_dir, run)
        if not os.path.isdir(run_dir):
            continue

        best, all_vals = find_best_epoch(run_dir, metric=metric, mode=mode)
        if best is None:
            continue

        # load configuration
        config_file = os.path.join(run_dir, "config.json")
        if os.path.exists(config_file):
            with open(config_file, "r") as f:
                config = json.load(f)
        else:
            config = {}

        record = {
            "run": run,
            "epoch": int(best["epoch"]),
            metric: float(best[metric]),
            "file": best["file"],
            "lambda_l1": config.get("lambda_l1"),
            "num_epochs": config.get("num_epochs")
        }
        all_runs.append(record)

        # update best run
        if best_overall is None:
            best_overall = record
        else:
            if (mode == "min" and record[metric] < best_overall[metric]) or \
               (mode == "max" and record[metric] > best_overall[metric]):
                best_overall = record

    df_all = pd.DataFrame(all_runs)
    return best_overall, df_all



best, summary = find_best_run("results", metric="MSE", mode="min")

print(f"Summary of all runs: \n {summary}")
print(f"Best run according to MSE: {best}")


Souhrn všech běhů:
                                  run  epoch       MSE  \
0   lambda10_epochs70_20250926_050105     30  0.044194   
1    lambda5_epochs30_20250926_044340     30  0.044945   
2   lambda20_epochs70_20250926_055553     70  0.043052   
3   lambda25_epochs60_20250926_082021     50  0.044733   
4   lambda15_epochs50_20250926_050919     40  0.042237   
5    lambda5_epochs70_20250926_044924     50  0.044944   
6   lambda20_epochs80_20250926_060026     40  0.043259   
7   lambda25_epochs50_20250926_074714     40  0.043244   
8   lambda15_epochs10_20250926_050554     10  0.055645   
9    lambda5_epochs50_20250926_044554     40  0.046894   
10  lambda20_epochs30_20250926_052954     30  0.044441   
11  lambda15_epochs80_20250926_052631     80  0.043211   
12  lambda25_epochs70_20250926_082247     60  0.043673   
13  lambda15_epochs60_20250926_051129     60  0.040948   
14  lambda10_epochs60_20250926_045909     50  0.042933   
15  lambda10_epochs20_20250926_045433     20  0.04489

In [91]:
import os
import glob
import json
import pandas as pd
from sklearn.metrics import mean_absolute_error
from scipy.stats import spearmanr
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean


def compute_extra_metrics(real, fake):
    """Compute additional metrics for a single signal."""
    mae = mean_absolute_error(real, fake)
    spear, _ = spearmanr(real, fake)
    # DTW (Dynamic Time Warping)
    dtw_dist, _ = fastdtw(real, fake, dist=euclidean)
    return mae, spear, dtw_dist


def evaluate_epoch(file_path):
    """Load validation CSV for one epoch and compute average metrics."""
    df = pd.read_csv(file_path)
    metrics = {
        "MSE": df["MSE"].mean(),
        "Corr": df["Corr"].mean()
    }

    maes, spears, dtws = [], [], []
    for idx, row in df.iterrows():
        # If you have saved real_signals.npy and fake_signals.npy → you could load them here.
        # For now, we only use MSE and Corr from the CSV.
        # Spearman and DTW could be computed during validation → placeholder here.
        pass

    # If DTW is too expensive to compute for all samples,
    # it’s better to do it during validation and store it in the CSV.
    # For simplicity, we only return MSE and Corr here.
    return metrics


def find_best_epoch(run_dir, metric="MSE", mode="min"):
    """Find the best epoch within a single run based on the selected metric."""
    files = sorted(glob.glob(os.path.join(run_dir, "validation_metrics_epoch*.csv")))
    if not files:
        return None, None

    results = []
    for f in files:
        epoch = int(os.path.basename(f).replace("validation_metrics_epoch", "").replace(".csv", ""))
        metrics = evaluate_epoch(f)
        results.append({"epoch": epoch, **metrics, "file": f})

    df_all = pd.DataFrame(results)
    if mode == "min":
        best = df_all.loc[df_all[metric].idxmin()]
    else:
        best = df_all.loc[df_all[metric].idxmax()]

    return best, df_all


def compute_score(row, weights=None):
    """
    Composite score = weighted combination of metrics.
    Lower is better.
    """
    if weights is None:
        # Default: MSE and (1 - Corr) equally weighted
        weights = {"MSE": 0.5, "Corr": 0.5}

    score = 0.0
    if "MSE" in weights:
        score += weights["MSE"] * row["MSE"]
    if "Corr" in weights:
        score += weights["Corr"] * (1 - row["Corr"])
    return score


def find_best_run(parent_dir="results", metric="MSE", mode="min", weights=None):
    """Find the globally best run based on the composite score."""
    all_runs = []
    best_overall = None

    for run in os.listdir(parent_dir):
        run_dir = os.path.join(parent_dir, run)
        if not os.path.isdir(run_dir):
            continue

        best, all_vals = find_best_epoch(run_dir, metric=metric, mode=mode)
        if best is None:
            continue

        # load configuration
        config_file = os.path.join(run_dir, "config.json")
        if os.path.exists(config_file):
            with open(config_file, "r") as f:
                config = json.load(f)
        else:
            config = {}

        record = {
            "run": run,
            "epoch": int(best["epoch"]),
            "MSE": float(best["MSE"]),
            "Corr": float(best["Corr"]),
            "score": compute_score(best, weights=weights),
            "file": best["file"],
            "lambda_l1": config.get("lambda_l1"),
            "num_epochs": config.get("num_epochs")
        }
        all_runs.append(record)

        if best_overall is None:
            best_overall = record
        else:
            if record["score"] < best_overall["score"]:
                best_overall = record

    df_all = pd.DataFrame(all_runs)
    return best_overall, df_all


if __name__ == "__main__":
    # Weights: equal for MSE and Corr
    weights = {"MSE": 0.5, "Corr": 0.5}

    best, summary = find_best_run("results", metric="MSE", mode="min", weights=weights)
    print("Summary of all runs:")
    print(summary)

    print(f"Best run according to composite score (MSE + Corr): {best}")


Souhrn všech běhů:
                                  run  epoch       MSE      Corr     score  \
0   lambda10_epochs70_20250926_050105     30  0.044194  0.939912  0.052141   
1    lambda5_epochs30_20250926_044340     30  0.044945  0.901697  0.071624   
2   lambda20_epochs70_20250926_055553     70  0.043052  0.967255  0.037899   
3   lambda25_epochs60_20250926_082021     50  0.044733  0.968899  0.037917   
4   lambda15_epochs50_20250926_050919     40  0.042237  0.945651  0.048293   
5    lambda5_epochs70_20250926_044924     50  0.044944  0.950072  0.047436   
6   lambda20_epochs80_20250926_060026     40  0.043259  0.953632  0.044813   
7   lambda25_epochs50_20250926_074714     40  0.043244  0.971839  0.035703   
8   lambda15_epochs10_20250926_050554     10  0.055645  0.759857  0.147894   
9    lambda5_epochs50_20250926_044554     40  0.046894  0.896649  0.075122   
10  lambda20_epochs30_20250926_052954     30  0.044441  0.968381  0.038030   
11  lambda15_epochs80_20250926_052631     80 