In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import math

## Gpytorch/MLP插值

### 所有原始训练数据从Lasair API获取
#### All of the original training data are obtained from the Lasair API

In [1]:
def safe_multiband_gp_interpolate(
    g_time, g_flux, r_time, r_flux,
    new_time_g, new_time_r,
    epochs=1000,
    device="cuda",
    fallback_threshold=0.95,
    patience=200,
    min_delta=1e-3,
    loss_threshold_for_accept=99,  # <<< 新增：GP 收敛质量标准
    verbose=True
):
    import numpy as np
    import torch
    import gpytorch
    #from interpolate_with_mlp import interpolate_with_mlp  # 确保你实现了这个函数

    # 预测退化判断函数
    def prediction_is_almost_constant(pred_values, tolerance=1e-2):
        return np.std(pred_values) < tolerance


    class MultibandGPModel(gpytorch.models.ExactGP):
        def __init__(self, train_x, train_y, likelihood):
            super().__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.covar_module = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=2)
            )

        def forward(self, x):
            mean_x = self.mean_module(x)
            covar_x = self.covar_module(x)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    # 1. 数据拼接
    g_band = np.zeros_like(g_time)
    r_band = np.ones_like(r_time)

    all_time = np.concatenate([g_time, r_time])
    all_flux = np.concatenate([g_flux, r_flux])
    all_band = np.concatenate([g_band, r_band])

    train_x_np = np.stack([all_time, all_band], axis=-1)
    time_min, time_max = train_x_np[:, 0].min(), train_x_np[:, 0].max()
    train_x_np[:, 0] = (train_x_np[:, 0] - time_min) / (time_max - time_min + 1e-8)

    train_x = torch.tensor(train_x_np, dtype=torch.float32).to(device)
    train_y = torch.tensor(all_flux, dtype=torch.float32).to(device)

    # 2. 模型初始化
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    model = MultibandGPModel(train_x, train_y, likelihood).to(device)

    model.train()
    likelihood.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    best_loss = float('inf')
    patience_counter = 0
    losses = []
    converged = False

    # 3. 训练过程 + Early Stopping
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        optimizer.step()

        current_loss = loss.item()
        losses.append(current_loss)

        if verbose and epoch % 100 == 0:
            print(f"[Multiband GP][Epoch {epoch}] Loss: {current_loss:.4f}")

        if best_loss - current_loss > min_delta:
            best_loss = current_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            if verbose:
                print(f"[Multiband GP] 提前早停于第 {epoch} 轮，Best Loss: {best_loss:.4f}")
            converged = True
            break

    # 若未早停，再根据下降程度判断是否算“收敛”
    if not converged:
        #if losses[-1] < fallback_threshold * losses[0] and losses[-1] <= 0:
        if losses[-1] < fallback_threshold * losses[0]:
            best_loss = losses[-1]
            converged = True
        else:
            converged = False

    # 4. 输入标准化后的时间坐标
    new_x_g = np.stack([new_time_g, np.zeros_like(new_time_g)], axis=-1)
    new_x_r = np.stack([new_time_r, np.ones_like(new_time_r)], axis=-1)

    new_x_g[:, 0] = (new_x_g[:, 0] - time_min) / (time_max - time_min + 1e-8)
    new_x_r[:, 0] = (new_x_r[:, 0] - time_min) / (time_max - time_min + 1e-8)

    # 5. 判断是否使用 GP 插值或 fallback 到 MLP
    if converged and best_loss < loss_threshold_for_accept:
        model.eval()
        likelihood.eval()
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            new_x_g_tensor = torch.tensor(new_x_g, dtype=torch.float32).to(device)
            new_x_r_tensor = torch.tensor(new_x_r, dtype=torch.float32).to(device)

            pred_g = likelihood(model(new_x_g_tensor)).mean.cpu().numpy()
            pred_r = likelihood(model(new_x_r_tensor)).mean.cpu().numpy()

            # ==== 新增判断：预测是否退化为常数 ====
        if (prediction_is_almost_constant(pred_g) and
            prediction_is_almost_constant(pred_r)):
            if verbose:
                print(f"[Multiband GP] 预测退化为均值，Fallback 到 MLP。")
            pred_g = interpolate_with_mlp(np.array(g_time), np.array(g_flux), new_time_g, device=device)
            pred_r = interpolate_with_mlp(np.array(r_time), np.array(r_flux), new_time_r, device=device)
            return pred_g, pred_r

        return pred_g, pred_r

    else:
        if verbose:
            print(f"[Multiband GP] Fallback 到 MLP。收敛状态: {converged}, 最终 Loss: {best_loss:.4f}")
        pred_g = interpolate_with_mlp(np.array(g_time), np.array(g_flux), new_time_g, device=device)
        pred_r = interpolate_with_mlp(np.array(r_time), np.array(r_flux), new_time_r, device=device)
        return pred_g, pred_r




In [22]:
import os
import json
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter
from scipy.interpolate import interp1d
from scipy.interpolate import PchipInterpolator
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import pywt
from scipy.signal import savgol_filter
from collections import Counter


def plot_interpolation_debug(
    g_time_sorted, g_flux_sorted,
    r_time_sorted, r_flux_sorted,
    new_time_g, g_interp,
    new_time_r, r_interp,
    idx=None, filename=None, save=False,
    save_dir="debug_plots", invert_y=True,
    label=0
):
    plt.figure(figsize=(8, 4))
    plt.plot(g_time_sorted, g_flux_sorted, '.', label='g-band (raw)', alpha=0.6)
    plt.plot(r_time_sorted, r_flux_sorted, '.', label='r-band (raw)', alpha=0.6)
    plt.plot(new_time_g, g_interp, '-', label='g-band (interp)', linewidth=1.2)
    plt.plot(new_time_r, r_interp, '-', label='r-band (interp)', linewidth=1.2)

    plt.xlabel("Time")
    plt.ylabel("Flux")
    plt.legend()

    title_str = f"Sample #{idx}, Label={label}" if idx is not None else ""
    if filename is not None:
        title_str += f" | {filename}"
    plt.title(title_str.strip())

    if invert_y:
        plt.gca().invert_yaxis()

    plt.tight_layout()

    if save:
        import os
        os.makedirs(save_dir, exist_ok=True)
        fname = filename.replace(".json", "") if filename else f"sample_{idx}"
        plt.savefig(f"{save_dir}/{fname}.png")
    else:
        plt.show()

    plt.close()


def wavelet_denoise(signal, wavelet='db4', level=2, threshold_scale=1.0):
    # 小波分解
    coeffs = pywt.wavedec(signal, wavelet, level=level)
    
    # 估算噪声阈值（使用最细层 detail 系数）
    sigma = np.median(np.abs(coeffs[-1])) / 0.6745
    threshold = threshold_scale * sigma

    # 软阈值处理高频部分
    coeffs_thresh = [coeffs[0]]  # 保留 approximation 部分
    for detail in coeffs[1:]:
        coeffs_thresh.append(pywt.threshold(detail, threshold, mode='soft'))

    # 重构信号
    return pywt.waverec(coeffs_thresh, wavelet)

# 示例：对 fg 和 fr 平滑
def smooth_two_band(tg, fg, tr, fr, wavelet='db4', level=3):
    fg_smooth = wavelet_denoise(fg, wavelet=wavelet, level=level)
    fr_smooth = wavelet_denoise(fr, wavelet=wavelet, level=level)

    # 修剪长度匹配（小波重构后可能略长）
    fg_smooth = fg_smooth[:len(fg)]
    fr_smooth = fr_smooth[:len(fr)]

    return tg, fg_smooth, tr, fr_smooth


def fill_data(time_g, flux_g, time_r, flux_r, seq_len=200):
    # 创建一个形状为 (seq_len, 4) 的全零数组
    filled_data = np.zeros((seq_len, 4))

    # 填充 time_g 和 flux_g 到第 1 和 2 维度
    filled_data[:len(time_g), 0] = time_g  # 填充 time_g 到第 1 维度
    filled_data[:len(flux_g), 1] = flux_g  # 填充 flux_g 到第 2 维度

    # 填充 time_r 和 flux_r 到第 3 和 4 维度
    filled_data[:len(time_r), 2] = time_r  # 填充 time_r 到第 3 维度
    filled_data[:len(flux_r), 3] = flux_r  # 填充 flux_r 到第 4 维度

    return filled_data


class MLPInterpolator(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(1, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, x):
        return self.model(x)
        

class LSTMInterpolator(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # x shape: (batch, seq_len=1, input_size=1)
        out, _ = self.lstm(x)
        return self.fc(out)[:, -1, :]  # output shape: (batch, 1)


def interpolate_with_mlp(time_array, flux_array, new_time, epochs=2000, lr=1e-3, hidden_size=64, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 去除 NaN
    mask = ~np.isnan(flux_array)
    x = time_array[mask].reshape(-1, 1)
    y = flux_array[mask].reshape(-1, 1)

    if len(x) < 3:
        return PchipInterpolator(time_array, flux_array)(new_time)

    # 归一化
    #t_min, t_max = x.min(), x.max()
    #x_norm = (x - t_min) / (t_max - t_min)
    #new_time_norm = (new_time.reshape(-1, 1) - t_min) / (t_max - t_min)

    # 新 Gaussian 归一化
    t_mean, t_std = x.mean(), max(x.std(), 1e-3)
    x_norm = (x - t_mean) / t_std
    new_time_norm = (new_time.reshape(-1, 1) - t_mean) / t_std

    y_mean, y_std = y.mean(), max(y.std(), 1e-3)
    y_norm = (y - y_mean) / y_std

    x_tensor = torch.tensor(x_norm, dtype=torch.float32, device=device)
    y_tensor = torch.tensor(y_norm, dtype=torch.float32, device=device)

    model = MLPInterpolator(hidden_size=hidden_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for _ in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(x_tensor)
        loss = loss_fn(output, y_tensor)
        loss.backward()
        optimizer.step()

    # 预测
    new_time_tensor = torch.tensor(new_time_norm, dtype=torch.float32, device=device)
    with torch.no_grad():
        pred_norm = model(new_time_tensor).cpu().numpy()

    pred = pred_norm * y_std + y_mean
    return np.clip(pred.flatten(), 1e-6, np.max(flux_array) * 10)


def interpolate_with_lstm(time_array, flux_array, new_time, epochs=2000, lr=1e-3, hidden_size=64, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 去除 NaN
    mask = ~np.isnan(flux_array)
    x = time_array[mask].reshape(-1, 1)
    y = flux_array[mask].reshape(-1, 1)

    if len(x) < 3:
        return PchipInterpolator(time_array, flux_array)(new_time)

    # 高斯归一化
    t_mean, t_std = x.mean(), max(x.std(), 1e-3)
    x_norm = (x - t_mean) / t_std
    new_time_norm = (new_time.reshape(-1, 1) - t_mean) / t_std

    y_mean, y_std = y.mean(), max(y.std(), 1e-3)
    y_norm = (y - y_mean) / y_std

    # 转为 tensor，并增加序列维度
    x_tensor = torch.tensor(x_norm[:, None], dtype=torch.float32, device=device)  # (N, 1, 1)
    y_tensor = torch.tensor(y_norm, dtype=torch.float32, device=device)           # (N, 1)

    model = LSTMInterpolator(hidden_size=hidden_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for _ in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(x_tensor)
        loss = loss_fn(output, y_tensor)
        loss.backward()
        optimizer.step()

    # 预测
    new_time_tensor = torch.tensor(new_time_norm.reshape(-1, 1, 1), dtype=torch.float32, device=device)
    with torch.no_grad():
        pred_norm = model(new_time_tensor).cpu().numpy()

    pred = pred_norm * y_std + y_mean
    return np.clip(pred.flatten(), 1e-6, np.max(flux_array) * 10)



# 将星等转为流量的函数
def mag_to_flux(mag, m_ref=22.5):
    if float(mag) <= 0:  # 避免无效的星等值
        return 0
    return 10 ** (0.4 * (m_ref - float(mag)))

def magerr_to_fluxerr(flux,magerr):
    
    return 0.921*flux*magerr


import numpy as np
from collections import Counter

# 切分数据并插值
import numpy as np

# 修改后的函数
def sample_and_process(sequence, target_length=200, time_range=(0, 15)):
    time = sequence[:, 0]
    g_flux = sequence[:, 1]
    r_flux = sequence[:, 2]
    
    # 超出指定时间范围的部分设定为0
    mask = (time >= time_range[0]) & (time <= time_range[1])
    sampled_time = time[mask]
    sampled_g_flux = g_flux[mask]
    sampled_r_flux = r_flux[mask]

    # 创建一个全为0的目标数组，长度为target_length，列数为5（包括时间、g_flux、r_flux）
    result = np.zeros((target_length, 3))

    # 将有效时间数据映射到0到target_length的范围
    if len(sampled_time) > 0:
        # 计算时间的线性映射
        time_min, time_max = sampled_time.min(), sampled_time.max()
        time_scaled = (sampled_time - time_min) / (time_max - time_min) * (target_length - 1)

        # 将时间对应的索引映射到目标数组
        time_indices = np.round(time_scaled).astype(int)

        # 填充时间、g_flux和r_flux
        for i, idx in enumerate(time_indices):
            result[idx, 0] = sampled_time[i]  # 填充时间
            result[idx, 1] = sampled_g_flux[i]  # 填充g_flux
            result[idx, 2] = sampled_r_flux[i]  # 填充r_flux

    # 对所有行进行排序，按时间升序排列
    result = result[result[:, 0].argsort()]

    # 标准化g_flux和r_flux
    max_flux = max(result[:, 1].max(), result[:, 2].max())
    if max_flux > 0:  # 防止除以零
        result[:, 1] /= max_flux
        result[:, 2] /= max_flux

    return result


def sample_and_interpolate(sequence, target_length=200,plot=False, invert_y=True,max_plots=5,label=None):
    #在这里，选择g，r波段分别插值，不再进行统一插值了
    g_time = sequence[:, 0]
    g_flux = sequence[:, 1]
    r_time = sequence[:, 2]
    r_flux = sequence[:, 3]
    
    max_time_range = max(max(g_time), max(r_time)) - min(min(r_time), min(g_time))
    
    
    # 如果时间跨度太小，直接返回全零数组
    if max_time_range <= 0:
        print(max_time_range)
        return np.zeros((target_length, 5))  # 返回全零数组，长度为目标长度
        
    
    # 切分数据为多个阶段（按时间百分比）
    max_time = max(max(g_time), max(r_time))
    #print(max_time)
    gx = g_time[g_time>0]
    min_time = np.min(gx)


    stages = [

        (-np.inf, min_time+0.2*(max_time-min_time)),
       
        (-np.inf, min_time+0.4*(max_time-min_time)),
        
        (-np.inf, min_time+0.6*(max_time-min_time)),
        
        (-np.inf, min_time+0.8*(max_time-min_time)),
        
        (-np.inf, np.inf)
    ]
    
    augmented_sequences = []
    plot_count=0
    
    for start, end in stages:
        # 根据时间范围筛选子序列
        #print('seq - sub01')
        mask_g = (g_time >= start) & (g_time <= end) & (g_time != 0)
        mask_r = (r_time >= start) & (r_time <= end) & (r_time != 0)
        sampled_g_time = g_time[mask_g]
        sampled_g_flux = g_flux[mask_g]
        sampled_r_time = r_time[mask_r]
        sampled_r_flux = r_flux[mask_r]

        # 只保留非零数据点
        non_zero_g_flux = sampled_g_flux[sampled_g_flux != 0]
        non_zero_r_flux = sampled_r_flux[sampled_r_flux != 0]

        # 如果非零数值过少（不足三个数据点），跳过
        if len(non_zero_g_flux) < 4 or len(non_zero_r_flux) < 4:
            continue
        
        # 归一化flux
        max_flux_fix = max(np.max(non_zero_g_flux), np.max(non_zero_r_flux))
        sampled_g_flux = sampled_g_flux 
        sampled_r_flux = sampled_r_flux 
        s1=min( np.min(sampled_r_time),np.min(sampled_g_time))
        
        
        sampled_g_time_1 = sampled_g_time.copy()
        sampled_r_time_1 = sampled_r_time.copy()
        sampled_r_flux_1 = sampled_r_flux.copy()
        sampled_g_flux_1 = sampled_g_flux.copy()

        s3 = max(np.min(sampled_g_time_1),np.min(sampled_r_time_1)) #max min time         
        s4 = min(np.max(sampled_g_time_1),np.max(sampled_r_time_1)) #min max time         
        s2 = max(np.max(sampled_g_flux_1),np.max(sampled_r_flux_1)) #max max flux
        s7 = min(np.min(sampled_g_flux_1),np.min(sampled_r_flux_1)) #min min flux
        
       

        if len(sampled_g_flux) < 4 or len(sampled_r_flux) < 4:             
            continue

       

        tg = np.linspace(np.min(sampled_g_time_1), np.max(sampled_g_time_1), target_length)
        tr = np.linspace(np.min(sampled_r_time_1), np.max(sampled_r_time_1), target_length)

        
        
        upper_limit_flux = 10 * max(np.max(sampled_g_flux),np.max(sampled_r_flux))
        lower_limit_flux = 1e-6

        sampled_g_flux = (sampled_g_flux - 0 ) /(s2 - 0)
        sampled_r_flux = (sampled_r_flux - 0 ) /(s2 - 0)

        '''
        for i in range(len(sampled_g_flux)):
            sampled_g_flux[i] = 10**sampled_g_flux[i]

        for i in range(len(sampled_r_flux)):
            sampled_r_flux[i] = 10**sampled_r_flux[i]
        '''
        tg = tg-s1
        tr = tr-s1
        #tt = tt - s1
        sampled_g_time = sampled_g_time - s1
        sampled_r_time = sampled_r_time - s1
        
        
        

        
        
        sampled_g_time_1, sampled_g_flux_1, sampled_r_time_1,sampled_r_flux_1 = sampled_g_time, sampled_g_flux, sampled_r_time,sampled_r_flux
       

        try:
            g_interp, r_interp = safe_multiband_gp_interpolate(
                                        sampled_g_time, sampled_g_flux,
                                        sampled_r_time, sampled_r_flux,
                                        tg, tr,
                                        epochs=1000,
                                        device="cuda",
                                        min_delta=1e-3,
                                        fallback_threshold = 0.1,
                                        loss_threshold_for_accept=99,  # <<< 新增：GP 收敛质量标准
                                        verbose=False
                                    )
        except RuntimeError as e:
            print('GP Process Facing Uncorrectable Error, MLP Instead...')
            g_interp = interpolate_with_mlp(np.array(sampled_g_time), np.array(sampled_g_flux), tg,device="cuda",epochs=1000)
            r_interp = interpolate_with_mlp(np.array(sampled_r_time), np.array(sampled_r_flux), tr,device="cuda",epochs=1000)
        
                
        
        fg = np.clip(g_interp,lower_limit_flux, upper_limit_flux)
        fr = np.clip(r_interp,lower_limit_flux, upper_limit_flux)

        #f_max = max(np.max(fg),np.max(fr))
        
        #fg = fg/f_max;fr = fr/f_max
        
        #小波变换
        #tg, fg, tr, fr = smooth_two_band(tg, fg, tr, fr)

        #SV平滑
        #fg = savgol_filter(fg, window_length=30, polyorder=2)
        #fr = savgol_filter(fr, window_length=30, polyorder=2)

        # 组合插值结果
        augmented_sequences.append(np.vstack((tg, fg, tr, fr)).T)
        #print(start,end)

        # 只绘制前 max_plots 次

        #tg = [];tr = [];fg = [];fr = []

        '''
        if plot and plot_count < max_plots:
            #print(tg,fg)
            plot_interpolation_debug(
                sampled_g_time, sampled_g_flux,
                sampled_r_time, sampled_r_flux,
                tg, fg,
                tr, fr,
                idx=plot_count,
                invert_y=False,
                label=str(label)
            )
            plot_count += 1
        else:
            stop
        '''

    
    return augmented_sequences


# 扩充样本数据（随机增强机制）
from collections import Counter
import numpy as np
from tqdm import tqdm



def augment_samples(data, labels, augment_times=1):
    augmented_data = []
    augmented_labels = []

    total_samples = len(data)

    for i in tqdm(range(total_samples), desc="Augmenting data"):
        sequence = data[i]
        label = labels[i]

        augmented_sequences = sample_and_interpolate(sequence, plot=True, label=label)

        for augmented_sequence in augmented_sequences:
            if len(augmented_sequence) > 0 and np.array(augmented_sequence).shape[0] == 200:
                augmented_data.append(augmented_sequence)
                augmented_labels.append(label)

    print(f"Augmented counts: {Counter(augmented_labels)}")
    print(np.array(augmented_data).shape)
    return np.array(augmented_data), np.array(augmented_labels)



# 修改后的主函数
def load_and_modify_json(directories, labels, m_ref=22.5, target_counts=None,train_ratio = 0.6):
    data = []
    all_labels = []
    label_count = Counter()
    processed = 0
    for directory, label in zip(directories, labels):
        for file in os.listdir(directory):
            file_path = os.path.join(directory, file)
            if os.path.isfile(file_path) and file.endswith('.json'):
                with open(file_path, 'r') as f:
                   try: 
                        processed +=1
                        #print(f'processing No.{processed} Light curve')
                        json_data = json.load(f)

                        # 处理格式1（candidates数组）
                        if isinstance(json_data, dict) and "candidates" in json_data:
                            candidates = json_data["candidates"]
                            sequence = []
                            
                            for candidate in candidates:
                                if candidate.get("isdiffpos") == 't':
                                    jd = candidate.get("jd", 0)
                                    mag_g = candidate.get("magpsf") if candidate.get("fid") == 1 else None  # 根据fid选择g波段
                                    mag_r = candidate.get("magpsf") if candidate.get("fid") == 2 else None  # 根据fid选择r波段
                                    mag_g_err = candidate.get("sigmapsf") if candidate.get("fid") == 1 else None
                                    mag_r_err = candidate.get("sigmapsf") if candidate.get("fid") == 2 else None
                                    mag_g_zp = candidate.get("magzpsci") if candidate.get("fid") == 1 else None
                                    mag_r_zp = candidate.get("magzpsci") if candidate.get("fid") == 2 else None
                                    
                                    # 仅当有有效的g和r波段数据时才进行赋值
                                    if mag_g is not None:
                                        mag_g = mag_to_flux(mag_g, m_ref)
                                        #mag_g = float(mag_g)
                                        flux_g_err = magerr_to_fluxerr(mag_g,mag_g_err)
                                    if mag_r is not None:
                                        mag_r = mag_to_flux(mag_r, m_ref)
                                        #mag_r = float(mag_r)
                                        flux_r_err = magerr_to_fluxerr(mag_r,mag_r_err)
    
                                    sequence.append([jd-2400000.5, mag_g if mag_g is not None else 0, mag_r if mag_r is not None else 0])

                        # 处理格式2（数组形式）
                        elif isinstance(json_data, list) and "MJD" in json_data[0]:
                            sequence = []
                            for entry in json_data:
                                mjd = float(entry.get("MJD", 0))
                                if entry['unforced_mag_status']=='positive':
                                    if entry["filter"] == "g":
                                        #mag_g = float(entry.get("unforced_mag", 0))
                                        mag_g = mag_to_flux(entry.get("unforced_mag", 0), m_ref)
                                        mag_r = 0  # 如果当前是g波段，r波段为0
                                    elif entry["filter"] == "r":
                                        #mag_r = float(entry.get("unforced_mag", 0))
                                        mag_r = mag_to_flux(entry.get("unforced_mag", 0), m_ref)
                                        mag_g = 0  # 如果当前是r波段，g波段为0
                                    else:
                                        continue  # 如果不是g或r波段则跳过
                                else:
                                    continue

                                sequence.append([mjd, mag_g, mag_r])

                        else:
                            print(f"Skipping file {file_path}: unexpected format.")
                            continue

                        # 将序列转换为numpy数组
                        sequence = np.array(sequence)
                        if sequence.shape[0]!=0 and sequence.shape[0]>10:
                            
                            pass
                        else:
                            #print('An error file occurs')
                            continue
                        # 处理排序和插值
                        # 分别获取 g 波段和 r 波段的时间和流量
                        g_time = sequence[:, 0][sequence[:, 1] > 0]
                        g_flux = sequence[:, 1][sequence[:, 1] > 0]
                        r_time = sequence[:, 0][sequence[:, 2] > 0]
                        r_flux = sequence[:, 2][sequence[:, 2] > 0]

                        # 对 g 和 r 波段的时间和流量进行排序
                        sorted_g_indices = np.argsort(g_time)  # 获取 g_time 排序的索引
                        g_time_sorted = g_time[sorted_g_indices]  # 根据排序的索引排序 g_time
                        g_flux_sorted = g_flux[sorted_g_indices]  # 同时排序 g_flux
                        
                        sorted_r_indices = np.argsort(r_time)  # 获取 r_time 排序的索引
                        r_time_sorted = r_time[sorted_r_indices]  # 根据排序的索引排序 r_time
                        r_flux_sorted = r_flux[sorted_r_indices]  # 同时排序 r_flux
                        
                       
                        
                        # 删除 g 和 r 波段中时间相同的点
                        unique_g_times = []
                        unique_g_flux = []
                        unique_r_times = []
                        unique_r_flux = []
                        
                        for g_time_val, g_flux_val in zip(g_time_sorted, g_flux_sorted):
                            if g_time_val not in unique_g_times:
                                unique_g_times.append(g_time_val)
                                unique_g_flux.append(g_flux_val)
                        
                        for r_time_val, r_flux_val in zip(r_time_sorted, r_flux_sorted):
                            if r_time_val not in unique_r_times:
                                unique_r_times.append(r_time_val)
                                unique_r_flux.append(r_flux_val)
                        
                        # 重新组合已去除重复时间点的 g 和 r 波段
                        g_time_sorted = np.array(unique_g_times)
                        g_flux_sorted = np.array(unique_g_flux)
                        r_time_sorted = np.array(unique_r_times)
                        r_flux_sorted = np.array(unique_r_flux)

                        #try:
                        
                        if len(g_time_sorted)>3 and len(r_time_sorted)>3:
                            # 使用 g 和 r 波段时间的最大和最小值设置为时间的起止点
                            min_time = min(min(g_time_sorted), min(r_time_sorted))
                            max_time = max(max(g_time_sorted), max(r_time_sorted))

                            min_max_time = min(max(g_time_sorted), max(r_time_sorted))
                            max_min_time = max(min(g_time_sorted), min(r_time_sorted))

                            tg = np.zeros(200);tr = np.zeros(200);fg = np.zeros(200);fr = np.zeros(200)
                            if len(g_time_sorted)<=200:
                                for i in range(len(g_time_sorted)):
                                    tg[i] = g_time_sorted[i]
                                    fg[i] = g_flux_sorted[i]
                            else:
                                for i in range(200):
                                    tg[i] = g_time_sorted[i]
                                    fg[i] = g_flux_sorted[i]

                            if len(r_time_sorted)<=200:
                                for i in range(len(r_time_sorted)):
                                    tr[i] = r_time_sorted[i]
                                    fr[i] = r_flux_sorted[i]
                            else:
                                for i in range(200):
                                    tr[i] = r_time_sorted[i]
                                    fr[i] = r_flux_sorted[i]

                            #import matplotlib.pyplot as plt

                            #plt.plot(g_time_sorted,g_flux_sorted,'.')
                            #plt.plot(r_time_sorted,r_flux_sorted,'.')
                            #plt.plot(new_time_g,g_interp)
                            #plt.plot(new_time_r,r_interp)
                            
                            #stop
                            #if g_flux_sorted[-1]<np.max(g_flux_sorted) and r_flux_sorted[-1]<np.max(r_flux_sorted):
                             
                            sequence_full = fill_data(tg, fg, tr,fr)

                            
                            
                        
                            # 填充或截断序列到最大长度200
                            if len(sequence) < 200:
                                padding = np.zeros((200 - len(sequence), 3))
                                sequence = np.vstack((sequence, padding))
                            elif len(sequence) > 200:
                                sequence = sequence[:200]
    
                            # 如果第二列和第三列都为0，将第一列的值也设为0
                            for row in sequence:
                                if row[1] == 0 and row[2] == 0:
                                    row[0] = 0  # 将第一列也设为0
    
                            data.append(sequence_full)
                            all_labels.append(label)
                            label_count[label] += 1  # 更新标签计数
                        #except SyntaxError:
                            #print('>>')
                            pass
                   #except (json.JSONDecodeError, ValueError):
                   except SyntaxError:
                        print(f"Skipping file {file_path}: unable to decode JSON or invalid data.")    

    print("Initial Label counts:")
    for label, count in label_count.items():
        print(f"Label {label}: {count} samples")

    print(f"Final dataset shape: {np.array(data).shape}")
   
    labels = np.array(all_labels)
    data = np.array(data)
    # 扩充样本,仅针对训练集
    X_train, X_test, y_train, y_test = train_test_split(
    data, labels, test_size=(1-train_ratio), stratify=labels, random_state=42
    )

    print("Train label distribution:", Counter(y_train))
    print("Test label distribution:", Counter(y_test))

    print(X_test.shape)
    print(f"Training Data is Augmenting...")
    X_train, y_train = augment_samples(np.array(X_train), np.array(y_train), augment_times=1)
    np.save("X_train.npy", X_train)
    np.save("y_train.npy", y_train)
    
    print(f"Testing Data is Augmenting...")
    X_test,y_test = augment_samples(np.array(X_test), np.array(y_test),augment_times=1)
    np.save("X_test.npy", X_test)
    np.save("y_test.npy", y_test)

    print("Train shape",X_train.shape)
    print("Test shape",X_test.shape)

    return np.array(data),X_train


# 保存数据


directories = ['../ZTF-TDE/','../ZTF_SN_total/SN Ia/','../ZTF SN Ib Ic','../ZTF_SN_total/SN II_all/','../ZTF SLSN/','../ZTF AGN/']
               

labels = [0,1,2,3,4,5]





target_counts = {0: 1000, 1: 10000, 2: 10000}  # 每个类别的目标样本数，但是目前来说是摆设

# 加载和扩充数据
data,X_train= load_and_modify_json(directories, labels, target_counts=target_counts,train_ratio=0.65)


# 检查数据

print(X_train[1])

#print(np.max(data))


Initial Label counts:
Label 0: 50 samples
Label 1: 4448 samples
Label 2: 281 samples
Label 3: 1388 samples
Label 4: 187 samples
Label 5: 52 samples
Final dataset shape: (6406, 200, 4)
Train label distribution: Counter({np.int64(1): 2891, np.int64(3): 902, np.int64(2): 183, np.int64(4): 121, np.int64(5): 34, np.int64(0): 32})
Test label distribution: Counter({np.int64(1): 1557, np.int64(3): 486, np.int64(2): 98, np.int64(4): 66, np.int64(0): 18, np.int64(5): 18})
(2243, 200, 4)
Training Data is Augmenting...


Augmenting data: 100%|████████████████████████████████████████████████████████████████████████| 4163/4163 [23:13:01<00:00, 20.08s/it]


Augmented counts: Counter({np.int64(1): 11267, np.int64(3): 3906, np.int64(2): 725, np.int64(4): 486, np.int64(5): 147, np.int64(0): 142})
(16673, 200, 4)
Testing Data is Augmenting...


Augmenting data: 100%|████████████████████████████████████████████████████████████████████████| 2243/2243 [12:26:21<00:00, 19.96s/it]

Augmented counts: Counter({np.int64(1): 6086, np.int64(3): 2106, np.int64(2): 385, np.int64(4): 283, np.int64(0): 77, np.int64(5): 73})
(9010, 200, 4)
Train shape (16673, 200, 4)
Test shape (9010, 200, 4)
[[ 0.          0.36586493  4.9541551   0.66401565]
 [ 0.11045209  0.36961544  5.04018572  0.66928136]
 [ 0.22090417  0.37371883  5.12621633  0.67454064]
 [ 0.33135626  0.37817287  5.21224695  0.67978984]
 [ 0.44180835  0.38297313  5.29827756  0.68502575]
 [ 0.55226043  0.38811189  5.38430818  0.6902445 ]
 [ 0.66271252  0.39358348  5.4703388   0.69544381]
 [ 0.77316461  0.39937907  5.55636941  0.70061946]
 [ 0.8836167   0.40549055  5.64240003  0.70576859]
 [ 0.99406878  0.41190594  5.72843064  0.71088672]
 [ 1.10452087  0.418616    5.81446126  0.71597081]
 [ 1.21497296  0.42560887  5.90049188  0.72101629]
 [ 1.32542504  0.4328731   5.98652249  0.72602057]
 [ 1.43587713  0.44039544  6.07255311  0.73097926]
 [ 1.54632922  0.44816363  6.15858373  0.7358886 ]
 [ 1.6567813   0.45616525  6.2


