# 一、数据处理

In [None]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_pil_image
from PIL import Image
from natsort import natsorted
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
import os
import numpy as np


# 假设 time_steps 已经定义
time_steps = 6
image_size = 128

# 1.1 固定数值数据

def read_numeric_data(file_path):
    df = pd.read_csv(file_path)
    return df[['grid_length', 'grid_thickness', 'depth', 'temperature', 'pressure']].values


file_path = r'3000数值型数据.csv'
fixed_numeric_data = read_numeric_data(file_path)


for i in range(fixed_numeric_data.shape[1]):
    scaler_fixed = MinMaxScaler()
    column_data = fixed_numeric_data[:, i].reshape(-1, 1)
    column_data = scaler_fixed.fit_transform(column_data)
    fixed_numeric_data[:, i] = column_data.flatten()

# 将数据从 [0, 1] 转换到 [-1, 1]
fixed_numeric_data = torch.from_numpy(fixed_numeric_data).float()  # 关键修改
fixed_numeric_data = fixed_numeric_data * 2 - 1

print("固定数据特征:", fixed_numeric_data.shape)
print("前六行特征值:")
print(fixed_numeric_data[:6, :])

# 1.2读取注入量数据
def read_injection_data(folder_path):
    data_files = natsorted(os.listdir(folder_path))
    all_data = []
    print(f"正在读取前缘数据，文件顺序如下：")
    for file in data_files:
        if not file.endswith('.txt'):
            continue
        file_path = os.path.join(folder_path, file)
        print(file_path)
        try:
            with open(file_path, 'r') as f:
                content = f.read().strip().split()
                data = [float(x) for x in content]
                if len(data) != time_steps:
                    print(f"文件 {file_path} 中的数据数量不是 {time_steps}，请检查。")
                    continue
                all_data.append(np.array(data))
        except Exception as e:
            print(f"读取文件 {file_path} 时出现问题: {e}")
    all_data = np.array(all_data)

    # 打印前五行数据
    if all_data.size > 0:
        print("读取到的数据的前五行如下：")
        print(all_data[:5])

    return all_data

# 1.3读取压力数据
def read_pressure_data(folder_path):
    data_files = natsorted(os.listdir(folder_path))
    all_data = []
    print(f"正在读取压力数据，文件顺序如下：")
    for file in data_files:
        if not file.endswith('.txt'):
            continue
        file_path = os.path.join(folder_path, file)
        print(file_path)
        try:
            with open(file_path, 'r') as f:
                content = f.read().strip().split()
                data = [float(x) for x in content]
                if len(data) != time_steps:
                    print(f"文件 {file_path} 中的数据数量不是 {time_steps}，请检查。")
                    continue
                all_data.append(np.array(data))
        except Exception as e:
            print(f"读取文件 {file_path} 时出现问题: {e}")
    return np.array(all_data)

# 1.4读取饱和度变异系数数据
def read_saturation_variation_data(folder_path):
    data_files = natsorted(os.listdir(folder_path))
    all_data = []
    success_count = 0
    skip_count = 0
    print(f"正在读取饱和度变异系数数据，文件顺序如下：")
    for file in data_files:
        if not file.endswith('.txt'):
            continue
        file_path = os.path.join(folder_path, file)
        print(file_path)
        try:
            with open(file_path, 'r') as f:
                lines = f.readlines()
                data = []
                for line in lines:
                    line = line.strip()
                    if '%' in line:
                        try:
                            line = line.replace('%', '').strip()
                            line = float(line) / 100.0
                        except ValueError:
                            print(f"无法将 {line} 转换为有效的百分比数据，请检查文件 {file_path}。")
                            continue
                    else:
                        try:
                            line = float(line)
                        except ValueError:
                            print(f"无法将 {line} 转换为有效的数值，请检查文件 {file_path}。")
                            continue
                    data.append(line)
                if len(data) != time_steps:
                    print(f"文件 {file_path} 中的数据数量不是 {time_steps}，请检查。")
                    skip_count += 1
                    continue
                all_data.append(np.array(data))
                success_count += 1
        except Exception as e:
            print(f"读取文件 {file_path} 时出现问题: {e}")
            skip_count += 1
    print(f"成功读取 {success_count} 个文件，跳过 {skip_count} 个文件。")
    all_data = np.array(all_data)

    # 打印前五行数据
    if all_data.size > 0:
        print("读取到的数据的前五行如下：")
        print(all_data[:5])

    return all_data


injection_data = read_injection_data(r'3000注入量')
pressure_data = read_pressure_data(r'3000压力')
saturation_data = read_saturation_variation_data(r'3000饱和度变异系数')


# 按所有样本的同一特征在所有时间步上进行归一化
def normalize_all_samples(data):
    flat_data = data.flatten()
    min_val = np.min(flat_data)
    max_val = np.max(flat_data)
    diff = max_val - min_val
    if diff == 0:
        diff = 1  # 避免除零错误
    normalized_data = (data - min_val) / diff
    # 将数据从 [0, 1] 转换到 [-1, 1]
    normalized_data = normalized_data * 2 - 1
    return normalized_data


injection_data = normalize_all_samples(injection_data)
print("归一化后 injection_data 前五行第一个特征值:")
print(injection_data[:5, 0])

pressure_data = normalize_all_samples(pressure_data)
print("归一化后 pressure_data 前五行第一个特征值:")
print(pressure_data[:5, 0])

saturation_data = normalize_all_samples(saturation_data)
print("归一化后 saturation_data 前五行第一个特征值:")
print(saturation_data[:5, 0])

changing_numeric_data = np.stack([
    injection_data,
    pressure_data,
    saturation_data
], axis=2)

def read_time_series_images(root_dir):
    """
    读取指定根目录下的时间序列图像数据
    :param root_dir: 根目录路径
    :return: 时间序列图像数据数组
    """
    image_time_series_data = []
    for folder in natsorted(os.listdir(root_dir)):
        folder_path = os.path.join(root_dir, folder)
        if os.path.isdir(folder_path):
            time_series_images = []
            for img_file in natsorted(os.listdir(folder_path)):
                img_path = os.path.join(folder_path, img_file)
                print(f"正在读取文件: {img_path}")
                try:
                    img = Image.open(img_path).convert('L')  # 转换为灰度模式
                    img = img.resize((image_size, image_size))
                    img = np.array(img)
                    # 归一化处理到 [0, 1]
                    img = img / 255.0
                    # 将数据从 [0, 1] 转换到 [-1, 1]
                    img = img * 2 - 1
                    img = np.expand_dims(img, axis=0)  # 添加通道维度
                    time_series_images.append(img)
                except Exception as e:
                    print(f"Failed to read image: {img_path}, Error: {e}")

            if len(time_series_images) != time_steps:
                print(f"Folder {folder} has {len(time_series_images)} images, expected {time_steps}. Skipping this folder.")
            else:
                image_time_series_data.append(time_series_images)

    # 将列表转换为NumPy数组并调整维度
    image_time_series_data = np.array(image_time_series_data)
    # 检查数组维度
    if image_time_series_data.ndim != 5:
        print(f"数组维度为 {image_time_series_data.ndim}，不等于 5，无法进行转置。")
        return None
    image_time_series_data = np.transpose(image_time_series_data, (0, 1, 2, 3, 4))  # (B, T, C, H, W)
    return image_time_series_data


# 调用函数读取数据
root_dir = r'3000饱和度图片'
img_time_series = read_time_series_images(root_dir)

# 1.5固定图片数据：渗透率图片、井位图片
# 读取图片数据
def read_image_data(folder_path, image_size, data_type):
    image_files = natsorted(os.listdir(folder_path))
    images = []
    print(f"正在读取 {data_type} 图片数据，文件顺序如下：")
    for file in image_files:
        file_path = os.path.join(folder_path, file)
        print(file_path)
        try:
            # 打开图片并转换为灰度模式
            image = Image.open(file_path).convert('L')
            # 调整图片大小为指定尺寸
            image = image.resize((image_size, image_size))
            # 转换为 numpy 数组，并确保数据类型为浮点型
            image = np.array(image)
            # 归一化处理到 [0, 1]
            image = image / 255.0
            # 将数据从 [0, 1] 转换到 [-1, 1]
            image = image * 2 - 1
            images.append(image)
        except Exception as e:
            print(f"读取图片 {file_path} 时出错: {e}")
    return np.array(images)


# 示例调用
permeability_folder = r'3000渗透率图集'
well_location_folder = r'3000井位可视化'

permeability_images = read_image_data(permeability_folder, image_size, '3000渗透率图集')
well_location_images = read_image_data(well_location_folder, image_size, '3000井位可视化')

# 调整维度顺序，添加通道维度
permeability_images = np.expand_dims(permeability_images, axis=1)
well_location_images = np.expand_dims(well_location_images, axis=1)

perm_images = torch.from_numpy(permeability_images).float()
well_images = torch.from_numpy(well_location_images).float()
print("转换为张量后渗透率图片数据形状:", perm_images.shape)
print("转换为张量后井位图片数据形状:", well_images.shape)

# 合并渗透率图片和井位图片为2通道特征矩阵
combined_images = torch.cat([perm_images, well_images], dim=1)
print("合并后2通道特征矩阵形状:", combined_images.shape)

# 数据转换与检查
changing_numeric_data = torch.from_numpy(changing_numeric_data).float()  # (B, T, 3)
print("转换为张量后可变数值数据形状:", changing_numeric_data.shape)
print("转换为张量后可变数值数据前五行第一个时间步的第一个特征值:")
print(changing_numeric_data[:1, :6, 0])

img_time_series = torch.from_numpy(img_time_series).float()  # (B, T, C, H, W)
print("转换为张量后时间序列图像数据形状:", img_time_series.shape)


first_six_rows = changing_numeric_data[:6]
print("changing_numeric_data 的前 6 行数据:")
print(first_six_rows)

# 二、模型训练+验证

In [None]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_pil_image
from PIL import Image
from natsort import natsorted
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
import os
import numpy as np
from sklearn.model_selection import train_test_split
import random
from pytorch_msssim import SSIM
import csv


# 残差块
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(dropout_rate)  # 添加Dropout 
        self.relu = nn.ReLU()
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=False),
            nn.BatchNorm2d(out_channels)
        ) if in_channels != out_channels or stride != 1 else None

    def forward(self, x):
        identity = x
        out = self.conv(x)
        out = self.bn(out)
        out = self.dropout(out)  # 在激活前应用Dropout
        out = self.relu(out)
        if self.downsample is not None:
            identity = self.downsample(identity)
        out = out + identity
        out = self.relu(out)
        return out


# 可学习位置编码模块
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=6):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        # 定义可学习的位置编码参数（位置数x维度）
        self.pos_emb = nn.Parameter(torch.zeros(max_len, d_model))  
        nn.init.trunc_normal_(self.pos_emb, std=0.02)  # 正态分布初始化
        
    def forward(self, x):
        B, T, D = x.shape  # (batch, time_steps, d_model)
        # 生成位置编码并扩展到批次维度
        pos_enc = self.pos_emb[:T].unsqueeze(0).repeat(B, 1, 1)  
        x = x + pos_enc  # 相加位置编码
        return self.dropout(x)

# 多模态模型
class MultiModalGasModel(nn.Module):
    def __init__(self, img_channels=2, time_features=3, fixed_features=5,
                 embed_dim=128, nhead=8, img_size=128, time_steps=6):
        super().__init__()
        self.img_size = img_size
        self.embed_dim = embed_dim
        self.time_steps = time_steps

        # 空间编码器（添加Dropout）
        self.img_encoder = nn.Sequential(
            nn.Conv2d(img_channels, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            ResidualBlock(32, 32, dropout_rate=0.1),
            nn.Dropout(0.1),  # 额外的Dropout层
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64, 64, dropout_rate=0.1),
            nn.Dropout(0.1),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128, 128, dropout_rate=0.1),
            nn.Dropout(0.1),
            nn.Conv2d(128, embed_dim, 4, 2, 1, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(),
        )

        # 时间编码器（增强Transformer正则化）
        self.time_embed = nn.Linear(time_features, embed_dim)
        self.time_pos = LearnablePositionalEncoding(2 * embed_dim, max_len=time_steps)  # 可学习位置编码
        self.time_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=2 * embed_dim,
                nhead=nhead,
                batch_first=True,
                dropout=0.1,  # Transformer内部Dropout
                activation='gelu'  # 使用更平滑的激活函数
            ),
            num_layers=6
        )

        # 固定编码器
        self.fixed_embed = nn.Linear(fixed_features, embed_dim)

        # 融合 Transformer
        self.fusion_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=2 * embed_dim,
                nhead=nhead,
                batch_first=True
            ),
            num_layers=4
        )

        # 解码器（添加谱归一化）
        self.decoder = nn.Sequential(
            nn.utils.spectral_norm(  # 谱归一化防止梯度爆炸
                nn.ConvTranspose2d(2 * embed_dim, 128, 8, 1, 0, bias=False)
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, static_img, fixed_num, dynamic_ts):
        B, T, F = dynamic_ts.shape
        device = fixed_num.device

        # 编码静态图像和固定特征（全局特征）
        static_feat = self.img_encoder(static_img).flatten(2).transpose(1, 2)  # (B, 64, E)
        fixed_feat = self.fixed_embed(fixed_num).unsqueeze(1)  # (B, 1, E)
        global_feat = torch.cat([fixed_feat, static_feat], dim=1)  # (B, 65, E)

        # 为每个时间步拼接全局特征和动态特征
        time_feat = self.time_embed(dynamic_ts)  # (B, T, E)
        global_feat_repeated = global_feat.mean(dim=1, keepdim=True).repeat(1, T, 1)  # (B, T, E)
        time_feat = torch.cat([global_feat_repeated, time_feat], dim=2)  # (B, T, 2E)
        time_feat = self.time_pos(time_feat)
        time_encoded = self.time_transformer(time_feat)  # (B, T, 2E)

        # 将 global_feat 的维度扩展到 2 * embed_dim
        global_feat = torch.cat([global_feat, global_feat], dim=2)  # (B, 65, 2E)

        # 融合全局特征和时间编码特征
        fused_input = torch.cat([global_feat, time_encoded], dim=1)  # (B, 65 + T, 2E)
        fused_feat = self.fusion_transformer(fused_input)  # (B, 65 + T, 2E)

        # 提取时间相关的融合特征
        time_fused_feat = fused_feat[:, 65:, :]  # (B, T, 2E)

        # 解码每个时间步的图像
        pred_imgs = []
        for t in range(T):
            current_feat = time_fused_feat[:, t, :].unsqueeze(1)  # (B, 1, 2E)
            pred_img = self.decoder(current_feat.permute(0, 2, 1).unsqueeze(-1))  # 转换为 (B, 2E, 1, 1)
            pred_imgs.append(pred_img)

        pred_imgs = torch.cat(pred_imgs, dim=1)  # (B, T, 1, 128, 128)
        return pred_imgs


# 数据集
class GasDataset(Dataset):
    def __init__(self, static_images, fixed_numeric, dynamic_numeric, target_images):
        self.static_images = static_images  # (N, C, H, W)
        self.fixed_numeric = fixed_numeric  # (N, F)
        self.dynamic_numeric = dynamic_numeric  # (N, T, F)
        self.target_images = target_images  # (N, T, C, H, W)

    def __len__(self):
        return len(self.static_images)

    def __getitem__(self, idx):
        return (
            self.static_images[idx],
            self.fixed_numeric[idx],
            self.dynamic_numeric[idx],
            self.target_images[idx]
        )


def train_epoch(model, loader, optimizer, criterion_mse, criterion_ssim, device):
    model.train()
    total_loss = 0.0
    total_ssim = 0.0
    total_mse = 0.0
    pbar = tqdm(loader, desc="Training", dynamic_ncols=True)
    for batch in pbar:
        static_img, fixed_num, dynamic_ts, target_img = [x.to(device).float() for x in batch]  # 转换数据类型
        pred_imgs = model(static_img, fixed_num, dynamic_ts)

        loss = 0.0
        batch_ssim = 0.0
        batch_mse = 0.0
        for t in range(model.time_steps):
            pred = pred_imgs[:, t]
            # 调整 pred 的维度，使其和 gt 一致
            if pred.dim() == 3:
                pred = pred.unsqueeze(1)
            gt = target_img[:, t]
            loss_mse = criterion_mse(pred, gt)
            loss_ssim = 1 - criterion_ssim(pred, gt)
            loss += 0.7 * loss_mse + 0.3 * loss_ssim
            batch_ssim += criterion_ssim(pred, gt).item()
            batch_mse += loss_mse.item()
        loss /= model.time_steps
        batch_ssim /= model.time_steps
        batch_mse /= model.time_steps

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_ssim += batch_ssim
        total_mse += batch_mse
        pbar.set_postfix({"loss": total_loss / (pbar.n + 1)})
    return total_loss / len(loader), total_ssim / len(loader), total_mse / len(loader)


def val_epoch(model, loader, criterion_mse, criterion_ssim, device):
    model.eval()
    total_loss = 0.0
    total_ssim = 0.0
    total_mse = 0.0
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation", dynamic_ncols=True)
        for batch in pbar:
            static_img, fixed_num, dynamic_ts, target_img = [x.to(device).float() for x in batch]  # 转换数据类型
            pred_imgs = model(static_img, fixed_num, dynamic_ts)

            loss = 0.0
            batch_ssim = 0.0
            batch_mse = 0.0
            for t in range(model.time_steps):
                pred = pred_imgs[:, t]
                # 调整 pred 的维度，使其和 gt 一致
                if pred.dim() == 3:
                    pred = pred.unsqueeze(1)
                gt = target_img[:, t]
                loss_mse = criterion_mse(pred, gt)
                loss_ssim = 1 - criterion_ssim(pred, gt)
                loss += 0.7 * loss_mse + 0.3 * loss_ssim
                batch_ssim += criterion_ssim(pred, gt).item()
                batch_mse += loss_mse.item()
            loss /= model.time_steps
            batch_ssim /= model.time_steps
            batch_mse /= model.time_steps

            total_loss += loss.item()
            total_ssim += batch_ssim
            total_mse += batch_mse
    return total_loss / len(loader), total_ssim / len(loader), total_mse / len(loader)


def save_all_images(model, dataset, device, epoch, dataset_type):
    model.eval()
    with torch.no_grad():
        for idx in tqdm(range(len(dataset)), desc=f"Saving {dataset_type} images for epoch {epoch}"):
            static_img, fixed_num, dynamic_ts, target_img = dataset[idx]
            if isinstance(static_img, np.ndarray):
                static_img = torch.from_numpy(static_img).unsqueeze(0).to(device).float()
            else:
                static_img = static_img.unsqueeze(0).to(device).float()
            if isinstance(fixed_num, np.ndarray):
                fixed_num = torch.from_numpy(fixed_num).unsqueeze(0).to(device).float()
            else:
                fixed_num = fixed_num.unsqueeze(0).to(device).float()
            if isinstance(dynamic_ts, np.ndarray):
                dynamic_ts = torch.from_numpy(dynamic_ts).unsqueeze(0).to(device).float()
            else:
                dynamic_ts = dynamic_ts.unsqueeze(0).to(device).float()
            if isinstance(target_img, np.ndarray):
                target_img = torch.from_numpy(target_img).to(device).float()
            else:
                target_img = target_img.to(device).float()

            pred_imgs = model(static_img, fixed_num, dynamic_ts)

            for t in range(model.time_steps):
                true_img = target_img[t, 0].cpu().numpy()
                # 保存真实图片
                true_img_path = os.path.join('Result', f'epoch_{epoch}', dataset_type, f'sample_{idx}', 'true_images')
                os.makedirs(true_img_path, exist_ok=True)
                plt.imsave(
                    os.path.join(true_img_path, f'step_{t}.png'),
                    (true_img + 1) / 2,
                    cmap='viridis'
                )

                # 保存预测图片
                pred = pred_imgs[0, t].cpu().numpy()
                if pred.ndim == 3:
                    pred = pred.squeeze(0)
                pred_img_path = os.path.join('Result', f'epoch_{epoch}', dataset_type, f'sample_{idx}', 'pred_images')
                os.makedirs(pred_img_path, exist_ok=True)
                plt.imsave(
                    os.path.join(pred_img_path, f'step_{t}.png'),
                    (pred + 1) / 2,
                    cmap='viridis'
                )

static_images = combined_images  # 静态多通道图像
fixed_numeric = fixed_numeric_data  # 固定数值特征
dynamic_numeric = changing_numeric_data  # 动态时间序列特征
target_images = img_time_series  # 目标图像序列（单通道）


if __name__ == "__main__":
    torch.autograd.set_detect_anomaly(True)  # 启用异常检测（可选）
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    os.makedirs('Result', exist_ok=True)  # 直接创建主目录

    # ================== 修改1：划分训练集、验证集、测试集 ==================
    all_indices = np.arange(3000)  # 假设总样本数2000
    # 训练集:验证集:测试集 = 6:2:2（先分80%训练+验证，再分其中25%为验证）
    train_val_idx, test_idx = train_test_split(all_indices, test_size=0.2, random_state=40)
    train_idx, val_idx = train_test_split(train_val_idx, test_size=0.25, random_state=40)  # 0.8*0.25=0.2验证集
    
    train_dataset = GasDataset(static_images[train_idx], fixed_numeric[train_idx], dynamic_numeric[train_idx],
                               target_images[train_idx])
    val_dataset = GasDataset(static_images[val_idx], fixed_numeric[val_idx], dynamic_numeric[val_idx],
                             target_images[val_idx])
    test_dataset = GasDataset(static_images[test_idx], fixed_numeric[test_idx], dynamic_numeric[test_idx],
                              target_images[test_idx])  
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)  

    model = MultiModalGasModel(
        img_channels=2,
        time_features=3,
        fixed_features=5,
        embed_dim=128,
        img_size=128,
        time_steps=6
    ).to(device)

    criterion_mse = nn.MSELoss()
    criterion_ssim = SSIM(data_range=2.0, size_average=True, channel=1)
    optimizer = optim.AdamW(model.parameters(), lr=0.0002, weight_decay=3e-5)
    scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

    # 初始化日志文件（添加测试集相关列）
    log_file = os.path.join('Result', 'training_log.csv')
    with open(log_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'train_ssim', 'train_mse', 
                         'val_loss', 'val_ssim', 'val_mse', 'test_loss', 'test_ssim', 'test_mse'])  # 新增测试列

    num_epochs = 50
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        train_loss, train_ssim, train_mse = train_epoch(model, train_loader, optimizer, criterion_mse, criterion_ssim, device)
        val_loss, val_ssim, val_mse = val_epoch(model, val_loader, criterion_mse, criterion_ssim, device)
        scheduler.step()
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train SSIM: {train_ssim:.4f}, Train MSE: {train_mse:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val SSIM: {val_ssim:.4f}, Val MSE: {val_mse:.4f}")


        with open(log_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch + 1, train_loss, train_ssim, train_mse, 
                             val_loss, val_ssim, val_mse, '-', '-', '-'])  # 训练时测试列留空

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join('Result', 'best_model.pth'))

        # 保存指定epoch的训练/验证图片
        save_epochs = [50]
        if (epoch + 1) in save_epochs:
            save_all_images(model, train_dataset, device, epoch + 1, 'train')
            save_all_images(model, val_dataset, device, epoch + 1, 'val')

    # ================== 修改2：训练后评估测试集 ==================
    print("\n================= 测试集评估 ==================")
    # 加载最佳模型
    model.load_state_dict(torch.load(os.path.join('Result', 'best_model.pth')))
    model.eval()
    
    # 评估测试集
    test_loss, test_ssim, test_mse = val_epoch(model, test_loader, criterion_mse, criterion_ssim, device)
    print(f"Test SSIM: {test_ssim:.4f}  \tTest MSE: {test_mse:.4f}  \tTest Loss: {test_loss:.4f}")
    
    # 保存测试集图片（使用最后epoch编号或特殊标识）
    save_all_images(model, test_dataset, device, num_epochs, 'test')
    
    # 补充测试集日志（可选）
    with open(log_file, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Test', '-', '-', '-', '-', '-', '-', test_loss, test_ssim, test_mse])