In [None]:
import os
import pandas as pd
import torch
import pickle
import time
from torch.utils.data import DataLoader, ConcatDataset
from typing import Dict, List, Tuple

In [None]:
from ResampleGAN.core.TrainerFactory import TrainerFactory
from ResampleGAN.core.TrainingConfig import TrainingConfig
from ResampleGAN.core.TrainingUtils import TrainingUtils, quick_setup
from ResampleGAN.utils.DatasetGenerator import DatasetGenerator

In [None]:
seed = 101
TrainingUtils.set_seed(seed)

In [None]:
now = f"Seed_{seed}"
base_dir = f"../results/002_all/reform/{now}"
for phase in [1, 2, 3]:
    phase_names = {1: "generator", 2: "discriminator", 3: "generator"}
    os.makedirs(f"{base_dir}/{phase}_{phase_names[phase]}", exist_ok=True)

In [None]:
logger, device = quick_setup(seed=seed, log_file="joint_training.log")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

training_strategy = "skip"  # "skip" or "overwrite"
attention_types = {
    "self": [3, 0, 0],
    "conv": [0, 3, 0],
    "self+conv": [3, 3, 0],
    "self_conv": [3, 3, 0]
}
waveforms = ["electric", "pv", "wind", "mpv"]

In [None]:
def process_waveform_unified(waveform: str, batch_size: int = 16) -> Tuple[DataLoader, DataLoader]:
    if waveform == "electric":
        df = pd.read_csv("../dataset/preprocessed/electric.csv")
        df["time"] = pd.to_datetime(df["time"])
        df.set_index("time", inplace=True)
        df = df[["P_1"]]
        df_input = df.resample("15min").first().ffill()
        df_output = df.resample("5min").first().ffill()

        dataset = DatasetGenerator(
            df_input=df_input, df_output=df_output,
            input_length=97, output_length=289,
            s_in="15min", s_out="5min", use_window=True
        )

    elif waveform == "pv":
        df = pd.read_csv("../dataset/preprocessed/pv.csv")
        df["datetime"] = pd.to_datetime(df["datetime"])
        df.set_index("datetime", inplace=True)
        df = df[["Gg_pyr"]]
        df_input = df.resample("15min").first().ffill()
        df_output = df.resample("5min").first().ffill()

        dataset = DatasetGenerator(
            df_input=df_input, df_output=df_output,
            input_length=97, output_length=289,
            s_in="15min", s_out="5min", use_window=True
        )

    elif waveform == "wind":
        # 处理两个风力数据文件
        datasets = []
        for i in [1, 2]:
            df = pd.read_csv(f"../dataset/preprocessed/wind_{i}.csv")
            df["time"] = pd.to_datetime(df["time"])
            df.set_index("time", inplace=True)
            df = df[["observed"]]
            df_input = df.resample("15min").first().ffill()
            df_output = df.resample("5min").first().ffill()

            dataset_part = DatasetGenerator(
                df_input=df_input, df_output=df_output,
                input_length=97, output_length=289,
                s_in="15min", s_out="5min", use_window=True
            )
            datasets.append(dataset_part)

        dataset = ConcatDataset(datasets)

    elif waveform == "mpv":
        df_input = pd.read_csv("../dataset/p_watt_15min.csv")
        df_output = pd.read_csv("../dataset/p_watt_5min.csv")

        df_input["time"] = pd.to_datetime(df_input["time"])
        df_input.set_index("time", inplace=True)
        df_input = df_input[["1a"]]

        df_output["time"] = pd.to_datetime(df_output["time"])
        df_output.set_index("time", inplace=True)
        df_output = df_output[["1a"]]

        dataset = DatasetGenerator(
            df_input=df_input, df_output=df_output,
            input_length=97, output_length=289,
            s_in="15min", s_out="5min", use_window=True
        )
    else:
        raise ValueError(f"Invalid waveform: {waveform}")

    # 分割数据集
    train_dataset, test_dataset, _ = DatasetGenerator.split_dataset(
        dataset, train_ratio=0.7, test_ratio=0.3, valid_ratio=0.0
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
def create_phase_config(phase: int, waveform: str, key: str, attention: List[int]) -> TrainingConfig:
    """Create training configuration for specified phase"""

    # Set attention type
    if key == "self_conv":
        attention_type = [["original"]*attention[0], ["conv"]*attention[1], ["freq"]*attention[2]]
    else:
        attention_type = ["original"]*attention[0] + ["conv"]*attention[1] + ["freq"]*attention[2]

    # Adjust batch size
    batch_size = 8 if waveform == "wind" else 16

    # Set weights based on phase
    if phase == 1:
        weights = {
            'mse': 1,
            'smoothness': 1,
            'gradient': 1,
        }
    else:  # phase 2 and 3
        weights = {
            'mse': 1,
            'feature_space': 1,
            'smoothness': 1,
            'gradient': 1,
        }

    return TrainingConfig(
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        n_epochs=120,
        batch_size=batch_size,
        lr=2e-4,
        weight_decay=0.01,
        grad_clip_threshold=10,
        dim_input=1,
        dim_attention=128,
        num_heads=4,
        dim_feedforward=128,
        dropout=0.1,
        num_layers=6,
        attention_type=attention_type,
        with_bias=False,
        weights=weights,
        optimizer_type='AdamW',
        scheduler_type='WarmupCosine',
        lambda_gan=0.1,
        critic=1,
        use_early_stopping=False,
        patience=10
    )

In [None]:
def train_single_phase_experiment(phase: int, waveform: str, key: str, attention: List[int], now: str, training_strategy: str, logger) -> bool:
    """
    Simplified function for training single phase experiment

    Returns:
        bool: Whether training was successful
    """

    # Check if should skip
    if training_strategy != "overwrite":
        phase_names = {1: "generator", 2: "discriminator", 3: "generator"}
        model_path = f"../results/002_all/reform/{now}/{phase}_{phase_names[phase]}/best_generator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
        if os.path.exists(model_path):
            logger.info(f"⏭️ Skip existing model: Phase {phase} - {key}_{waveform}")
            return True

    start_time = time.time()
    logger.info(f"Start training Phase {phase}: {key}_{waveform} - {attention}")

    # Create configuration
    config = create_phase_config(phase, waveform, key, attention)

    # Create data loaders
    train_loader, test_loader = process_waveform_unified(waveform, config.batch_size)

    # Set save directory
    phase_names = {1: "generator", 2: "discriminator", 3: "generator"}
    save_dir = f"../results/002_all/reform/{now}/{phase}_{phase_names[phase]}"
    base_save_dir = f"../results/002_all/reform/{now}"

    # Create trainer
    trainer = TrainerFactory.create_trainer(phase, config, save_dir, logger, key, attention, waveform, now, base_save_dir)

    # Start training
    history = trainer.train(train_loader, test_loader)

    # Save training history
    history_filename = f"losses_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pkl"
    history_path = os.path.join(save_dir, history_filename)
    with open(history_path, "wb") as f:
        pickle.dump(history, f)

    end_time = time.time()
    execution_time = end_time - start_time
    logger.info(f"Phase {phase} - {key}_{waveform} training completed, time taken: {execution_time:.4f} seconds")

    return True

In [None]:
# Cell 6: Phase 1 - Generator Pretraining
def run_phase1():
    """Execute Phase 1 training"""
    logger.info("="*60)
    logger.info("Starting Phase 1: Generator Pretraining")
    logger.info("="*60)

    success_count = 0
    total_experiments = len(waveforms) * len(attention_types)

    for waveform in waveforms:
        logger.info(f"Processing waveform: {waveform}")

        for key, attention in attention_types.items():
            success = train_single_phase_experiment(1, waveform, key, attention, now, training_strategy, logger)
            if success:
                success_count += 1

    logger.info(f"Phase 1 completed! Success rate: {success_count}/{total_experiments}")
    return success_count, total_experiments

phase1_success, phase1_total = run_phase1()

In [None]:
# Cell 7: Phase 2 - Joint Training
def run_phase2():
    """Execute Phase 2 training"""
    logger.info("="*60)
    logger.info("Starting Phase 2: Joint Training")
    logger.info("="*60)

    success_count = 0
    total_experiments = len(waveforms) * len(attention_types)

    for waveform in waveforms:
        logger.info(f"Processing waveform: {waveform}")

        for key, attention in attention_types.items():
            success = train_single_phase_experiment(2, waveform, key, attention, now, training_strategy, logger)
            if success:
                success_count += 1

    logger.info(f"Phase 2 completed! Success rate: {success_count}/{total_experiments}")
    return success_count, total_experiments

# Execute Phase 2
phase2_success, phase2_total = run_phase2()

In [None]:
def run_phase3():
    """Execute Phase 3 training"""
    logger.info("="*60)
    logger.info("Starting Phase 3: Generator Fine-tuning")
    logger.info("="*60)

    success_count = 0
    total_experiments = len(waveforms) * len(attention_types)

    for waveform in waveforms:
        logger.info(f"Processing waveform: {waveform}")

        for key, attention in attention_types.items():
            success = train_single_phase_experiment(3, waveform, key, attention, now, training_strategy, logger)
            if success:
                success_count += 1

    logger.info(f"Phase 3 completed! Success rate: {success_count}/{total_experiments}")
    return success_count, total_experiments

# Execute Phase 3
phase3_success, phase3_total = run_phase3()