In [1]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, ConcatDataset
from ResampleGAN.utils.DatasetGenerator import DatasetGenerator
from ResampleGAN.core.ModelManager import ModelManager
from ResampleGAN.TransGAN.Discriminator import Discriminator
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
from ResampleGAN.core.ErrorUtils import compute_metrics

In [3]:
def process_waveform(waveform, mc):
    if waveform == "electric":
        df = pd.read_csv(f"../dataset/preprocessed/electric.csv")
        df["time"] = pd.to_datetime(df["time"])
        df.set_index("time", inplace=True)
        df = df[["P_1"]]
        df_input = df.resample("15min").first().ffill()
        df_output = df.resample("5min").first().ffill()

        dataset = DatasetGenerator(
            df_input=df_input,
            df_output=df_output,
            input_length=97,
            output_length=289,
            s_in="15min",
            s_out="5min",
            use_window=True,
            extra_interpolate=True,
        )
        print(len(dataset))
        train_dataset, test_dataset,_ = DatasetGenerator.split_dataset(dataset, train_ratio=0.7, test_ratio=0.3, valid_ratio=0.0)
        train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
    elif waveform == "pv" :
        df = pd.read_csv(f"../dataset/preprocessed/pv.csv")
        df["datetime"] = pd.to_datetime(df["datetime"])
        df.set_index("datetime", inplace=True)
        df = df[["Gg_pyr"]]
        df_input = df.resample("15min").first().ffill()
        df_output = df.resample("5min").first().ffill()

        dataset = DatasetGenerator(
            df_input=df_input,
            df_output=df_output,
            input_length=97,
            output_length=289,
            s_in="15min",
            s_out="5min",
            use_window=True,
            extra_interpolate=True,
        )
        print(len(dataset))
        train_dataset, test_dataset,_ = DatasetGenerator.split_dataset(dataset, train_ratio=0.7, test_ratio=0.3, valid_ratio=0.0)
        train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
    elif waveform == "wind":
        df = pd.read_csv(f"../dataset/preprocessed/wind_1.csv")
        df["time"] = pd.to_datetime(df["time"])
        df.set_index("time", inplace=True)
        df = df[["observed"]]
        df_input = df.resample("15min").first().ffill()
        df_output = df.resample("5min").first().ffill()

        dataset_1 = DatasetGenerator(
            df_input=df_input,
            df_output=df_output,
            input_length=97,
            output_length=289,
            s_in="15min",
            s_out="5min",
            use_window=True,
            extra_interpolate=True,
        )

        df = pd.read_csv(f"../dataset/preprocessed/wind_2.csv")
        df["time"] = pd.to_datetime(df["time"])
        df.set_index("time", inplace=True)
        df = df[["observed"]]
        df_input = df.resample("15min").first().ffill()
        df_output = df.resample("5min").first().ffill()

        dataset_2 = DatasetGenerator(
            df_input=df_input,
            df_output=df_output,
            input_length=97,
            output_length=289,
            s_in="15min",
            s_out="5min",
            use_window=True,
            extra_interpolate=True,
        )
        dataset = ConcatDataset([dataset_1, dataset_2])
        print(len(dataset))
        train_dataset, test_dataset,_ = DatasetGenerator.split_dataset(dataset, train_ratio=0.7, test_ratio=0.3, valid_ratio=0.0)
        train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
    elif waveform == "mpv":
        df = pd.read_csv(f"../dataset/p_watt_15min.csv")
        df_input = pd.read_csv(f"../dataset/p_watt_15min.csv")
        df_output = pd.read_csv(f"../dataset/p_watt_5min.csv")
        df_input["time"] = pd.to_datetime(df_input["time"])
        df_input.set_index("time", inplace=True)
        df_input = df_input[["1a"]]
        df_output["time"] = pd.to_datetime(df_output["time"])
        df_output.set_index("time", inplace=True)
        df_output = df_output[["1a"]]
        dataset = DatasetGenerator(
            df_input=df_input,
            df_output=df_output,
            input_length=97,
            output_length=289,
            s_in="15min",
            s_out="5min",
            use_window=True,
            extra_interpolate=True,
        )
        print(len(dataset))
        train_dataset, test_dataset,_ = DatasetGenerator.split_dataset(dataset, train_ratio=0.7, test_ratio=0.3, valid_ratio=0.0)
        train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
    else:
        raise ValueError("Invalid waveform")

    return train_loader, test_loader

In [4]:
class ModelConfig:
    def __init__(self):
        # Device and environment configuration
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Generator configuration
        self.dim_attention = 128
        self.num_heads = 4
        self.dim_feedforward = 128
        self.dropout = 0.1
        self.num_layers = 6
        self.use_noise = False
        self.restore = False
        self.use_mask = False
        self.with_bias = False
        self.init_weights = False

        # Training configuration
        self.n_epochs = 360
        self.batch_size = 64
        self.dim_input = 1
        self.hidden_dim = 16
        self.grad_clip_threshold = 10
        self.best_loss = float('inf')
        self.no_improve_count = 0
        self.patience = 10
        self.use_early_stopping = False

        # Optimizer configuration
        self.optimizer_type = 'AdamW'
        self.scheduler_type = 'WarmupCosine'
        self.lr = 1e-3
        self.weight_decay = 1e-3

        # Scheduler configuration
        self.mode = 'min' # ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        self.factor = 0.5
        self.patience = 5
        self.total_steps = self.n_epochs * 10
        self.warmup_steps = int(self.total_steps * 0.1)

        # weights
        self.weights = {
            'cs_mse': 1,
            'cs_smoothness': 1,
            'cs_gradient': 1,
        }

In [5]:
mc = ModelConfig()

In [6]:
attention_types = {"self":[3,0,0], "conv":[0,3,0], "self+conv":[3,3,0], "self_conv":[3,3,0]}
waveforms = ["electric", "pv", "wind", "mpv"]
# attention_types = {"self":[3,0,0]}
# waveforms = ["electric"]

In [7]:
import os

def get_immediate_subdirectories(path):
    subdirectories = []
    for item in os.listdir(path):
        item_path = os.path.join(path, item)
        if os.path.isdir(item_path):
            subdirectories.append(item)
    return subdirectories

# folder_path = '../results/007_all/backup'
folder_path = '../results/002_all/reform'
subdirs = get_immediate_subdirectories(folder_path)

In [8]:
subdirs

['Seed_101']

In [9]:
result_order = ['Seed_49']
result_order = subdirs
# backup_path = "../results/007_all/backup"
backup_path = "../results/002_all/reform"

In [10]:
df_features = []

In [12]:
for result_version in result_order:
    path = f"{backup_path}/{result_version}"
    #####################################################################
    index_level1 = ["RMSE", "PCC", "MAG", "PH"]
    index_level2 = ["phase 1", "phase 2", "phase 3", "static", "only"]
    index_level3 = list(attention_types.keys()) + ["linear", "slinear", "cubic", "quad", "gp"]
    column_level1 = waveforms
    column_level2 = ["train", "test"]
    multi_index = pd.MultiIndex.from_product([index_level1, index_level2, index_level3])
    multi_column = pd.MultiIndex.from_product([column_level1, column_level2])
    df_result = pd.DataFrame(index=multi_index, columns=multi_column)
    df_result = df_result.sort_index()

    index_level1 = ["ACC", "PC", "RC", "F1"]
    index_level2 = ["phase 1", "phase 2", "phase 3", "static", "only"]
    index_level3 = list(attention_types.keys()) + ["linear", "slinear", "cubic", "quad", "gp"]
    column_level1 = waveforms
    column_level2 = ["train", "test"]
    multi_index = pd.MultiIndex.from_product([index_level1, index_level2, index_level3])
    multi_column = pd.MultiIndex.from_product([column_level1, column_level2])
    df_pred = pd.DataFrame(index=multi_index, columns=multi_column)
    df_pred = df_pred.sort_index()
    ########################################################################################
    model_type = "Transformer"
    for i, waveform in enumerate(waveforms):
        train_loader, test_loader = process_waveform(waveform, mc)
        ###################################################################################
        with (torch.no_grad()):
            for split, data_loader in {"train":train_loader, "test":test_loader}.items():
                print(result_version, waveform, split)
                for batch in data_loader:
                    x_input, x_initial, x_mask, x_output, mask, condition, interpolate = batch
                    x_input, x_initial, x_mask, x_output, mask, condition  = x_input.to(mc.device), x_initial.to(mc.device), x_mask.to(mc.device), x_output.to(mc.device), mask.to(mc.device), condition
                    batch_size = x_output.size(0)
                    s_in, s_out = condition[0], condition[1]

                    real = x_output.cpu().numpy()
                    linear = x_initial.cpu().numpy()
                    rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, linear)
                    df_result.loc[("RMSE", "static", "linear"), (waveform, split)] = rmse_model
                    df_result.loc[("PCC", "static", "linear"), (waveform, split)] = pcc_model

                    rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, interpolate[1].cpu().numpy())
                    df_result.loc[("RMSE", "static", "cubic"), (waveform, split)] = rmse_model
                    df_result.loc[("PCC", "static", "cubic"), (waveform, split)] = pcc_model

                    rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, interpolate[2].cpu().numpy())
                    df_result.loc[("RMSE", "static", "quad"), (waveform, split)] = rmse_model
                    df_result.loc[("PCC", "static", "quad"), (waveform, split)] = pcc_model
            ###################################################################################
                    model_type = "Transformer"
                    for j, (key, attention) in enumerate(attention_types.items()):
                        if key == "self_conv":
                            mc.attention_type = [["original"]*attention[0],["conv"]*attention[1],["freq"]*attention[2]]
                        else:
                            mc.attention_type = ["original"]*attention[0]+["conv"]*attention[1]+["freq"]*attention[2]
                        # phase 1
                        manager = ModelManager(None, mc)
                        model = manager.create_model(model_type)
                        generator_model_path = f"{path}/1_generator/best_generator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        print(generator_model_path)
                        model.load_state_dict(torch.load(generator_model_path))
                        model.eval()
                        for param in model.parameters():
                            param.requires_grad = False
                        output = model(x_input, x_initial, s_in, s_out, mask, x_mask)

                        pred = output.cpu().detach().numpy()
                        rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, pred)
                        df_result.loc[("RMSE", "phase 1", key), (waveform, split)] = rmse_model
                        df_result.loc[("PCC", "phase 1", key), (waveform, split)] = pcc_model
                        ###################################################################################
                        # phase 2
                        manager = ModelManager(None, mc)
                        model = manager.create_model(model_type)
                        generator_model_path = f"{path}/2_discriminator/best_generator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        print(generator_model_path)
                        model.load_state_dict(torch.load(generator_model_path))
                        model.eval()
                        for param in model.parameters():
                            param.requires_grad = False

                        output = model(x_input, x_initial, s_in, s_out, mask, x_mask)
                        pred = output.cpu().detach().numpy()
                        rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, pred)
                        df_result.loc[("RMSE", "phase 2", key), (waveform, split)] = rmse_model
                        df_result.loc[("PCC", "phase 2", key), (waveform, split)] = pcc_model

                        discriminator = Discriminator(num_layers=mc.num_layers, dim_input=mc.dim_input, dim_attention=mc.dim_attention,
                                          num_heads=mc.num_heads, dim_feedforward=mc.dim_feedforward, dropout=mc.dropout, attention_type=mc.attention_type,
                                          with_bias=mc.with_bias).to(mc.device)
                        discriminator_model_path = f"{path}/2_discriminator/best_discriminator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        print(discriminator_model_path)
                        discriminator.load_state_dict(torch.load(discriminator_model_path))
                        discriminator.eval()
                        for param in discriminator.parameters():
                            param.requires_grad = False

                        preds_on_fake = discriminator(output, s_out)
                        preds_on_real = discriminator(x_input, s_in)
                        preds_on_fake_bin = (torch.sigmoid(preds_on_fake).cpu() > 0.5).int().view(-1)
                        preds_on_real_bin = (torch.sigmoid(preds_on_real).cpu() > 0.5).int().view(-1)
                        all_preds = preds_on_real_bin.tolist() + preds_on_fake_bin.tolist()
                        all_labels = [1] * len(preds_on_real_bin) + [0] * len(preds_on_fake_bin)
                        acc = accuracy_score(all_labels, all_preds)
                        precision = precision_score(all_labels, all_preds, zero_division=0)
                        recall = recall_score(all_labels, all_preds, zero_division=0)
                        f1 = f1_score(all_labels, all_preds, zero_division=0)
                        df_pred.loc[("ACC", "phase 2", key), (waveform, split)] = acc
                        df_pred.loc[("PC", "phase 2", key), (waveform, split)] = precision
                        df_pred.loc[("RC", "phase 2", key), (waveform, split)] = recall
                        df_pred.loc[("F1", "phase 2", key), (waveform, split)] = f1
                        ###################################################################################
                        # phase 3
                        manager = ModelManager(None, mc)
                        model = manager.create_model(model_type)
                        generator_model_path = f"{path}/3_generator/best_generator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        print(generator_model_path)
                        model.load_state_dict(torch.load(generator_model_path))
                        model.eval()
                        for param in model.parameters():
                            param.requires_grad = False

                        output = model(x_input, x_initial, s_in, s_out, mask, x_mask)
                        pred = output.cpu().detach().numpy()
                        rmse_model, pcc_model, mag_model, phase_model= compute_metrics(real, pred)
                        df_result.loc[("RMSE", "phase 3", key), (waveform, split)] = rmse_model
                        df_result.loc[("PCC", "phase 3", key), (waveform, split)] = pcc_model
                        ###################################################################################
                        discriminator = Discriminator(num_layers=mc.num_layers, dim_input=mc.dim_input, dim_attention=mc.dim_attention,
                                          num_heads=mc.num_heads, dim_feedforward=mc.dim_feedforward, dropout=mc.dropout, attention_type=mc.attention_type,
                                          with_bias=mc.with_bias).to(mc.device)
                        discriminator_model_path = f"{path}/2_discriminator/best_discriminator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        print(discriminator_model_path)
                        discriminator.load_state_dict(torch.load(discriminator_model_path))
                        discriminator.eval()
                        for param in discriminator.parameters():
                            param.requires_grad = False

                        preds_on_fake = discriminator(output, s_out)
                        preds_on_real = discriminator(x_input, s_in)
                        preds_on_fake_bin = (torch.sigmoid(preds_on_fake).cpu() > 0.5).int().view(-1)
                        preds_on_real_bin = (torch.sigmoid(preds_on_real).cpu() > 0.5).int().view(-1)
                        all_preds = preds_on_real_bin.tolist() + preds_on_fake_bin.tolist()
                        all_labels = [1] * len(preds_on_real_bin) + [0] * len(preds_on_fake_bin)
                        acc = accuracy_score(all_labels, all_preds)
                        precision = precision_score(all_labels, all_preds, zero_division=0)
                        recall = recall_score(all_labels, all_preds, zero_division=0)
                        f1 = f1_score(all_labels, all_preds, zero_division=0)
                        df_pred.loc[("ACC", "phase 3", key), (waveform, split)] = acc
                        df_pred.loc[("PC", "phase 3", key), (waveform, split)] = precision
                        df_pred.loc[("RC", "phase 3", key), (waveform, split)] = recall
                        df_pred.loc[("F1", "phase 3", key), (waveform, split)] = f1
                        ###################################################################################

    df_result.dropna(inplace=True)
    df_result.to_csv(f"{path}/features.csv")
    df_features.append(df_result)
    df_pred.dropna(inplace=True)
    df_pred.to_csv(f"{path}/pred.csv")

235
Seed_101 electric train
../results/002_all/reform/Seed_101/1_generator/best_generator_self_electric_self_3_conv_0_freq_0.pth
../results/002_all/reform/Seed_101/2_discriminator/best_generator_self_electric_self_3_conv_0_freq_0.pth
../results/002_all/reform/Seed_101/2_discriminator/best_discriminator_self_electric_self_3_conv_0_freq_0.pth
../results/002_all/reform/Seed_101/3_generator/best_generator_self_electric_self_3_conv_0_freq_0.pth
../results/002_all/reform/Seed_101/2_discriminator/best_discriminator_self_electric_self_3_conv_0_freq_0.pth
../results/002_all/reform/Seed_101/1_generator/best_generator_conv_electric_self_0_conv_3_freq_0.pth
../results/002_all/reform/Seed_101/2_discriminator/best_generator_conv_electric_self_0_conv_3_freq_0.pth
../results/002_all/reform/Seed_101/2_discriminator/best_discriminator_conv_electric_self_0_conv_3_freq_0.pth
../results/002_all/reform/Seed_101/3_generator/best_generator_conv_electric_self_0_conv_3_freq_0.pth
../results/002_all/reform/Seed_

In [13]:
for waveform in waveforms:
    #####################################################################
    index_level1 = ["RMSE", "PCC", "MAG", "PH"]
    index_level2 = ["phase 1", "phase 2", "phase 3", "static", "only"]
    index_level3 = list(attention_types.keys()) + ["linear", "slinear", "cubic", "quad", "gp"]
    column_level1 = result_order
    column_level2 = ["train", "test"]
    multi_index = pd.MultiIndex.from_product([index_level1, index_level2, index_level3])
    multi_column = pd.MultiIndex.from_product([column_level1, column_level2])
    df_result = pd.DataFrame(index=multi_index, columns=multi_column)
    df_result = df_result.sort_index()

    index_level1 = ["ACC", "PC", "RC", "F1"]
    index_level2 = ["phase 1", "phase 2", "phase 3", "static", "only"]
    index_level3 = list(attention_types.keys()) + ["linear", "slinear", "cubic", "quad", "gp"]
    column_level1 = result_order
    column_level2 = ["train", "test"]
    multi_index = pd.MultiIndex.from_product([index_level1, index_level2, index_level3])
    multi_column = pd.MultiIndex.from_product([column_level1, column_level2])
    df_pred = pd.DataFrame(index=multi_index, columns=multi_column)
    df_pred = df_pred.sort_index()
    ########################################################################################
    model_type = "Transformer"
    for result_version in result_order:
        path = f"{backup_path}/{result_version}"
        train_loader, test_loader = process_waveform(waveform, mc)
        ###################################################################################
        with (torch.no_grad()):
            for split, data_loader in {"train":train_loader, "test":test_loader}.items():
                print(waveform,result_version,split)
                for batch in data_loader:
                    x_input, x_initial, x_mask, x_output, mask, condition, interpolate = batch
                    x_input, x_initial, x_mask, x_output, mask, condition  = x_input.to(mc.device), x_initial.to(mc.device), x_mask.to(mc.device), x_output.to(mc.device), mask.to(mc.device), condition
                    batch_size = x_output.size(0)
                    s_in, s_out = condition[0], condition[1]

                    real = x_output.cpu().numpy()
                    linear = x_initial.cpu().numpy()
                    rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, linear)
                    df_result.loc[("RMSE", "static", "linear"), (result_version, split)] = rmse_model
                    df_result.loc[("PCC", "static", "linear"), (result_version, split)] = pcc_model

                    rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, interpolate[1].cpu().numpy())
                    df_result.loc[("RMSE", "static", "cubic"), (result_version, split)] = rmse_model
                    df_result.loc[("PCC", "static", "cubic"), (result_version, split)] = pcc_model

                    rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, interpolate[2].cpu().numpy())
                    df_result.loc[("RMSE", "static", "quad"), (result_version, split)] = rmse_model
                    df_result.loc[("PCC", "static", "quad"), (result_version, split)] = pcc_model
            ###################################################################################
                    data_prefix = 2
                    # 读取模型
                    model_type = "Transformer"
                    for j, (key, attention) in enumerate(attention_types.items()):
                        x_input, x_initial, x_mask, x_output, mask, condition, interpolate = batch
                        x_input, x_initial, x_mask, x_output, mask, condition = x_input.to(mc.device), x_initial.to(mc.device), x_mask.to(mc.device), x_output.to(mc.device), mask.to(mc.device), condition
                        batch_size = x_output.size(0)
                        s_in, s_out = condition[0], condition[1]
                        if key == "self_conv":
                            mc.attention_type = [["original"]*attention[0],["conv"]*attention[1],["freq"]*attention[2]]
                        else:
                            mc.attention_type = ["original"]*attention[0]+["conv"]*attention[1]+["freq"]*attention[2]
                        # phase 1
                        manager = ModelManager(None, mc)
                        model = manager.create_model(model_type)
                        generator_model_path = f"{path}/1_generator/best_generator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        model.load_state_dict(torch.load(generator_model_path))
                        model.eval()
                        for param in model.parameters():
                            param.requires_grad = False
                        output = model(x_input, x_initial, s_in, s_out, mask, x_mask)

                        pred = output.cpu().detach().numpy()
                        rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, pred)
                        df_result.loc[("RMSE", "phase 1", key), (result_version, split)] = rmse_model
                        df_result.loc[("PCC", "phase 1", key), (result_version, split)] = pcc_model
                        ###################################################################################
                        # phase 2
                        x_input, x_initial, x_mask, x_output, mask, condition, interpolate = batch
                        x_input, x_initial, x_mask, x_output, mask, condition = x_input.to(mc.device), x_initial.to(mc.device), x_mask.to(mc.device), x_output.to(mc.device), mask.to(mc.device), condition
                        batch_size = x_output.size(0)
                        s_in, s_out = condition[0], condition[1]

                        manager = ModelManager(None, mc)
                        model = manager.create_model(model_type)
                        generator_model_path = f"{path}/2_discriminator/best_generator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        model.load_state_dict(torch.load(generator_model_path))
                        model.eval()
                        for param in model.parameters():
                            param.requires_grad = False

                        output = model(x_input, x_initial, s_in, s_out, mask, x_mask)
                        pred = output.cpu().detach().numpy()
                        rmse_model, pcc_model, mag_model, phase_model = compute_metrics(real, pred)
                        df_result.loc[("RMSE", "phase 2", key), (result_version, split)] = rmse_model
                        df_result.loc[("PCC", "phase 2", key), (result_version, split)] = pcc_model

                        # 判别器
                        discriminator = Discriminator(num_layers=mc.num_layers, dim_input=mc.dim_input, dim_attention=mc.dim_attention,
                                          num_heads=mc.num_heads, dim_feedforward=mc.dim_feedforward, dropout=mc.dropout, attention_type=mc.attention_type,
                                          with_bias=mc.with_bias).to(mc.device)
                        discriminator_model_path = f"{path}/2_discriminator/best_discriminator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        discriminator.load_state_dict(torch.load(discriminator_model_path))
                        discriminator.eval()
                        for param in discriminator.parameters():
                            param.requires_grad = False

                        preds_on_fake = discriminator(output, s_out)
                        preds_on_real = discriminator(x_input, s_in)
                        preds_on_fake_bin = (torch.sigmoid(preds_on_fake).cpu() > 0.5).int().view(-1)
                        preds_on_real_bin = (torch.sigmoid(preds_on_real).cpu() > 0.5).int().view(-1)
                        all_preds = preds_on_real_bin.tolist() + preds_on_fake_bin.tolist()
                        all_labels = [1] * len(preds_on_real_bin) + [0] * len(preds_on_fake_bin)
                        acc = accuracy_score(all_labels, all_preds)
                        precision = precision_score(all_labels, all_preds, zero_division=0)
                        recall = recall_score(all_labels, all_preds, zero_division=0)
                        f1 = f1_score(all_labels, all_preds, zero_division=0)
                        df_pred.loc[("ACC", "phase 2", key), (result_version, split)] = acc
                        df_pred.loc[("PC", "phase 2", key), (result_version, split)] = precision
                        df_pred.loc[("RC", "phase 2", key), (result_version, split)] = recall
                        df_pred.loc[("F1", "phase 2", key), (result_version, split)] = f1
                        ###################################################################################
                        # phase 3
                        x_input, x_initial, x_mask, x_output, mask, condition, interpolate = batch
                        x_input, x_initial, x_mask, x_output, mask, condition = x_input.to(mc.device), x_initial.to(mc.device), x_mask.to(mc.device), x_output.to(mc.device), mask.to(mc.device), condition
                        batch_size = x_output.size(0)
                        s_in, s_out = condition[0], condition[1]

                        manager = ModelManager(None, mc)
                        model = manager.create_model(model_type)
                        generator_model_path = f"{path}/3_generator/best_generator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        model.load_state_dict(torch.load(generator_model_path))
                        model.eval()
                        for param in model.parameters():
                            param.requires_grad = False

                        output = model(x_input, x_initial, s_in, s_out, mask, x_mask)
                        pred = output.cpu().detach().numpy()
                        rmse_model, pcc_model, mag_model, phase_model= compute_metrics(real, pred)
                        df_result.loc[("RMSE", "phase 3", key), (result_version, split)] = rmse_model
                        df_result.loc[("PCC", "phase 3", key), (result_version, split)] = pcc_model
                        ###################################################################################
                        # 判别器
                        x_input, x_initial, x_mask, x_output, mask, condition, interpolate = batch
                        x_input, x_initial, x_mask, x_output, mask, condition = x_input.to(mc.device), x_initial.to(mc.device), x_mask.to(mc.device), x_output.to(mc.device), mask.to(mc.device), condition
                        batch_size = x_output.size(0)
                        s_in, s_out = condition[0], condition[1]
                        discriminator = Discriminator(num_layers=mc.num_layers, dim_input=mc.dim_input, dim_attention=mc.dim_attention,
                                          num_heads=mc.num_heads, dim_feedforward=mc.dim_feedforward, dropout=mc.dropout, attention_type=mc.attention_type,
                                          with_bias=mc.with_bias).to(mc.device)
                        discriminator_model_path = f"{path}/2_discriminator/best_discriminator_{key}_{waveform}_self_{attention[0]}_conv_{attention[1]}_freq_{attention[2]}.pth"
                        discriminator.load_state_dict(torch.load(discriminator_model_path))
                        discriminator.eval()
                        for param in discriminator.parameters():
                            param.requires_grad = False

                        preds_on_fake = discriminator(output, s_out)
                        preds_on_real = discriminator(x_input, s_in)
                        preds_on_fake_bin = (torch.sigmoid(preds_on_fake).cpu() > 0.5).int().view(-1)
                        preds_on_real_bin = (torch.sigmoid(preds_on_real).cpu() > 0.5).int().view(-1)
                        all_preds = preds_on_real_bin.tolist() + preds_on_fake_bin.tolist()
                        all_labels = [1] * len(preds_on_real_bin) + [0] * len(preds_on_fake_bin)
                        acc = accuracy_score(all_labels, all_preds)
                        precision = precision_score(all_labels, all_preds, zero_division=0)
                        recall = recall_score(all_labels, all_preds, zero_division=0)
                        f1 = f1_score(all_labels, all_preds, zero_division=0)
                        df_pred.loc[("ACC", "phase 3", key), (result_version, split)] = acc
                        df_pred.loc[("PC", "phase 3", key), (result_version, split)] = precision
                        df_pred.loc[("RC", "phase 3", key), (result_version, split)] = recall
                        df_pred.loc[("F1", "phase 3", key), (result_version, split)] = f1
                        ###################################################################################


    df_result.dropna(inplace=True)
    df_result.to_csv(f"{backup_path}/features_{waveform}.csv")
    df_features.append(df_result)
    df_pred.dropna(inplace=True)
    df_pred.to_csv(f"{backup_path}/pred_{waveform}.csv")

235
electric Seed_101 train
electric Seed_101 test
363
pv Seed_101 train
pv Seed_101 test
81
wind Seed_101 train
wind Seed_101 test
364
mpv Seed_101 train
mpv Seed_101 test
