In [None]:
import pandas as pd
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

In [None]:
from ResampleGAN.core.ModelManager import ModelManager
from ResampleGAN.core.TrainingUtils import TrainingUtils, quick_setup
from ResampleGAN.utils.DatasetGenerator import DatasetGenerator, get_aligned_input_output
from ResampleGAN.core.ErrorUtils import compute_metrics
from dataclasses import dataclass
from typing import Dict, List, Tuple

In [None]:
@dataclass
class BaselineConfig:
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # model config
    dim_attention: int = 128
    num_heads: int = 4
    dim_feedforward: int = 128
    dropout: float = 0.1
    num_layers: int = 6
    with_bias: bool = False
    attention_type: List = None

    # training config
    n_epochs: int = 120
    batch_size: int = 64
    dim_input: int = 1
    hidden_dim: int = 16
    grad_clip_threshold: float = 10.0

    # optimizer config
    lr: float = 1e-3
    weight_decay: float = 1e-3

    # 颜色配置
    blue: str = '#0C5DA5'
    green: str = '#00B945'
    orange: str = '#FF9500'
    red: str = '#FF2C00'

    def __post_init__(self):
        if self.attention_type is None:
            self.attention_type = [["original"]*3, ["conv"]*0, ["freq"]*0]

In [None]:
config = BaselineConfig()
TrainingUtils.set_seed(42)
logger, device = quick_setup(log_file="baseline_training.log")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
def train_baseline_model(model_type: str, waveform: str, train_loader, valid_loader, config, logger):
    """Simplified baseline model training function"""

    # Create model manager and model
    manager = ModelManager(logger, config)
    model = manager.create_model(model_type)

    # Create optimizer and loss function
    optimizer = manager.create_optimizer(model)
    scheduler = manager.create_scheduler(optimizer)
    criterion = nn.MSELoss()

    # Print model information
    manager.print_model_summary(model, f"{model_type}")

    best_loss = float('inf')
    train_losses, valid_losses = [], []

    logger.info(f"Starting training {model_type} - {waveform}")

    for epoch in range(config.n_epochs):
        # Training phase
        model.train()
        total_train_loss = 0

        for batch in train_loader:
            optimizer.zero_grad()

            # Parse batch data
            x_input, x_initial, x_mask, x_output, mask, condition = batch
            x_input = x_input.to(config.device)
            x_initial = x_initial.to(config.device)
            x_output = x_output.to(config.device) if x_output is not None else None
            s_in, s_out = condition[0], condition[1]

            # Select inference method based on model type
            if model_type == "TCN":
                pred = model(x_initial)
            elif model_type in ["LSTM"]:
                pred = model(x_initial)
            elif model_type == "Transformer":
                pred = model(x_input, x_initial, s_in, s_out, with_clamp=False)

            # Calculate loss
            loss = criterion(pred, x_output)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_threshold)
            optimizer.step()

            total_train_loss += loss.item()

        # Validation phase
        model.eval()
        total_valid_loss = 0

        with torch.no_grad():
            for batch in valid_loader:
                x_input, x_initial, x_mask, x_output, mask, condition = batch
                x_input = x_input.to(config.device)
                x_initial = x_initial.to(config.device)
                x_output = x_output.to(config.device) if x_output is not None else None
                s_in, s_out = condition[0], condition[1]

                if model_type == "TCN":
                    pred = model(x_initial)
                elif model_type in ["LSTM"]:
                    pred = model(x_initial)
                elif model_type == "Transformer":
                    pred = model(x_input, x_initial, s_in, s_out, with_clamp=False)

                loss = criterion(pred, x_output)
                total_valid_loss += loss.item()

        # Calculate average loss
        avg_train_loss = total_train_loss / len(train_loader)
        avg_valid_loss = total_valid_loss / len(valid_loader)

        train_losses.append(avg_train_loss)
        valid_losses.append(avg_valid_loss)

        # Update learning rate
        if scheduler:
            scheduler.step()

        # Save best model
        if avg_valid_loss < best_loss:
            best_loss = avg_valid_loss
            save_path = f"../results/001_baseline_selection/model/best_{model_type}_model_{waveform}.pth"
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(model.state_dict(), save_path)

        # Print every 10 epochs
        if (epoch + 1) % 10 == 0:
            logger.info(f"Epoch {epoch+1}/{config.n_epochs} | Train: {avg_train_loss:.6f} | Valid: {avg_valid_loss:.6f}")

    logger.info(f"{model_type} - {waveform} training completed")
    return model, train_losses, valid_losses

In [None]:
os.makedirs("../results/001_baseline_selection/model", exist_ok=True)
os.makedirs("../results/001_baseline_selection/picture", exist_ok=True)

waveforms = ["line", "square", "sine", "triangle"]
# waveforms = ["line"]
model_types = ["Transformer", "TCN", "LSTM"]

In [None]:
for waveform in waveforms:
    logger.info(f"Processing waveform: {waveform}")

    # Load data
    df = pd.read_csv(f"../dataset/special_wave_{waveform}.csv")
    df["time"] = pd.to_datetime(df["time"])
    df.set_index("time", inplace=True)
    df = df[["value"]]

    # Process data
    df_input, df_output = get_aligned_input_output(df, s_in="15min", s_out="5min")
    dataset = DatasetGenerator(
        df_input=df_input,
        df_output=df_output,
        input_length=97,
        output_length=289,
        s_in="15min",
        s_out="5min",
        use_window=True
    )

    # Split preprocessed
    train_dataset, test_dataset, valid_dataset = DatasetGenerator.split_dataset(dataset)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)

    # Train each model
    for model_type in model_types:
        try:
            model, train_losses, valid_losses = train_baseline_model(model_type, waveform, train_loader, valid_loader, config, logger)
            logger.info(f"{model_type} training completed for {waveform}")
        except Exception as e:
            logger.info(f"{model_type} training failed: {e}")

        # Clean up memory
        TrainingUtils.cleanup_memory()

logger.info("All models trained successfully")

In [None]:
import scienceplots
plt.style.use(['science', "no-latex"])
df_rmse = pd.DataFrame(index=["TCN", "LSTM", "Transformer_simple"],
                       columns=["line", "square", "sine", "triangle"])
df_pcc = pd.DataFrame(index=["TCN", "LSTM", "Transformer_simple"],
                      columns=["line", "square", "sine", "triangle"])

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(14, 3), dpi=300)

for i, waveform in enumerate(["line", "square", "sine", "triangle"]):
    print(f"Evaluating waveform: {waveform}")

    # Load data
    df = pd.read_csv(f"../dataset/special_wave_{waveform}.csv")
    df["time"] = pd.to_datetime(df["time"])
    df.set_index("time", inplace=True)
    df = df[["value"]]

    df_input, df_output = get_aligned_input_output(df, s_in="15min", s_out="5min")
    dataset = DatasetGenerator(
        df_input=df_input,
        df_output=df_output,
        input_length=97,
        output_length=289,
        s_in="15min",
        s_out="5min",
        use_window=True,
    )

    _, _, valid_dataset = DatasetGenerator.split_dataset(dataset)
    valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False)

    # Get validation data
    for batch in valid_loader:
        x_input, x_initial, x_mask, x_output, mask, condition = batch
        x_input = x_input.to(config.device)
        x_initial = x_initial.to(config.device)
        x_output = x_output.to(config.device)
        s_in, s_out = condition[0], condition[1]
        break

    # Evaluate each model
    manager = ModelManager(logger, config)

    for model_type in ["TCN", "Transformer", "LSTM"]:
        try:
            # Load model
            model = manager.create_model(model_type)
            model_path = f"../results/001_baseline_selection/model/best_{model_type}_model_{waveform}.pth"

            if os.path.exists(model_path):
                model = manager.load_model(model, model_path)
                model.eval()

                # Inference
                with torch.no_grad():
                    if model_type == "TCN":
                        output = model(x_initial)
                    elif model_type in ["LSTM",]:
                        output = model(x_initial)
                    elif model_type == "Transformer":
                        output = model(x_input, x_initial, s_in, s_out, with_clamp=False)

                # Calculate metrics
                pred = output.cpu().numpy()
                real = x_output.cpu().numpy()

                rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, pred)
                df_rmse.loc[model_type, waveform] = rmse_model
                df_pcc.loc[model_type, waveform] = pcc_model

                # Plot
                color_map = {"TCN": config.green, "LSTM": config.orange, "Transformer": config.red}
                axes[i].plot(pred[0], label=model_type, color=color_map[model_type], linewidth=1.5)

                print(f"{model_type}: RMSE={rmse_model:.4f}, PCC={pcc_model:.4f}")
            else:
                print(f"{model_type} model file does not exist")

        except Exception as e:
            print(f"{model_type} evaluation failed: {e}")

    # Plot real values
    axes[i].plot(x_output.cpu().numpy()[0], label="Real", color=config.blue, linewidth=2)
    axes[i].set_title(waveform.capitalize(), fontsize=12)
    axes[i].grid(True, alpha=0.3)

    if i > 0:
        axes[i].set_ylabel('')
        axes[i].set_yticklabels([])

# Add legend
legend_elements = [
    plt.Line2D([0], [0], color=config.blue, linewidth=2, label='Real'),
    plt.Line2D([0], [0], color=config.red, linewidth=1.5, label='Transformer'),
    plt.Line2D([0], [0], color=config.orange, linewidth=1.5, label='LSTM'),
    plt.Line2D([0], [0], color=config.green, linewidth=1.5, label='TCN'),
]

axes[0].legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(2, -0.05),
               ncol=4, fontsize=10)

plt.tight_layout()
plt.savefig("../results/001_baseline_selection/picture/A_1_baseline_selection.pdf",
            dpi=300, bbox_inches='tight')
plt.show()

# Save results
df_rmse.to_csv("../results/001_baseline_selection/rmse_results.csv")
df_pcc.to_csv("../results/001_baseline_selection/pcc_results.csv")