In [None]:
############################################################################################################

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm
import gc
import matplotlib.pyplot as plt
import os
import pickle
import psutil
import warnings
import math
from scipy.signal import butter, filtfilt, find_peaks
from scipy.interpolate import interp1d
import pywt

# Suppress transformer warning
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")

# VMD Dependency (simplified placeholder)
print("Using simplified VMD placeholder.")
def vmd_wrapper(signal, alpha, tau, K, DC, init, tol):
    modes = np.array([signal / K for _ in range(K)])
    return modes, None, None

# Bandpass Filter
def bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyquist = fs / 2
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    filtered = filtfilt(b, a, data, axis=0)
    return filtered

# MAD (Median Absolute Deviation)
def mad(data):
    return np.median(np.abs(data - np.median(data)))

# Wavelet Denoising with Enhanced Thresholding
def wavelet_denoise(data, wavelet='db8', level=6):
    coeffs = pywt.wavedec(data, wavelet, level=level)
    sigma = mad(coeffs[-level])
    uthresh = sigma * np.sqrt(2 * np.log(len(data)))
    for i in range(1, len(coeffs)):
        coeffs[i] = pywt.threshold(coeffs[i], value=uthresh, mode='soft')
    return pywt.waverec(coeffs, wavelet)

# False Nearest Neighbors
def false_nearest_neighbors(signal, max_dim=10, tau=1, R_T=15.0):
    def compute_distance(x, y):
        return np.sqrt(np.sum((x - y)**2))
    fnn_ratios = []
    signal = np.array(signal)
    for m in range(1, max_dim + 1):
        phase_space = []
        for i in range(0, len(signal) - (m - 1) * tau, tau):
            vec = signal[i:i + m * tau:tau]
            if len(vec) == m:
                phase_space.append(vec)
        if not phase_space:
            fnn_ratios.append(0)
            continue
        phase_space = np.array(phase_space)
        fnn_count = 0
        total_count = 0
        for i in range(len(phase_space)):
            distances = [compute_distance(phase_space[i], phase_space[j]) for j in range(len(phase_space)) if i != j]
            if distances:
                min_dist_idx = np.argmin(distances)
                nn_idx = min_dist_idx if min_dist_idx < i else min_dist_idx + 1
                if i + m * tau < len(signal) and nn_idx + m * tau < len(signal):
                    dist_m = compute_distance(phase_space[i], phase_space[nn_idx])
                    dist_m1 = abs(signal[i + m * tau] - signal[nn_idx + m * tau])
                    if dist_m > 0 and dist_m1 / dist_m >= R_T:
                        fnn_count += 1
                    total_count += 1
        fnn_ratio = fnn_count / total_count if total_count > 0 else 0
        fnn_ratios.append(fnn_ratio)
        if fnn_ratio < 0.01:
            break
    if not fnn_ratios:
        return [0], 1
    return fnn_ratios, np.argmin(fnn_ratios) + 1

# Mutual Information
def mutual_information(signal, max_tau=50):
    def entropy(data):
        counts, _ = np.histogram(data, bins=50, density=True)
        probs = counts / counts.sum()
        probs = probs[probs > 0]
        return -np.sum(probs * np.log2(probs))
    mi_values = []
    for tau in range(1, max_tau + 1):
        s = signal[:-tau]
        q = signal[tau:]
        joint = np.vstack((s, q)).T
        h_s = entropy(s)
        h_q = entropy(q)
        h_sq = entropy(joint)
        mi = h_s + h_q - h_sq
        mi_values.append(mi)
        if len(mi_values) > 1 and mi_values[-1] > mi_values[-2]:
            break
    return mi_values, np.argmin(mi_values) + 1

# VMD-PSR with Dimensionality Reduction
def vmd_psr(data, K=15, embedding_dim=5, tau=1, variate_indices=None, alpha=2000):
    num_samples, num_variates = data.shape
    enhanced_data = []
    for v in range(min(100, num_variates)):  # Limit to 100 variates
        signal = data[:, v]
        if variate_indices and v not in variate_indices:
            cache_file = f"vmd_psr_variate_{v}.pkl"
            if os.path.exists(cache_file):
                with open(cache_file, 'rb') as f:
                    modes = pickle.load(f)
            else:
                modes = np.array([signal / K for _ in range(K)])
        else:
            if len(signal) < 10 or np.all(signal == signal[0]):
                print(f"Warning: Skipping variate {v}")
                modes = np.array([signal / K for _ in range(K)])
            else:
                try:
                    modes, _, _ = vmd_wrapper(signal, alpha=alpha, tau=0, K=K, DC=0, init=1, tol=1e-6)
                    if modes.shape != (K, num_samples):
                        print(f"Warning: VMD shape mismatch for variate {v}")
                        modes_padded = np.zeros((K, num_samples))
                        modes_padded[:min(K, modes.shape[0]), :min(num_samples, modes.shape[1])] = modes
                        modes = modes_padded
                    if variate_indices:
                        with open(f"vmd_psr_variate_{v}.pkl", 'wb') as f:
                            pickle.dump(modes, f)
                except Exception as e:
                    print(f"Error in VMD for variate {v}: {e}")
                    modes = np.array([signal / K for _ in range(K)])
        for mode in modes:
            mode_psr = []
            for d in range(embedding_dim):
                delay = d * tau
                delayed = np.roll(mode, -delay)
                if delay > 0:
                    delayed[-delay:] = delayed[-delay-1]
                mode_psr.append(delayed)
            mode_psr = np.stack(mode_psr, axis=1)
            enhanced_data.append(mode_psr)
    return np.concatenate(enhanced_data, axis=1)

# Enhanced iTransformer Model
class EnhancediTransformer(nn.Module):
    def __init__(self, num_variates, seq_len, pred_len, hidden_size=512, num_heads=8, num_layers=6, dropout=0.2):
        super().__init__()
        self.pred_len = pred_len
        self.proj_in = nn.Linear(num_variates, hidden_size)
        self.pos_enc = PositionalEncoding(hidden_size)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dim_feedforward=4*hidden_size,
                dropout=dropout,
                activation='gelu',
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.proj_out = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, pred_len * num_variates)  # Output for all variates
        )

    def forward(self, x):
        x = self.proj_in(x)  # [batch, seq_len, hidden]
        x = self.pos_enc(x)
        x = self.encoder(x)
        x = self.proj_out(x[:, -1, :])  # Use last token
        x = x.view(-1, self.pred_len, x.shape[-1] // self.pred_len)  # Reshape to [batch, pred_len, num_variates]
        return x  # [batch, pred_len, num_variates]

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[..., 0::2] = torch.sin(position * div_term)
        pe[..., 1::2] = torch.cos(position * div_term)
        self.pe = nn.Parameter(pe, requires_grad=False)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Enhanced Loss Function with Confidence
class ECG_loss(nn.Module):
    def __init__(self, l2_weight=1e-4):
        super().__init__()
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()
        self.l2_weight = l2_weight

    def forward(self, pred, target, model_parameters=None):
        # Ensure pred and target have compatible shapes
        if pred.shape != target.shape:
            print(f"Warning: Shape mismatch - pred: {pred.shape}, target: {target.shape}")
            pred = pred.expand_as(target[:, :, :pred.shape[-1]])  # Match to target variates
        weights = torch.ones_like(target)
        gradients = torch.abs(target[:, 1:] - target[:, :-1])
        weights[:, :-1] += 5 * gradients  # Weight high-gradient areas
        mse_loss = (weights * (pred - target)**2).mean()
        mae_loss = (weights * torch.abs(pred - target)).mean()
        l2_loss = sum(p.pow(2).sum() for p in model_parameters) * self.l2_weight if model_parameters else 0
        return 0.7 * mse_loss + 0.3 * mae_loss + l2_loss

# Robust Scaler with Clipping
class RobustScaler:
    def __init__(self, data):
        self.median = np.median(data, axis=0)
        self.iqr = np.percentile(data, 75, axis=0) - np.percentile(data, 25, axis=0)
        self.iqr[self.iqr == 0] = 1.0

    def normalize(self, data):
        return np.clip((data - self.median) / self.iqr, -5, 5)

    def denormalize(self, data):
        return data * self.iqr + self.median

# TimeSeriesDataset with Additional Augmentation
class TimeSeriesDataset(Dataset):
    def __init__(self, data, seq_len, pred_len):
        self.features = data.values.astype(np.float32)
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.num_samples = len(data) - seq_len - pred_len + 1
        if self.num_samples <= 0:
            raise ValueError(f"Insufficient data: num_samples = {self.num_samples}, seq_len = {seq_len}, pred_len = {pred_len}")
        rng = np.random.default_rng()
        noise1 = rng.normal(0, 0.01, (self.num_samples, seq_len, self.features.shape[1])).astype(np.float32)
        scale = rng.uniform(0.95, 1.05, (self.num_samples, 1, 1)).astype(np.float32)
        noise2 = rng.normal(0, 0.005, (self.num_samples, seq_len, self.features.shape[1])).astype(np.float32)
        warp_factor = rng.uniform(0.98, 1.02, self.num_samples).astype(np.float32)
        self.augmented = []
        time_orig = np.linspace(0, 1, seq_len)
        for idx in range(self.num_samples):
            x = self.features[idx:idx + seq_len]
            time_warped = np.linspace(0, 1, seq_len) * warp_factor[idx]
            time_warped = time_warped / time_warped[-1] * time_orig[-1]
            x_warped = np.zeros_like(x)
            for v in range(x.shape[1]):
                interp = interp1d(time_orig, x[:, v], kind='linear', fill_value="extrapolate")
                x_warped[:, v] = interp(time_warped)
            t = np.linspace(0, 1, seq_len)
            wander = 0.02 * np.sin(2 * np.pi * 0.5 * t)
            x = x_warped + wander[:, None] + noise1[idx]
            x = x * scale[idx] + noise2[idx]
            x = np.clip(x, -5, 5)
            self.augmented.append(x)

    def __len__(self):
        return max(0, self.num_samples)
    
    def __getitem__(self, idx):
        idx = idx % self.num_samples
        x = self.augmented[idx]
        y = self.features[idx + self.seq_len:idx + self.seq_len + self.pred_len]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

# Metrics with Confidence Interval
def calculate_metrics(true, pred, confidence=0.95):
    rmse = torch.sqrt(torch.mean((true - pred)**2, dim=(0, 1)))
    mae = torch.mean(torch.abs(true - pred), dim=(0, 1))
    std_error = torch.std(true - pred, dim=(0, 1))
    z_score = 1.96  # For 95% confidence
    conf_interval = z_score * std_error / np.sqrt(true.size(0))
    return {'RMSE': rmse.tolist(), 'MAE': mae.tolist(), 'Conf_Interval': conf_interval.tolist()}

# Plot Function
import matplotlib.pyplot as plt
import numpy as np
def plot_actual_vs_predicted(inputs, trues, preds, variates_to_plot=1, filename="ecg_prediction_qitransformer.png", fs=250):
    # Create figure with IEEE-compliant settings
    plt.figure(figsize=(8, 4), dpi=600)
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 10
    plt.rcParams['axes.linewidth'] = 1.0
    plt.rcParams['lines.linewidth'] = 1.5
    
    # Select data segment to plot (focusing on a representative QRS complex)
    input_len = inputs.shape[1]
    pred_len = preds.shape[1]
    total_len = input_len + pred_len
    time_steps_full = np.arange(total_len) * (1000 / fs)  # Convert to milliseconds
    
    # Find the most dynamic segment (highest variance) to showcase
    combined = np.concatenate([inputs[0, :, variates_to_plot], preds[0, :, variates_to_plot]])
    variances = np.array([np.var(combined[i:i+100]) for i in range(len(combined)-100)])
    start_idx = np.argmax(variances)
    window_size = 200  # 800 ms window (200 samples at 250Hz)
    
    # Extract the segments
    actual_input = inputs[0, start_idx:start_idx + (window_size - pred_len), variates_to_plot]
    actual_trues = trues[0, :min(pred_len, window_size - len(actual_input)), variates_to_plot]
    actual = np.concatenate([actual_input, actual_trues])[:window_size]
    
    pred_input = inputs[0, start_idx:start_idx + (window_size - pred_len), variates_to_plot]
    pred_preds = preds[0, :min(pred_len, window_size - len(pred_input)), variates_to_plot]
    prediction = np.concatenate([pred_input, pred_preds])[:window_size]
    
    time_steps = time_steps_full[start_idx:start_idx + window_size]
    
    # Normalize for better visualization
    actual = (actual - np.min(actual)) / (np.max(actual) - np.min(actual)) * 1.2  # Scale to 0-1.2 mV
    prediction = (prediction - np.min(prediction)) / (np.max(prediction) - np.min(prediction)) * 1.2
    
    # Plot with enhanced styling
    plt.plot(time_steps, actual, label="Ground Truth", color="#2c7bb6", linewidth=1.8, alpha=0.9)
    plt.plot(time_steps, prediction, label="iTransformer Prediction", color="#d7191c", linewidth=1.8, linestyle='--', alpha=0.9)
    
    # Detect and annotate key features
    peaks, _ = find_peaks(actual, height=0.6, distance=20)
    if len(peaks) > 0:
        qrs_pos = peaks[0]
        plt.axvline(x=time_steps[qrs_pos], color='gray', linestyle=':', linewidth=1, alpha=0.7)
        plt.text(time_steps[qrs_pos], 1.25, 'QRS', ha='center', va='bottom', 
                fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
        
        # Highlight QRS complex
        qrs_start = max(0, qrs_pos - 10)
        qrs_end = min(window_size-1, qrs_pos + 10)
        plt.axvspan(time_steps[qrs_start], time_steps[qrs_end], 
                   color='#fdae61', alpha=0.2, label='QRS Complex')
    
    # Add shaded region for prediction area
    pred_start = input_len - start_idx
    if pred_start > 0 and pred_start < window_size:
        plt.axvspan(time_steps[pred_start], time_steps[-1], 
                   color='#abd9e9', alpha=0.15, label='Prediction Horizon')
    
    # Formatting
    plt.title("ECG Prediction: iTransformer vs Ground Truth", fontsize=11, pad=12)
    plt.xlabel("Time (ms)", fontsize=10, labelpad=5)
    plt.ylabel("Normalized Amplitude (mV)", fontsize=10, labelpad=5)
    plt.ylim(-0.1, 1.3)
    plt.xlim(time_steps[0], time_steps[-1])
    
    # IEEE-style legend
    legend = plt.legend(loc='upper right', fontsize=9, frameon=True, 
                       framealpha=1, edgecolor='black', facecolor='white')
    legend.get_frame().set_linewidth(0.8)
    
    # IEEE-style grid and ticks
    plt.grid(True, linestyle=':', linewidth=0.6, alpha=0.5)
    plt.tick_params(axis='both', which='both', direction='in', top=True, right=True, width=0.8)
    
    # Adjust layout and save
    plt.tight_layout(pad=1.5)
    plt.savefig(filename, dpi=600, bbox_inches='tight', format='png')
    plt.savefig(filename.replace('.png', '.eps'), dpi=600, bbox_inches='tight', format='eps')  # EPS for IEEE
    plt.close()


def main():
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_ids = list(range(torch.cuda.device_count()))
    print(f"Using device: {device}, GPUs: {device_ids}")

    # Load and preprocess data
    raw_data = pd.read_csv("/kaggle/input/7309records/7309_arrhythmia_balanced.csv")
    raw_data.replace([np.inf, -np.inf], np.nan, inplace=True)
    raw_data.dropna(inplace=True)
    print(f"Raw data shape after dropping NaN: {raw_data.shape}")

    # Wavelet denoising
    filtered_data = np.apply_along_axis(wavelet_denoise, 0, raw_data.drop(columns=['type']).values)
    filtered_data = bandpass_filter(filtered_data, lowcut=0.5, highcut=45.0, fs=360.0)
    filtered_data = np.clip(filtered_data, np.percentile(filtered_data, 1), np.percentile(filtered_data, 99))
    filtered_df = pd.DataFrame(filtered_data, columns=raw_data.columns[:-1])

    # Adaptive VMD parameters
    signal = filtered_df.values[:, 0]
    entropy = mad(signal)
    K = max(8, min(15, int(entropy * 10)))
    alpha = 2000 + entropy * 5000

    # Dynamic embedding dimension
    _, embedding_dim = false_nearest_neighbors(signal, max_dim=10)
    _, tau = mutual_information(signal, max_tau=50)
    print(f"Selected K={K}, embedding_dim={embedding_dim}, tau={tau}")

    # VMD-PSR
    cache_file = f"vmd_psr_cache_K{K}_dim{embedding_dim}.pkl"
    if os.path.exists(cache_file):
        print("Loading cached VMD-PSR...")
        with open(cache_file, 'rb') as f:
            vmd_psr_data = pickle.load(f)
    else:
        print("Computing VMD-PSR...")
        vmd_psr_data = vmd_psr(filtered_df.values, K=K, embedding_dim=embedding_dim, tau=tau, alpha=alpha)
        with open(cache_file, 'wb') as f:
            pickle.dump(vmd_psr_data, f)

    modes, _, _ = vmd_wrapper(signal, alpha=alpha, tau=0, K=K, DC=0, init=1, tol=1e-6)
    # plot_vmd(signal, modes, K)
    # plot_fnn(signal)
    # plot_mi(signal)

    vmd_psr_df = pd.DataFrame(vmd_psr_data, columns=[f"vmd_psr_{i}" for i in range(vmd_psr_data.shape[1])])
    data_ma = filtered_df.rolling(window=5).mean().bfill()
    data_ma.columns = [f"{col}_ma" for col in data_ma.columns]
    enhanced_data = pd.concat([filtered_df, vmd_psr_df, data_ma], axis=1)
    print(f"Enhanced data shape: {enhanced_data.shape}, num_variates: {enhanced_data.shape[1]}")

    # Robust Scaler
    scaler = RobustScaler(enhanced_data.values[:int(0.7*len(enhanced_data))])
    norm_data = pd.DataFrame(scaler.normalize(enhanced_data.values), columns=enhanced_data.columns)

    # Config
    config = {
        'num_variates': norm_data.shape[1],
        'seq_len': 250,
        'pred_len': 50,
        'hidden_size': 512,
        'num_heads': 8,
        'num_layers': 6,
        'dropout': 0.2,
        'lr': 5e-5,
        'batch_size': 64,
        'epochs': 200,  # Increased epochs
        'patience': 10,  # Increased patience
        'grad_clip': 1.0,
    }

    # Data loaders
    total_len = len(norm_data)
    train_end = int(total_len * 0.7)
    val_end = train_end + int(total_len * 0.15)
    train_data = norm_data.iloc[:train_end]
    val_data = norm_data.iloc[train_end:val_end]
    test_data = norm_data.iloc[val_end:]

    train_dataset = TimeSeriesDataset(train_data, config['seq_len'], config['pred_len'])
    val_dataset = TimeSeriesDataset(val_data, config['seq_len'], config['pred_len'])
    test_dataset = TimeSeriesDataset(test_data, config['seq_len'], config['pred_len'])
    print(f"Train dataset num_samples: {train_dataset.num_samples}")
    print(f"Val dataset num_samples: {val_dataset.num_samples}")
    print(f"Test dataset num_samples: {test_dataset.num_samples}")

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True, num_workers=4)
    print(f"Train loader length: {len(train_loader)}")
    print(f"Val loader length: {len(val_loader)}")
    print(f"Test loader length: {len(test_loader)}")

    if len(train_loader) == 0:
        raise ValueError("Train loader is empty. Check dataset size and batch size.")

    # Model
    model = EnhancediTransformer(
        num_variates=config['num_variates'], seq_len=config['seq_len'], pred_len=config['pred_len'],
        hidden_size=config['hidden_size'], num_heads=config['num_heads'], num_layers=config['num_layers'],
        dropout=config['dropout']
    ).to(device)

    if len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)
        print(f"Using DataParallel on GPUs: {device_ids}")

    optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
    criterion = ECG_loss(l2_weight=1e-4)
    grad_scaler = GradScaler('cuda')

    best_val_rmse = float('inf')
    patience_counter = 0
    train_losses, val_losses, rmses, maes = [], [], [], []
    cpu_usage, gpu_usage = [], []

    for epoch in range(config['epochs']):
        gc.collect()
        torch.cuda.empty_cache()

        # Train
        model.train()
        train_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            with autocast('cuda', dtype=torch.float16):
                pred = model(x)
                loss = criterion(pred, y)
            grad_scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip'])
            grad_scaler.step(optimizer)
            grad_scaler.update()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        # Validate
        model.eval()
        val_loss = 0
        preds, trues, inputs = [], [], []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                with autocast('cuda', dtype=torch.float16):
                    pred = model(x)
                    loss = criterion(pred, y)
                val_loss += loss.item()
                preds.append(pred.cpu())
                trues.append(y.cpu())
                inputs.append(x.cpu())
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        preds = torch.cat(preds)
        trues = torch.cat(trues)
        inputs = torch.cat(inputs)
        preds_denorm = scaler.denormalize(preds.numpy())
        trues_denorm = scaler.denormalize(trues.numpy())
        inputs_denorm = scaler.denormalize(inputs.numpy())
        val_metrics = calculate_metrics(torch.tensor(trues_denorm), torch.tensor(preds_denorm))
        mean_rmse = np.mean(val_metrics['RMSE'])
        mean_mae = np.mean(val_metrics['MAE'])
        rmses.append(val_metrics['RMSE'])
        maes.append(val_metrics['MAE'])

        scheduler.step()

        cpu_usage.append(psutil.cpu_percent())
        gpu_usage.append(torch.cuda.memory_reserved(0) / torch.cuda.get_device_properties(0).total_memory * 100 if torch.cuda.is_available() else 0)

        print(f"Epoch {epoch+1}: Train Loss={train_loss:.6f}, Val Loss={val_loss:.6f}, "
              f"Val RMSE={mean_rmse:.6f}, Val MAE={mean_mae:.6f}, Conf Interval={np.mean(val_metrics['Conf_Interval']):.6f}")

        if mean_rmse < best_val_rmse:
            best_val_rmse = mean_rmse
            torch.save(model.state_dict(), "best_model.pth")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= config['patience']:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Test
    model.load_state_dict(torch.load("best_model.pth", weights_only=True))
    model.eval()
    test_loss = 0
    preds, trues, inputs = [], [], []
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(test_loader):
            x, y = x.to(device), y.to(device)
            with autocast('cuda', dtype=torch.float16):
                pred = model(x)
                loss = criterion(pred, y[:, :, :pred.shape[-1]])  # Match to predicted variates
            test_loss += loss.item()
            preds.append(pred.cpu())
            trues.append(y.cpu())
            inputs.append(x.cpu())
            # Enhanced debug
            pred = pred.to(dtype=torch.float32)  # Convert to float32 for comparison
            y = y.to(dtype=torch.float32)  # Convert to float32 for comparison
            print(f"Batch {batch_idx}: pred shape: {pred.shape}, min: {pred.min().item()}, max: {pred.max().item()}")
            print(f"Batch {batch_idx}: true shape: {y.shape}, min: {y.min().item()}, max: {y.max().item()}")
            if torch.allclose(pred, y[:, :, :pred.shape[-1]], rtol=1e-5, atol=1e-5):
                print(f"Warning: Batch {batch_idx} - Predictions are identical to ground truth for predicted variates!")
    test_loss /= len(test_loader)
    preds = torch.cat(preds)
    trues = torch.cat(trues)
    inputs = torch.cat(inputs)
    preds_denorm = scaler.denormalize(preds.numpy())
    trues_denorm = scaler.denormalize(trues.numpy())
    inputs_denorm = scaler.denormalize(inputs.numpy())
    test_metrics = calculate_metrics(torch.tensor(trues_denorm), torch.tensor(preds_denorm))
    mean_test_rmse = np.mean(test_metrics['RMSE'])
    mean_test_mae = np.mean(test_metrics['MAE'])

    print("\n=== FINAL TEST RESULTS ===")
    print(f"Test Loss: {test_loss:.6f}")
    print(f"Test RMSE: {mean_test_rmse:.6f}")
    print(f"Test MAE: {mean_test_mae:.6f}")
    print(f"Per-variate RMSE: {[f'{x:.6f}' for x in test_metrics['RMSE'][:3]]}")
    print(f"Confidence Interval: {np.mean(test_metrics['Conf_Interval']):.6f}")

    # Plot actual vs predicted
    print("Generating ECG prediction plot...")
    plot_actual_vs_predicted(
        inputs_denorm, trues_denorm, preds_denorm,
        variates_to_plot=1,
        filename="ecg_prediction_qitransformer.png"
    )

if __name__ == "__main__":
    main()

Using simplified VMD placeholder.
Using device: cuda, GPUs: [0, 1]
Raw data shape after dropping NaN: (7308, 33)
Selected K=8, embedding_dim=3, tau=1
Computing VMD-PSR...
Enhanced data shape: (7308, 832), num_variates: 832
Train dataset num_samples: 4816
Val dataset num_samples: 797
Test dataset num_samples: 798
Train loader length: 76
Val loader length: 13
Test loader length: 13
Using DataParallel on GPUs: [0, 1]


Epoch 1: 100%|██████████| 76/76 [00:14<00:00,  5.42it/s]


Epoch 1: Train Loss=1.022329, Val Loss=4.132458, Val RMSE=0.002041, Val MAE=0.001661, Conf Interval=0.000141


Epoch 2: 100%|██████████| 76/76 [00:11<00:00,  6.45it/s]


Epoch 2: Train Loss=1.015279, Val Loss=4.128566, Val RMSE=0.002040, Val MAE=0.001660, Conf Interval=0.000141


Epoch 3: 100%|██████████| 76/76 [00:11<00:00,  6.41it/s]


Epoch 3: Train Loss=1.014028, Val Loss=4.120420, Val RMSE=0.002038, Val MAE=0.001657, Conf Interval=0.000141


Epoch 4: 100%|██████████| 76/76 [00:11<00:00,  6.39it/s]


Epoch 4: Train Loss=0.999736, Val Loss=4.110842, Val RMSE=0.002036, Val MAE=0.001655, Conf Interval=0.000141


Epoch 5: 100%|██████████| 76/76 [00:11<00:00,  6.36it/s]


Epoch 5: Train Loss=0.986428, Val Loss=4.099949, Val RMSE=0.002035, Val MAE=0.001652, Conf Interval=0.000141


Epoch 6: 100%|██████████| 76/76 [00:12<00:00,  6.31it/s]


Epoch 6: Train Loss=0.960037, Val Loss=4.088598, Val RMSE=0.002030, Val MAE=0.001648, Conf Interval=0.000141


Epoch 7: 100%|██████████| 76/76 [00:12<00:00,  6.31it/s]


Epoch 7: Train Loss=0.945774, Val Loss=4.082528, Val RMSE=0.002027, Val MAE=0.001644, Conf Interval=0.000140


Epoch 8: 100%|██████████| 76/76 [00:12<00:00,  6.28it/s]


Epoch 8: Train Loss=0.928208, Val Loss=4.076137, Val RMSE=0.002023, Val MAE=0.001639, Conf Interval=0.000140


Epoch 9: 100%|██████████| 76/76 [00:12<00:00,  6.22it/s]


Epoch 9: Train Loss=0.906795, Val Loss=4.069910, Val RMSE=0.002019, Val MAE=0.001635, Conf Interval=0.000140


Epoch 10: 100%|██████████| 76/76 [00:12<00:00,  6.18it/s]


Epoch 10: Train Loss=0.893073, Val Loss=4.072704, Val RMSE=0.002018, Val MAE=0.001633, Conf Interval=0.000140


Epoch 11: 100%|██████████| 76/76 [00:12<00:00,  6.14it/s]


Epoch 11: Train Loss=0.878683, Val Loss=4.068005, Val RMSE=0.002014, Val MAE=0.001629, Conf Interval=0.000139


Epoch 12: 100%|██████████| 76/76 [00:12<00:00,  6.12it/s]


Epoch 12: Train Loss=0.871945, Val Loss=4.074531, Val RMSE=0.002014, Val MAE=0.001628, Conf Interval=0.000139


Epoch 13: 100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


Epoch 13: Train Loss=0.858586, Val Loss=4.066952, Val RMSE=0.002011, Val MAE=0.001625, Conf Interval=0.000139


Epoch 14: 100%|██████████| 76/76 [00:12<00:00,  6.06it/s]


Epoch 14: Train Loss=0.848275, Val Loss=4.075235, Val RMSE=0.002011, Val MAE=0.001625, Conf Interval=0.000139


Epoch 15: 100%|██████████| 76/76 [00:12<00:00,  6.00it/s]


Epoch 15: Train Loss=0.838856, Val Loss=4.081541, Val RMSE=0.002011, Val MAE=0.001625, Conf Interval=0.000139


Epoch 16: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 16: Train Loss=0.827394, Val Loss=4.078063, Val RMSE=0.002008, Val MAE=0.001621, Conf Interval=0.000139


Epoch 17: 100%|██████████| 76/76 [00:12<00:00,  6.06it/s]


Epoch 17: Train Loss=0.820245, Val Loss=4.092028, Val RMSE=0.002008, Val MAE=0.001621, Conf Interval=0.000139


Epoch 18: 100%|██████████| 76/76 [00:12<00:00,  6.03it/s]


Epoch 18: Train Loss=0.810855, Val Loss=4.097919, Val RMSE=0.002006, Val MAE=0.001619, Conf Interval=0.000139


Epoch 19: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 19: Train Loss=0.806006, Val Loss=4.092917, Val RMSE=0.002004, Val MAE=0.001619, Conf Interval=0.000139


Epoch 20: 100%|██████████| 76/76 [00:12<00:00,  6.02it/s]


Epoch 20: Train Loss=0.795499, Val Loss=4.110623, Val RMSE=0.002006, Val MAE=0.001619, Conf Interval=0.000139


Epoch 21: 100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


Epoch 21: Train Loss=0.787302, Val Loss=4.118742, Val RMSE=0.002005, Val MAE=0.001618, Conf Interval=0.000139


Epoch 22: 100%|██████████| 76/76 [00:12<00:00,  6.08it/s]


Epoch 22: Train Loss=0.779208, Val Loss=4.131795, Val RMSE=0.002006, Val MAE=0.001619, Conf Interval=0.000139


Epoch 23: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 23: Train Loss=0.770876, Val Loss=4.132094, Val RMSE=0.002005, Val MAE=0.001618, Conf Interval=0.000138


Epoch 24: 100%|██████████| 76/76 [00:12<00:00,  6.01it/s]


Epoch 24: Train Loss=0.764655, Val Loss=4.125976, Val RMSE=0.001999, Val MAE=0.001614, Conf Interval=0.000138


Epoch 25: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Epoch 25: Train Loss=0.757466, Val Loss=4.137862, Val RMSE=0.002001, Val MAE=0.001616, Conf Interval=0.000138


Epoch 26: 100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


Epoch 26: Train Loss=0.752702, Val Loss=4.145444, Val RMSE=0.002001, Val MAE=0.001616, Conf Interval=0.000138


Epoch 27: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Epoch 27: Train Loss=0.743798, Val Loss=4.146188, Val RMSE=0.001998, Val MAE=0.001613, Conf Interval=0.000138


Epoch 28: 100%|██████████| 76/76 [00:12<00:00,  6.09it/s]


Epoch 28: Train Loss=0.740628, Val Loss=4.147155, Val RMSE=0.001998, Val MAE=0.001613, Conf Interval=0.000138


Epoch 29: 100%|██████████| 76/76 [00:12<00:00,  6.06it/s]


Epoch 29: Train Loss=0.736439, Val Loss=4.150264, Val RMSE=0.001998, Val MAE=0.001614, Conf Interval=0.000138


Epoch 30: 100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


Epoch 30: Train Loss=0.731166, Val Loss=4.158862, Val RMSE=0.002000, Val MAE=0.001615, Conf Interval=0.000138


Epoch 31: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Epoch 31: Train Loss=0.728901, Val Loss=4.164444, Val RMSE=0.001999, Val MAE=0.001615, Conf Interval=0.000138


Epoch 32: 100%|██████████| 76/76 [00:12<00:00,  6.10it/s]


Epoch 32: Train Loss=0.729294, Val Loss=4.160769, Val RMSE=0.001998, Val MAE=0.001614, Conf Interval=0.000138


Epoch 33: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 33: Train Loss=0.721445, Val Loss=4.164396, Val RMSE=0.001997, Val MAE=0.001614, Conf Interval=0.000138


Epoch 34: 100%|██████████| 76/76 [00:12<00:00,  6.11it/s]


Epoch 34: Train Loss=0.717631, Val Loss=4.166675, Val RMSE=0.001997, Val MAE=0.001613, Conf Interval=0.000138


Epoch 35: 100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


Epoch 35: Train Loss=0.715361, Val Loss=4.175329, Val RMSE=0.002001, Val MAE=0.001617, Conf Interval=0.000138


Epoch 36: 100%|██████████| 76/76 [00:12<00:00,  6.11it/s]


Epoch 36: Train Loss=0.713759, Val Loss=4.184813, Val RMSE=0.002000, Val MAE=0.001615, Conf Interval=0.000138


Epoch 37: 100%|██████████| 76/76 [00:12<00:00,  6.00it/s]


Epoch 37: Train Loss=0.706748, Val Loss=4.181702, Val RMSE=0.002000, Val MAE=0.001616, Conf Interval=0.000138


Epoch 38: 100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


Epoch 38: Train Loss=0.706414, Val Loss=4.181736, Val RMSE=0.001999, Val MAE=0.001615, Conf Interval=0.000138


Epoch 39: 100%|██████████| 76/76 [00:12<00:00,  6.03it/s]


Epoch 39: Train Loss=0.705494, Val Loss=4.188712, Val RMSE=0.001998, Val MAE=0.001615, Conf Interval=0.000138


Epoch 40: 100%|██████████| 76/76 [00:12<00:00,  6.10it/s]


Epoch 40: Train Loss=0.703528, Val Loss=4.189374, Val RMSE=0.002000, Val MAE=0.001616, Conf Interval=0.000138


Epoch 41: 100%|██████████| 76/76 [00:12<00:00,  5.98it/s]


Epoch 41: Train Loss=0.699170, Val Loss=4.187392, Val RMSE=0.001997, Val MAE=0.001613, Conf Interval=0.000138


Epoch 42: 100%|██████████| 76/76 [00:12<00:00,  6.03it/s]


Epoch 42: Train Loss=0.692817, Val Loss=4.189233, Val RMSE=0.001998, Val MAE=0.001614, Conf Interval=0.000138


Epoch 43: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Epoch 43: Train Loss=0.694611, Val Loss=4.199841, Val RMSE=0.001999, Val MAE=0.001615, Conf Interval=0.000138


Epoch 44: 100%|██████████| 76/76 [00:12<00:00,  6.06it/s]


Epoch 44: Train Loss=0.689223, Val Loss=4.191870, Val RMSE=0.001996, Val MAE=0.001613, Conf Interval=0.000137


Epoch 45: 100%|██████████| 76/76 [00:12<00:00,  6.08it/s]


Epoch 45: Train Loss=0.690378, Val Loss=4.190897, Val RMSE=0.001996, Val MAE=0.001612, Conf Interval=0.000137


Epoch 46: 100%|██████████| 76/76 [00:12<00:00,  6.06it/s]


Epoch 46: Train Loss=0.682828, Val Loss=4.193616, Val RMSE=0.001994, Val MAE=0.001610, Conf Interval=0.000137


Epoch 47: 100%|██████████| 76/76 [00:12<00:00,  6.10it/s]


Epoch 47: Train Loss=0.682908, Val Loss=4.192631, Val RMSE=0.001996, Val MAE=0.001613, Conf Interval=0.000138


Epoch 48: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 48: Train Loss=0.681672, Val Loss=4.183069, Val RMSE=0.001992, Val MAE=0.001610, Conf Interval=0.000137


Epoch 49: 100%|██████████| 76/76 [00:12<00:00,  6.01it/s]


Epoch 49: Train Loss=0.682288, Val Loss=4.187618, Val RMSE=0.001992, Val MAE=0.001609, Conf Interval=0.000137


Epoch 50: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 50: Train Loss=0.676320, Val Loss=4.191694, Val RMSE=0.001992, Val MAE=0.001610, Conf Interval=0.000137


Epoch 51: 100%|██████████| 76/76 [00:12<00:00,  6.03it/s]


Epoch 51: Train Loss=0.674878, Val Loss=4.199371, Val RMSE=0.001997, Val MAE=0.001614, Conf Interval=0.000138


Epoch 52: 100%|██████████| 76/76 [00:12<00:00,  6.08it/s]


Epoch 52: Train Loss=0.672129, Val Loss=4.205533, Val RMSE=0.001996, Val MAE=0.001612, Conf Interval=0.000137


Epoch 53: 100%|██████████| 76/76 [00:12<00:00,  6.03it/s]


Epoch 53: Train Loss=0.675259, Val Loss=4.193388, Val RMSE=0.001992, Val MAE=0.001609, Conf Interval=0.000137


Epoch 54: 100%|██████████| 76/76 [00:12<00:00,  6.01it/s]


Epoch 54: Train Loss=0.670484, Val Loss=4.204041, Val RMSE=0.001995, Val MAE=0.001611, Conf Interval=0.000137


Epoch 55: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Epoch 55: Train Loss=0.666358, Val Loss=4.194360, Val RMSE=0.001992, Val MAE=0.001609, Conf Interval=0.000137


Epoch 56: 100%|██████████| 76/76 [00:12<00:00,  6.10it/s]


Epoch 56: Train Loss=0.665566, Val Loss=4.206556, Val RMSE=0.001995, Val MAE=0.001612, Conf Interval=0.000137


Epoch 57: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Epoch 57: Train Loss=0.667353, Val Loss=4.204909, Val RMSE=0.001997, Val MAE=0.001614, Conf Interval=0.000138


Epoch 58: 100%|██████████| 76/76 [00:12<00:00,  5.97it/s]


Epoch 58: Train Loss=0.664945, Val Loss=4.200916, Val RMSE=0.001992, Val MAE=0.001609, Conf Interval=0.000137


Epoch 59: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 59: Train Loss=0.666919, Val Loss=4.211998, Val RMSE=0.001997, Val MAE=0.001613, Conf Interval=0.000137


Epoch 60: 100%|██████████| 76/76 [00:12<00:00,  6.06it/s]


Epoch 60: Train Loss=0.663195, Val Loss=4.204804, Val RMSE=0.001994, Val MAE=0.001611, Conf Interval=0.000137


Epoch 61: 100%|██████████| 76/76 [00:12<00:00,  6.03it/s]


Epoch 61: Train Loss=0.661585, Val Loss=4.210944, Val RMSE=0.001997, Val MAE=0.001613, Conf Interval=0.000138


Epoch 62: 100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


Epoch 62: Train Loss=0.661351, Val Loss=4.201164, Val RMSE=0.001994, Val MAE=0.001611, Conf Interval=0.000137


Epoch 63: 100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


Epoch 63: Train Loss=0.658641, Val Loss=4.209354, Val RMSE=0.001996, Val MAE=0.001613, Conf Interval=0.000137


Epoch 64: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Epoch 64: Train Loss=0.658015, Val Loss=4.217575, Val RMSE=0.001998, Val MAE=0.001614, Conf Interval=0.000138


Epoch 65: 100%|██████████| 76/76 [00:12<00:00,  6.08it/s]


Epoch 65: Train Loss=0.656973, Val Loss=4.203051, Val RMSE=0.001995, Val MAE=0.001612, Conf Interval=0.000137
Early stopping at epoch 65
Batch 0: pred shape: torch.Size([64, 50, 832]), min: -1.94140625, max: 2.498046875
Batch 0: true shape: torch.Size([64, 50, 832]), min: -4.352345943450928, max: 5.0
Batch 1: pred shape: torch.Size([64, 50, 832]), min: -2.3671875, max: 1.396484375
Batch 1: true shape: torch.Size([64, 50, 832]), min: -4.326645374298096, max: 5.0
Batch 2: pred shape: torch.Size([64, 50, 832]), min: -2.607421875, max: 1.86328125
Batch 2: true shape: torch.Size([64, 50, 832]), min: -5.0, max: 5.0
Batch 3: pred shape: torch.Size([64, 50, 832]), min: -2.037109375, max: 1.2978515625
Batch 3: true shape: torch.Size([64, 50, 832]), min: -5.0, max: 5.0
Batch 4: pred shape: torch.Size([64, 50, 832]), min: -1.9013671875, max: 1.4091796875
Batch 4: true shape: torch.Size([64, 50, 832]), min: -5.0, max: 5.0
Batch 5: pred shape: torch.Size([64, 50, 832]), min: -1.4658203125, max: 1.1