In [1]:
import os
import time
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from sklearn.preprocessing import StandardScaler

# =========================
# 0. 공통 설정
# =========================

def setup_seed(seed: int = 2025):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(2025)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =========================
# 1. 실험 설정 (요청사항)
# =========================

DATASET_NAMES = [
    'ETTh1',
    'ETTh2',
    'ETTm1',
    'ETTm2',
    'Electricity',
    'Weather',
    'Traffic',
]

INPUT_LENS = [96, 336]
PRED_LENS = [96, 192, 336, 720]

# 공정 비교용 기본 하이퍼파라미터(필요하면 아래만 조절)
BATCH_SIZE = 32
EPOCHS = 30
PATIENCE = 5
MIN_EPOCHS = 5
LR = 1e-3
WEIGHT_DECAY = 1e-4
GRAD_CLIP = 1.0
USE_AMP = True

# =========================
# Colab 경로 설정
# =========================
# Colab에서는 기본 작업 디렉토리가 /content 입니다.
# 데이터와 결과물이 섞이지 않게 /content/data, /content/artifacts 로 통일합니다.

import sys

IN_COLAB = 'google.colab' in sys.modules
BASE_DIR = '/content' if IN_COLAB else '.'
DATA_DIR = os.path.join(BASE_DIR, 'data')
ARTIFACT_DIR = os.path.join(BASE_DIR, 'artifacts')

# 결과/플롯 저장 폴더
PLOT_DIR = os.path.join(ARTIFACT_DIR, 'plots')
CSV_DIR = os.path.join(ARTIFACT_DIR, 'results')
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(PLOT_DIR, exist_ok=True)
os.makedirs(CSV_DIR, exist_ok=True)

# =========================
# 2. 데이터셋 다운로드 (wget)
# =========================
# - ETT는 확실히 zhouhaoyi/ETDataset에 존재
# - Electricity/Weather/Traffic도 보통 같은 레포의 표준 파일명을 씁니다.
#   (만약 다운로드가 실패하면, 아래 URL만 사용 환경에 맞게 바꿔주면 됩니다.)

DATASET_SPECS = {
    'ETTh1': {
        'filename': 'ETTh1.csv',
        'url': 'https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv',
        'type': 'ett_hour',
    },
    'ETTh2': {
        'filename': 'ETTh2.csv',
        'url': 'https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv',
        'type': 'ett_hour',
    },
    'ETTm1': {
        'filename': 'ETTm1.csv',
        'url': 'https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv',
        'type': 'ett_minute',
    },
    'ETTm2': {
        'filename': 'ETTm2.csv',
        'url': 'https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm2.csv',
        'type': 'ett_minute',
    },
    'Electricity': {
        'filename': 'electricity.csv',
        # NOTE: zhouhaoyi/ETDataset에는 electricity/weather/traffic 파일이 없음.
        # HuggingFace mirror를 사용(Colab에서 wget으로 바로 받기 가능)
        'url': 'https://huggingface.co/datasets/dunzane/time-series-dataset/resolve/main/electricity/electricity.csv',
        'type': 'custom',
    },
    'Weather': {
        'filename': 'weather.csv',
        'url': 'https://huggingface.co/datasets/dunzane/time-series-dataset/resolve/main/weather/weather.csv',
        'type': 'custom',
    },
    'Traffic': {
        'filename': 'traffic.csv',
        'url': 'https://huggingface.co/datasets/dunzane/time-series-dataset/resolve/main/traffic/traffic.csv',
        'type': 'custom',
    },
}


def maybe_download(url: str, out_path: str, force: bool = False):
    """url을 out_path로 다운로드.

    - force=True이면 기존 파일이 있어도 다시 받습니다.
    - 다운로드 중 끊기면 0 bytes 파일이 남을 수 있어서, ensure_datasets에서 size 체크로 재시도합니다.
    """

    if (not force) and os.path.exists(out_path) and os.path.getsize(out_path) > 0:
        return

    os.makedirs(os.path.dirname(out_path) or '.', exist_ok=True)

    # quiet + retry 옵션(네트워크 불안정 대비)
    !wget -q --show-progress --progress=bar:force:noscroll -O {out_path} {url}


def ensure_datasets(root: str = None, force: bool = False):
    """필요한 데이터셋을 root 아래에 준비합니다.

    - 파일이 없거나, **0 bytes**이면 자동으로 재다운로드합니다.
    - force=True이면 전부 다시 받습니다.
    """

    root = root or DATA_DIR  # Colab 기본: /content/data
    os.makedirs(root, exist_ok=True)

    for name in DATASET_NAMES:
        spec = DATASET_SPECS[name]
        out_path = os.path.join(root, spec['filename'])

        need = force or (not os.path.exists(out_path)) or (os.path.getsize(out_path) == 0)

        if need:
            print(f"Downloading {name} -> {out_path}")
            maybe_download(spec['url'], out_path, force=True)

        # 다운로드 후 최종 점검
        if not os.path.exists(out_path) or os.path.getsize(out_path) == 0:
            raise RuntimeError(f"Dataset file is empty or missing: {out_path}")

        print(f"{name}: OK ({out_path})")


# 필요하면 아래를 실행하세요(Colab 기본 경로: /content/data).
ensure_datasets()  # 또는 ensure_datasets('/content/data')

Using device: cuda
Downloading ETTh1 -> /content/data/ETTh1.csv
ETTh1: OK (/content/data/ETTh1.csv)
Downloading ETTh2 -> /content/data/ETTh2.csv
ETTh2: OK (/content/data/ETTh2.csv)
Downloading ETTm1 -> /content/data/ETTm1.csv
ETTm1: OK (/content/data/ETTm1.csv)
Downloading ETTm2 -> /content/data/ETTm2.csv
ETTm2: OK (/content/data/ETTm2.csv)
Downloading Electricity -> /content/data/electricity.csv
Electricity: OK (/content/data/electricity.csv)
Downloading Weather -> /content/data/weather.csv
Weather: OK (/content/data/weather.csv)
Downloading Traffic -> /content/data/traffic.csv
Traffic: OK (/content/data/traffic.csv)


In [2]:
# =========================
# 3. 데이터셋 클래스
# =========================

class Dataset_ETT_Hour(Dataset):
    """ETTh1/ETTh2 표준 split (12개월/4개월/4개월)."""

    def __init__(
        self,
        root_path,
        flag='train',
        size=None,
        features='M',
        data_path='ETTh1.csv',
        target='OT',
        scale=True,
    ):
        # size = [seq_len, label_len(미사용), pred_len]
        self.seq_len = size[0]
        self.pred_len = size[2]

        assert flag in ['train', 'val', 'test']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        # 12개월 / 4개월 / 4개월
        border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
        border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]

        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features in ['M', 'MS']:
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        else:
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0] : border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end
        r_end = r_begin + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        return torch.FloatTensor(seq_x), torch.FloatTensor(seq_y)

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)


class Dataset_ETT_Minute(Dataset):
    """ETTm1/ETTm2 표준 split.

    - 관례적으로 (12개월/4개월/4개월) 분할을 15분 간격 샘플 수로 환산합니다.
    - 1일 = 24*4 = 96
    """

    def __init__(
        self,
        root_path,
        flag='train',
        size=None,
        features='M',
        data_path='ETTm1.csv',
        target='OT',
        scale=True,
    ):
        self.seq_len = size[0]
        self.pred_len = size[2]

        assert flag in ['train', 'val', 'test']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        # 15분 간격: 1일=96
        border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
        border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]

        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features in ['M', 'MS']:
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        else:
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0] : border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end
        r_end = r_begin + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        return torch.FloatTensor(seq_x), torch.FloatTensor(seq_y)

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)


class Dataset_CustomCSV(Dataset):
    """Electricity/Weather/Traffic 같은 일반 CSV용.

    - 기본 가정: (date 컬럼이 있으면 첫 컬럼) + 나머지 수치 컬럼들
    - split: train/val/test = 0.7/0.1/0.2 (표준 비율)
    - scale: train 구간으로 StandardScaler fit
    - (옵션) max_features로 변수 수가 너무 큰 데이터셋을 안전하게 축소 가능
    """

    def __init__(
        self,
        root_path,
        flag='train',
        size=None,
        data_path='electricity.csv',
        scale=True,
        split_ratio=(0.7, 0.1, 0.2),
        max_features=None,
    ):
        self.seq_len = size[0]
        self.pred_len = size[2]

        assert flag in ['train', 'val', 'test']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.root_path = root_path
        self.data_path = data_path
        self.scale = scale
        self.split_ratio = split_ratio
        self.max_features = max_features

        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        # 수치 컬럼만 안전하게 선택 (date/문자열 컬럼 제거)
        df_num = df_raw.select_dtypes(include=[np.number])
        if df_num.shape[1] == 0:
            # fallback: 첫 컬럼이 date라고 가정
            df_num = df_raw[df_raw.columns[1:]]

        if self.max_features is not None and df_num.shape[1] > int(self.max_features):
            df_num = df_num.iloc[:, : int(self.max_features)]

        data_all = df_num.values
        n = len(data_all)

        tr = int(n * self.split_ratio[0])
        va = int(n * self.split_ratio[1])
        te = n - tr - va

        border1s = [0, tr - self.seq_len, tr + va - self.seq_len]
        border2s = [tr, tr + va, tr + va + te]

        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.scale:
            train_data = data_all[border1s[0] : border2s[0]]
            self.scaler.fit(train_data)
            data = self.scaler.transform(data_all)
        else:
            data = data_all

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end
        r_end = r_begin + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        return torch.FloatTensor(seq_x), torch.FloatTensor(seq_y)

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)


def build_dataset(
    name: str,
    root: str,
    flag: str,
    seq_len: int,
    pred_len: int,
    max_features_custom=None,
):
    spec = DATASET_SPECS[name]
    ds_type = spec['type']
    if ds_type == 'ett_hour':
        return Dataset_ETT_Hour(root_path=root, flag=flag, size=[seq_len, 0, pred_len], data_path=spec['filename'])
    if ds_type == 'ett_minute':
        return Dataset_ETT_Minute(root_path=root, flag=flag, size=[seq_len, 0, pred_len], data_path=spec['filename'])
    if ds_type == 'custom':
        return Dataset_CustomCSV(
            root_path=root,
            flag=flag,
            size=[seq_len, 0, pred_len],
            data_path=spec['filename'],
            max_features=max_features_custom,
        )
    raise ValueError(f"Unknown dataset type: {ds_type}")

In [3]:
# =========================
# 4. 모델 정의 (요청하신 7개) - 기존 노트북 코드 그대로 이식
# =========================

# -------------------------------------------------
# (A) DecompOpTCN+TCN+Basic (1).ipynb 에 있던 4개
#   - TCN
#   - DecompOpTCN
#   - Decomp-OpTCN-v3(patch)
#   - DLinear
# -------------------------------------------------


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class RevIN(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=True):
        super(RevIN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self._init_params()

    def _init_params(self):
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def forward(self, x, mode: str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        return x

    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim - 1))
        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()

    def _normalize(self, x):
        x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.affine:
            x = (x - self.affine_bias) / (self.affine_weight + self.eps * self.affine_weight)
        x = x * self.stdev
        x = x + self.mean
        return x


class MovingAverage(nn.Module):
    def __init__(self, kernel_size, stride):
        super(MovingAverage, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # x: [Batch, Length, Channel] -> [Batch, Channel, Length]
        x = x.permute(0, 2, 1)
        # Padding
        front = x[:, :, 0:1].repeat(1, 1, (self.kernel_size - 1) // 2)
        end = x[:, :, -1:].repeat(1, 1, (self.kernel_size - 1) // 2)
        x = torch.cat([front, x, end], dim=-1)
        x = self.avg(x)
        x = x.permute(0, 2, 1)
        return x


class SeriesDecomp(nn.Module):
    def __init__(self, kernel_size):
        super(SeriesDecomp, self).__init__()
        self.moving_avg = MovingAverage(kernel_size, stride=1)

    def forward(self, x):
        trend = self.moving_avg(x)
        seasonal = x - trend
        return seasonal, trend


class OperatorMixer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, period=24):
        super(OperatorMixer, self).__init__()
        self.dilation = dilation
        self.kernel_size = kernel_size

        # Causal Padding 계산: (kernel_size - 1) * dilation
        self.padding = (kernel_size - 1) * dilation
        self.chomp = Chomp1d(self.padding)

        # 연산자 커널 정의 (고정 가중치)
        # 1차 차분
        self.register_buffer('diff1_kernel', torch.tensor([-1.0, 1.0]).view(1, 1, 2))
        # 2차 차분
        self.register_buffer('diff2_kernel', torch.tensor([1.0, -2.0, 1.0]).view(1, 1, 3))
        # 이동 평균
        self.register_buffer('ma_kernel', torch.ones(1, 1, kernel_size) / kernel_size)
        # 계절성 차분
        seasonal_k = torch.zeros(1, 1, period + 1)
        seasonal_k[0, 0, 0] = -1.0
        seasonal_k[0, 0, period] = 1.0
        self.register_buffer('seasonal_kernel', seasonal_k)

        # 4개 연산자 결과(diff1, diff2, ma, seasonal)를 concat -> in_channels*4
        self.mixer = nn.Conv1d(in_channels * 4, out_channels, 1)

    def forward(self, x):
        # x shape: [B, C, L]
        B, C, L = x.shape

        def apply_op(kernel, pad):
            out = F.conv1d(
                x,
                kernel.repeat(C, 1, 1),
                padding=pad * self.dilation,
                dilation=self.dilation,
                groups=C,
            )
            if pad > 0:
                out = out[:, :, : -(pad * self.dilation)]
            return out

        op1 = apply_op(self.diff1_kernel, 1)
        op2 = apply_op(self.diff2_kernel, 2)
        op3 = apply_op(self.ma_kernel, self.kernel_size - 1)

        period = self.seasonal_kernel.shape[-1] - 1
        op4 = apply_op(self.seasonal_kernel, period)

        out = torch.cat([op1, op2, op3, op4], dim=1)
        return self.mixer(out)


# Model 1: TCN (Baseline) + RevIN
class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.bn1 = nn.BatchNorm1d(n_outputs)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.bn2 = nn.BatchNorm1d(n_outputs)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.net = nn.Sequential(
            self.conv1,
            self.chomp1,
            self.bn1,
            self.relu1,
            self.dropout1,
            self.conv2,
            self.chomp2,
            self.bn2,
            self.relu2,
            self.dropout2,
        )
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TCN_Baseline(nn.Module):
    def __init__(self, num_inputs, seq_len, pred_len, num_channels=[32] * 4, kernel_size=5, dropout=0.3):
        super(TCN_Baseline, self).__init__()
        self.revin = RevIN(num_inputs)
        layers = []
        for i in range(len(num_channels)):
            dilation_size = 2**i
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers += [
                TemporalBlock(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=1,
                    dilation=dilation_size,
                    padding=(kernel_size - 1) * dilation_size,
                    dropout=dropout,
                )
            ]
        self.network = nn.Sequential(*layers)
        self.linear_time = nn.Linear(seq_len, pred_len)
        self.projection = nn.Linear(num_channels[-1], num_inputs)

    def forward(self, x):
        x = self.revin(x, 'norm')
        x = x.permute(0, 2, 1)  # [B, C, L]
        out = self.network(x)
        out = self.linear_time(out)
        out = out.permute(0, 2, 1)
        out = self.projection(out)
        out = self.revin(out, 'denorm')
        return out


# Model 2: DLinear (Baseline) + RevIN
class DLinear(nn.Module):
    def __init__(self, seq_len, pred_len, num_inputs):
        super(DLinear, self).__init__()
        self.revin = RevIN(num_inputs)
        self.decompsition = SeriesDecomp(25)
        self.Linear_Seasonal = nn.Linear(seq_len, pred_len)
        self.Linear_Trend = nn.Linear(seq_len, pred_len)

    def forward(self, x):
        x = self.revin(x, 'norm')
        seasonal_init, trend_init = self.decompsition(x)
        seasonal_init = seasonal_init.permute(0, 2, 1)
        trend_init = trend_init.permute(0, 2, 1)
        seasonal_output = self.Linear_Seasonal(seasonal_init)
        trend_output = self.Linear_Trend(trend_init)
        x = seasonal_output + trend_output
        x = x.permute(0, 2, 1)
        x = self.revin(x, 'denorm')
        return x


# Model 3: Decomp-OpTCN (Proposed)
class OpTCNBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, dilation, dropout=0.2):
        super(OpTCNBlock, self).__init__()
        self.op_mixer1 = OperatorMixer(n_inputs, n_outputs, kernel_size, dilation)
        self.bn1 = nn.BatchNorm1d(n_outputs)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.op_mixer2 = OperatorMixer(n_outputs, n_outputs, kernel_size, dilation)
        self.bn2 = nn.BatchNorm1d(n_outputs)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.op_mixer1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.dropout1(out)
        out = self.op_mixer2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.dropout2(out)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class DecompOpTCN(nn.Module):
    def __init__(self, num_inputs, seq_len, pred_len, num_channels=[32] * 4, kernel_size=5, dropout=0.3):
        super(DecompOpTCN, self).__init__()
        self.revin = RevIN(num_inputs)
        self.decomp = SeriesDecomp(kernel_size=25)

        layers = []
        for i in range(len(num_channels)):
            dilation_size = 2**i
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers += [OpTCNBlock(in_channels, out_channels, kernel_size, dilation=dilation_size, dropout=dropout)]
        self.seasonal_backbone = nn.Sequential(*layers)

        self.seasonal_head = nn.Linear(seq_len, pred_len)
        self.seasonal_proj = nn.Linear(num_channels[-1], num_inputs)
        self.trend_head = nn.Linear(seq_len, pred_len)

    def forward(self, x):
        x = self.revin(x, 'norm')
        seasonal_init, trend_init = self.decomp(x)

        trend_out = self.trend_head(trend_init.permute(0, 2, 1)).permute(0, 2, 1)

        seasonal_x = seasonal_init.permute(0, 2, 1)
        seasonal_feat = self.seasonal_backbone(seasonal_x)
        seasonal_out = self.seasonal_head(seasonal_feat).permute(0, 2, 1)
        seasonal_out = self.seasonal_proj(seasonal_out)

        x_out = seasonal_out + trend_out
        x_out = self.revin(x_out, 'denorm')
        return x_out


# ---- Decomp-OpTCN-v3(Patch) 의존 모듈들 (원본 그대로) ----

class SpectralOperator(nn.Module):
    """FFT 기반 low-pass로 global 성분을 보조 피처로 제공."""

    def __init__(self, k=20):
        super().__init__()
        self.k = k

    def forward(self, x):
        # x: [B, C, L]
        B, C, L = x.shape
        x_dtype = x.dtype

        x32 = x.float()
        x_ft = torch.fft.rfft(x32, dim=-1)
        if x_ft.shape[-1] > self.k:
            x_ft[:, :, self.k:] = 0
        out = torch.fft.irfft(x_ft, n=L, dim=-1)
        return out.to(dtype=x_dtype)


class SEScale(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        hidden = max(1, channels // reduction)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, hidden, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y


class AdaptiveOperatorMixerV2(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        dilation=1,
        periods=(24, 168),
        use_spectral=True,
        spectral_k=20,
        se_reduction=4,
    ):
        super().__init__()
        self.dilation = dilation
        self.kernel_size = kernel_size
        self.periods = tuple(periods)
        self.use_spectral = use_spectral

        self.register_buffer('diff1_kernel', torch.tensor([-1.0, 1.0]).view(1, 1, 2))
        self.register_buffer('diff2_kernel', torch.tensor([1.0, -2.0, 1.0]).view(1, 1, 3))
        self.register_buffer('ma_kernel', torch.ones(1, 1, kernel_size) / kernel_size)

        for i, p in enumerate(self.periods):
            k = torch.zeros(1, 1, p + 1)
            k[0, 0, 0], k[0, 0, p] = -1.0, 1.0
            self.register_buffer(f's_kernel_{i}', k)

        if self.use_spectral:
            self.spectral_op = SpectralOperator(k=spectral_k)

        self.total_ops = 1 + 2 + 1 + len(self.periods) + (1 if self.use_spectral else 0)
        self.op_gates = nn.Parameter(torch.ones(self.total_ops))

        self.se_block = SEScale(in_channels * self.total_ops, reduction=se_reduction)
        self.mixer = nn.Conv1d(in_channels * self.total_ops, out_channels, 1)

    def forward(self, x):
        # x: [B, C, L]
        B, C, L = x.shape

        def apply_op(kernel, pad):
            out = F.conv1d(
                x,
                kernel.repeat(C, 1, 1),
                padding=pad * self.dilation,
                dilation=self.dilation,
                groups=C,
            )
            if pad > 0:
                out = out[:, :, : -(pad * self.dilation)]
            return out

        ops = []
        ops.append(x)  # identity
        ops.append(apply_op(self.diff1_kernel, 1))
        ops.append(apply_op(self.diff2_kernel, 2))
        ops.append(apply_op(self.ma_kernel, self.kernel_size - 1))

        for i in range(len(self.periods)):
            p_kernel = getattr(self, f's_kernel_{i}')
            ops.append(apply_op(p_kernel, p_kernel.shape[-1] - 1))

        if self.use_spectral:
            ops.append(self.spectral_op(x))

        out = torch.cat(ops, dim=1)  # [B, C*total_ops, L]

        gate = torch.repeat_interleave(self.op_gates, repeats=C).to(out.dtype)
        out = out * gate.view(1, -1, 1)

        out = self.se_block(out)
        return self.mixer(out)


class OpTCNBlockV2(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, dilation, dropout=0.2, gn_groups=8):
        super().__init__()
        self.op_mixer1 = AdaptiveOperatorMixerV2(n_inputs, n_outputs, kernel_size, dilation=dilation)
        self.norm1 = nn.GroupNorm(num_groups=min(gn_groups, n_outputs), num_channels=n_outputs)
        self.act1 = nn.GELU()
        self.drop1 = nn.Dropout(dropout)

        self.ffn1 = nn.Sequential(
            nn.Conv1d(n_outputs, 2 * n_outputs, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv1d(2 * n_outputs, n_outputs, 1),
        )

        self.op_mixer2 = AdaptiveOperatorMixerV2(n_outputs, n_outputs, kernel_size, dilation=dilation)
        self.norm2 = nn.GroupNorm(num_groups=min(gn_groups, n_outputs), num_channels=n_outputs)
        self.act2 = nn.GELU()
        self.drop2 = nn.Dropout(dropout)

        self.ffn2 = nn.Sequential(
            nn.Conv1d(n_outputs, 2 * n_outputs, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv1d(2 * n_outputs, n_outputs, 1),
        )

        self.res_conv = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else nn.Identity()

    def forward(self, x):
        res = self.res_conv(x)

        out = self.op_mixer1(x)
        out = self.norm1(out)
        out = self.act1(out)
        out = self.drop1(out)
        out = out + self.ffn1(out)

        out = self.op_mixer2(out)
        out = self.norm2(out)
        out = self.act2(out)
        out = self.drop2(out)
        out = out + self.ffn2(out)

        return out + res


class PatchingLayer(nn.Module):
    """x: [B, C, L] -> patches: [B, C, N, P]"""

    def __init__(self, patch_len=16, stride=8):
        super().__init__()
        self.patch_len = patch_len
        self.stride = stride

    def forward(self, x):
        return x.unfold(dimension=-1, size=self.patch_len, step=self.stride)

    def num_patches(self, seq_len):
        return (seq_len - self.patch_len) // self.stride + 1


class ChannelCorrMixer(nn.Module):
    """입력/출력: [B, C, T]"""

    def __init__(self, channels, d=32, dropout=0.1):
        super().__init__()
        self.q = nn.Linear(channels, d)
        self.k = nn.Linear(channels, d)
        self.dropout = nn.Dropout(dropout)
        self.scale = d ** -0.5

    def forward(self, x):
        p = x.mean(dim=-1)  # [B, C]
        q = self.q(p)  # [B, d]
        k = self.k(p)  # [B, d]

        e_q = p.unsqueeze(-1) * q.unsqueeze(1)  # [B, C, d]
        e_k = p.unsqueeze(-1) * k.unsqueeze(1)  # [B, C, d]

        scores = torch.matmul(e_q, e_k.transpose(1, 2)) * self.scale  # [B, C, C]
        w = torch.softmax(scores, dim=-1)
        w = self.dropout(w)

        out = torch.einsum('bij,bjt->bit', w, x)
        return out + x


class DecompOpTCN_V3_Patch(nn.Module):
    def __init__(
        self,
        num_inputs,
        seq_len,
        pred_len,
        num_channels=[32] * 4,
        kernel_size=5,
        dropout=0.3,
        patch_len=8,
        patch_stride=4,
        patch_embed_dim=8,
    ):
        super().__init__()
        self.revin = RevIN(num_inputs)
        self.decomp = SeriesDecomp(kernel_size=25)

        self.patching = PatchingLayer(patch_len=patch_len, stride=patch_stride)
        self.n_patches = self.patching.num_patches(seq_len)

        self.patch_embed = nn.Linear(patch_len, patch_embed_dim)
        self.patch_reduce = nn.Linear(patch_embed_dim, 1)

        self.channel_corr = ChannelCorrMixer(num_inputs, d=32, dropout=0.1)

        layers = []
        for i in range(len(num_channels)):
            dilation_size = 2**i
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers += [OpTCNBlockV2(in_channels, out_channels, kernel_size, dilation=dilation_size, dropout=dropout)]
        self.seasonal_backbone = nn.Sequential(*layers)

        self.seasonal_head = nn.Linear(self.n_patches, pred_len)
        self.seasonal_proj = nn.Linear(num_channels[-1], num_inputs)

        hidden = max(16, seq_len // 4)
        self.trend_head = nn.Sequential(
            nn.Linear(seq_len, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, pred_len),
        )

    def forward(self, x):
        x = self.revin(x, 'norm')
        seasonal_init, trend_init = self.decomp(x)

        trend_out = self.trend_head(trend_init.permute(0, 2, 1)).permute(0, 2, 1)

        seasonal_x = seasonal_init.permute(0, 2, 1)  # [B, C, L]
        patches = self.patching(seasonal_x)  # [B, C, N, P]
        emb = self.patch_embed(patches)  # [B, C, N, d]
        tok = self.patch_reduce(emb).squeeze(-1)  # [B, C, N]

        tok = self.channel_corr(tok)  # [B, C, N]

        seasonal_feat = self.seasonal_backbone(tok)  # [B, C_hidden, N]
        seasonal_out = self.seasonal_head(seasonal_feat).permute(0, 2, 1)  # [B, H, C_hidden]
        seasonal_out = self.seasonal_proj(seasonal_out)  # [B, H, C]

        out = seasonal_out + trend_out
        out = self.revin(out, 'denorm')
        return out


# -------------------------------------------------
# (B) ModernTCN2.ipynb 에 있던 3개
#   - ModernTCN
#   - ModernTCN+Decomp
#   - ModernTCN(seasonal)+Linear(trend)
# -------------------------------------------------


def _make_odd(k: int) -> int:
    return k if (k % 2 == 1) else (k + 1)


class moving_avg(nn.Module):
    def __init__(self, kernel_size, stride=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # x: [B, L, C]
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1)).permute(0, 2, 1)
        return x


class series_decomp(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class Flatten_Head(nn.Module):
    def __init__(self, individual, n_vars, nf, target_window, head_dropout=0.0):
        super().__init__()
        self.individual = individual
        self.n_vars = n_vars

        if self.individual:
            self.linears = nn.ModuleList()
            self.dropouts = nn.ModuleList()
            self.flattens = nn.ModuleList()
            for _ in range(self.n_vars):
                self.flattens.append(nn.Flatten(start_dim=-2))
                self.linears.append(nn.Linear(nf, target_window))
                self.dropouts.append(nn.Dropout(head_dropout))
        else:
            self.flatten = nn.Flatten(start_dim=-2)
            self.linear = nn.Linear(nf, target_window)
            self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        # x: [B, nvars, d_model, patch_num]
        if self.individual:
            x_out = []
            for i in range(self.n_vars):
                z = self.flattens[i](x[:, i, :, :])
                z = self.linears[i](z)
                z = self.dropouts[i](z)
                x_out.append(z)
            x = torch.stack(x_out, dim=1)
        else:
            x = self.flatten(x)
            x = self.linear(x)
            x = self.dropout(x)
        return x


class RevIN_Official(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.subtract_last = subtract_last
        if self.affine:
            self.affine_weight = nn.Parameter(torch.ones(self.num_features))
            self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def forward(self, x, mode: str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else:
            raise NotImplementedError
        return x

    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim - 1))
        if self.subtract_last:
            self.last = x[:, -1, :].unsqueeze(1)
        else:
            self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()

    def _normalize(self, x):
        if self.subtract_last:
            x = x - self.last
        else:
            x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps * self.eps)
        x = x * self.stdev
        if self.subtract_last:
            x = x + self.last
        else:
            x = x + self.mean
        return x


def _get_conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
    return nn.Conv1d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
        bias=bias,
    )


def _conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1, bias=False):
    if padding is None:
        padding = kernel_size // 2
    return nn.Sequential(
        _get_conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
        nn.BatchNorm1d(out_channels),
    )


def _fuse_bn(conv: nn.Conv1d, bn: nn.BatchNorm1d):
    kernel = conv.weight
    running_mean = bn.running_mean
    running_var = bn.running_var
    gamma = bn.weight
    beta = bn.bias
    eps = bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1)
    return kernel * t, beta - running_mean * gamma / std


def _pad_two_edge_1d(w, pad_left: int, pad_right: int):
    if pad_left == 0 and pad_right == 0:
        return w
    left = torch.zeros(w.shape[0], w.shape[1], pad_left, device=w.device, dtype=w.dtype)
    right = torch.zeros(w.shape[0], w.shape[1], pad_right, device=w.device, dtype=w.dtype)
    return torch.cat([left, w, right], dim=-1)


class ReparamLargeKernelConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        groups,
        small_kernel,
        small_kernel_merged=False,
    ):
        super().__init__()
        kernel_size = _make_odd(int(kernel_size))
        small_kernel = _make_odd(int(small_kernel)) if small_kernel is not None else None

        self.kernel_size = kernel_size
        self.small_kernel = small_kernel

        padding = kernel_size // 2
        if small_kernel_merged:
            self.lkb_reparam = nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=1,
                groups=groups,
                bias=True,
            )
        else:
            self.lkb_origin = _conv_bn(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                dilation=1,
                bias=False,
            )
            if small_kernel is not None:
                assert small_kernel <= kernel_size
                self.small_conv = _conv_bn(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=small_kernel,
                    stride=stride,
                    padding=small_kernel // 2,
                    groups=groups,
                    dilation=1,
                    bias=False,
                )

    def forward(self, x):
        if hasattr(self, 'lkb_reparam'):
            return self.lkb_reparam(x)
        out = self.lkb_origin(x)
        if hasattr(self, 'small_conv'):
            out = out + self.small_conv(x)
        return out

    def get_equivalent_kernel_bias(self):
        eq_k, eq_b = _fuse_bn(self.lkb_origin[0], self.lkb_origin[1])
        if hasattr(self, 'small_conv'):
            small_k, small_b = _fuse_bn(self.small_conv[0], self.small_conv[1])
            eq_b = eq_b + small_b
            pad = (self.kernel_size - self.small_kernel) // 2
            eq_k = eq_k + _pad_two_edge_1d(small_k, pad, pad)
        return eq_k, eq_b

    @torch.no_grad()
    def merge_kernel(self):
        if hasattr(self, 'lkb_reparam'):
            return
        eq_k, eq_b = self.get_equivalent_kernel_bias()
        conv: nn.Conv1d = self.lkb_origin[0]
        self.lkb_reparam = nn.Conv1d(
            in_channels=conv.in_channels,
            out_channels=conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            dilation=conv.dilation,
            groups=conv.groups,
            bias=True,
        )
        self.lkb_reparam.weight.data = eq_k
        self.lkb_reparam.bias.data = eq_b
        delattr(self, 'lkb_origin')
        if hasattr(self, 'small_conv'):
            delattr(self, 'small_conv')


class ModernTCN_Block_Official(nn.Module):
    def __init__(self, large_size, small_size, dmodel, dff, nvars, small_kernel_merged=False, drop=0.1):
        super().__init__()
        self.dw = ReparamLargeKernelConv(
            in_channels=nvars * dmodel,
            out_channels=nvars * dmodel,
            kernel_size=large_size,
            stride=1,
            groups=nvars * dmodel,
            small_kernel=small_size,
            small_kernel_merged=small_kernel_merged,
        )
        self.norm = nn.BatchNorm1d(dmodel)

        # ConvFFN1 (variable-independent)
        self.ffn1pw1 = nn.Conv1d(nvars * dmodel, nvars * dff, kernel_size=1, stride=1, padding=0, groups=nvars)
        self.ffn1act = nn.GELU()
        self.ffn1pw2 = nn.Conv1d(nvars * dff, nvars * dmodel, kernel_size=1, stride=1, padding=0, groups=nvars)
        self.ffn1drop1 = nn.Dropout(drop)
        self.ffn1drop2 = nn.Dropout(drop)

        # ConvFFN2 (cross-variable)
        self.ffn2pw1 = nn.Conv1d(nvars * dmodel, nvars * dff, kernel_size=1, stride=1, padding=0, groups=dmodel)
        self.ffn2act = nn.GELU()
        self.ffn2pw2 = nn.Conv1d(nvars * dff, nvars * dmodel, kernel_size=1, stride=1, padding=0, groups=dmodel)
        self.ffn2drop1 = nn.Dropout(drop)
        self.ffn2drop2 = nn.Dropout(drop)

    def forward(self, x):
        # x: [B, M, D, N]
        shortcut = x
        B, M, D, N = x.shape

        # 1) DWConv
        x = x.reshape(B, M * D, N)
        x = self.dw(x)
        x = x.reshape(B, M, D, N)

        # 2) BN over D per-variable
        x = x.reshape(B * M, D, N)
        x = self.norm(x)
        x = x.reshape(B, M, D, N)

        # 3) ConvFFN1
        x = x.reshape(B, M * D, N)
        x = self.ffn1drop1(self.ffn1pw1(x))
        x = self.ffn1act(x)
        x = self.ffn1drop2(self.ffn1pw2(x))
        x = x.reshape(B, M, D, N)

        # 4) ConvFFN2
        x = x.permute(0, 2, 1, 3).contiguous()  # [B, D, M, N]
        x = x.reshape(B, D * M, N)
        x = self.ffn2drop1(self.ffn2pw1(x))
        x = self.ffn2act(x)
        x = self.ffn2drop2(self.ffn2pw2(x))
        x = x.reshape(B, D, M, N)
        x = x.permute(0, 2, 1, 3).contiguous()  # [B, M, D, N]

        return shortcut + x


class ModernTCN_Stage_Official(nn.Module):
    def __init__(self, ffn_ratio, num_blocks, large_size, small_size, dmodel, nvars, small_kernel_merged=False, drop=0.1):
        super().__init__()
        d_ffn = dmodel * ffn_ratio
        self.blocks = nn.ModuleList(
            [
                ModernTCN_Block_Official(
                    large_size=large_size,
                    small_size=small_size,
                    dmodel=dmodel,
                    dff=d_ffn,
                    nvars=nvars,
                    small_kernel_merged=small_kernel_merged,
                    drop=drop,
                )
                for _ in range(num_blocks)
            ]
        )

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x


class ModernTCN_Official(nn.Module):
    """입력: x [B, M, L] / 출력: y [B, M, target_window]"""

    def __init__(
        self,
        patch_size,
        patch_stride,
        downsample_ratio,
        ffn_ratio,
        num_blocks,
        large_size,
        small_size,
        dims,
        nvars,
        small_kernel_merged=False,
        backbone_dropout=0.1,
        head_dropout=0.1,
        use_multi_scale=False,
        revin=True,
        affine=True,
        subtract_last=False,
        seq_len=96,
        individual=False,
        target_window=96,
    ):
        super().__init__()

        self.revin = revin
        if self.revin:
            self.revin_layer = RevIN_Official(nvars, affine=affine, subtract_last=subtract_last)

        self.downsample_layers = nn.ModuleList()
        self.patch_size = patch_size
        self.patch_stride = patch_stride
        self.downsample_ratio = downsample_ratio

        stem = nn.Sequential(
            nn.Conv1d(1, dims[0], kernel_size=patch_size, stride=patch_stride),
            nn.BatchNorm1d(dims[0]),
        )
        self.downsample_layers.append(stem)

        for i in range(3):
            downsample_layer = nn.Sequential(
                nn.BatchNorm1d(dims[i]),
                nn.Conv1d(dims[i], dims[i + 1], kernel_size=downsample_ratio, stride=downsample_ratio),
            )
            self.downsample_layers.append(downsample_layer)

        self.num_stage = len(num_blocks)
        self.stages = nn.ModuleList(
            [
                ModernTCN_Stage_Official(
                    ffn_ratio=ffn_ratio,
                    num_blocks=num_blocks[stage_idx],
                    large_size=large_size[stage_idx],
                    small_size=small_size[stage_idx],
                    dmodel=dims[stage_idx],
                    nvars=nvars,
                    small_kernel_merged=small_kernel_merged,
                    drop=backbone_dropout,
                )
                for stage_idx in range(self.num_stage)
            ]
        )

        patch_num = seq_len // patch_stride
        self.n_vars = nvars
        self.individual = individual
        d_model = dims[-1]

        self.use_multi_scale = use_multi_scale
        if self.use_multi_scale:
            self.head_nf = d_model * patch_num
        else:
            div = pow(downsample_ratio, (self.num_stage - 1))
            if patch_num % div == 0:
                self.head_nf = d_model * patch_num // div
            else:
                self.head_nf = d_model * (patch_num // div + 1)

        self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout)

    def forward_feature(self, x):
        # x: [B,M,L]
        B, M, L = x.shape
        x = x.unsqueeze(-2)  # [B,M,1,L]

        for i in range(self.num_stage):
            B, M, D, N = x.shape
            x = x.reshape(B * M, D, N)

            if i == 0:
                if self.patch_size != self.patch_stride:
                    pad_len = self.patch_size - self.patch_stride
                    pad = x[:, :, -1:].repeat(1, 1, pad_len)
                    x = torch.cat([x, pad], dim=-1)
            else:
                if N % self.downsample_ratio != 0:
                    pad_len = self.downsample_ratio - (N % self.downsample_ratio)
                    x = torch.cat([x, x[:, :, -pad_len:]], dim=-1)

            x = self.downsample_layers[i](x)
            _, D_, N_ = x.shape
            x = x.reshape(B, M, D_, N_)
            x = self.stages[i](x)

        return x

    def forward(self, x):
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'norm')
            x = x.permute(0, 2, 1)

        feat = self.forward_feature(x)
        y = self.head(feat)

        if self.revin:
            y = y.permute(0, 2, 1)
            y = self.revin_layer(y, 'denorm')
            y = y.permute(0, 2, 1)

        return y

    def structural_reparam(self):
        for m in self.modules():
            if hasattr(m, 'merge_kernel'):
                m.merge_kernel()


class ModernTCN_Official_Wrapper(nn.Module):
    """입력 [B,L,C] -> 출력 [B,H,C] wrapper."""

    def __init__(
        self,
        seq_len: int,
        pred_len: int,
        enc_in: int,
        decomposition: bool = False,
        decomp_kernel: int = 25,
        patch_size: int = 16,
        patch_stride: int = 16,
        downsample_ratio: int = 2,
        ffn_ratio: int = 2,
        num_blocks=(1, 1, 1, 1),
        large_size=(51, 51, 51, 51),
        small_size=(7, 7, 7, 7),
        dims=(16, 16, 16, 16),
        small_kernel_merged: bool = False,
        dropout: float = 0.1,
        head_dropout: float = 0.0,
        use_multi_scale: bool = False,
        revin: bool = True,
        affine: bool = True,
        subtract_last: bool = False,
        individual: bool = False,
    ):
        super().__init__()

        self.decomposition = decomposition
        if self.decomposition:
            self.decomp_module = series_decomp(decomp_kernel)
            self.model_res = ModernTCN_Official(
                patch_size=patch_size,
                patch_stride=patch_stride,
                downsample_ratio=downsample_ratio,
                ffn_ratio=ffn_ratio,
                num_blocks=list(num_blocks),
                large_size=list(large_size),
                small_size=list(small_size),
                dims=list(dims),
                nvars=enc_in,
                small_kernel_merged=small_kernel_merged,
                backbone_dropout=dropout,
                head_dropout=head_dropout,
                use_multi_scale=use_multi_scale,
                revin=revin,
                affine=affine,
                subtract_last=subtract_last,
                seq_len=seq_len,
                individual=individual,
                target_window=pred_len,
            )
            self.model_trend = ModernTCN_Official(
                patch_size=patch_size,
                patch_stride=patch_stride,
                downsample_ratio=downsample_ratio,
                ffn_ratio=ffn_ratio,
                num_blocks=list(num_blocks),
                large_size=list(large_size),
                small_size=list(small_size),
                dims=list(dims),
                nvars=enc_in,
                small_kernel_merged=small_kernel_merged,
                backbone_dropout=dropout,
                head_dropout=head_dropout,
                use_multi_scale=use_multi_scale,
                revin=revin,
                affine=affine,
                subtract_last=subtract_last,
                seq_len=seq_len,
                individual=individual,
                target_window=pred_len,
            )
        else:
            self.model = ModernTCN_Official(
                patch_size=patch_size,
                patch_stride=patch_stride,
                downsample_ratio=downsample_ratio,
                ffn_ratio=ffn_ratio,
                num_blocks=list(num_blocks),
                large_size=list(large_size),
                small_size=list(small_size),
                dims=list(dims),
                nvars=enc_in,
                small_kernel_merged=small_kernel_merged,
                backbone_dropout=dropout,
                head_dropout=head_dropout,
                use_multi_scale=use_multi_scale,
                revin=revin,
                affine=affine,
                subtract_last=subtract_last,
                seq_len=seq_len,
                individual=individual,
                target_window=pred_len,
            )

    def forward(self, x):
        # x: [B,L,C]
        if self.decomposition:
            res_init, trend_init = self.decomp_module(x)
            res_init = res_init.permute(0, 2, 1)
            trend_init = trend_init.permute(0, 2, 1)
            res = self.model_res(res_init)
            trend = self.model_trend(trend_init)
            y = res + trend
            return y.permute(0, 2, 1)
        x_c = x.permute(0, 2, 1)
        y = self.model(x_c)
        return y.permute(0, 2, 1)

    @torch.no_grad()
    def structural_reparam(self):
        if self.decomposition:
            self.model_res.structural_reparam()
            self.model_trend.structural_reparam()
        else:
            self.model.structural_reparam()


class ModernTCN_Seasonal_LinearTrend(nn.Module):
    def __init__(
        self,
        seq_len: int,
        pred_len: int,
        enc_in: int,
        decomp_kernel: int = 25,
        patch_size: int = 16,
        patch_stride: int = 16,
        downsample_ratio: int = 2,
        ffn_ratio: int = 2,
        num_blocks=(1, 1, 1, 1),
        large_size=(51, 51, 51, 51),
        small_size=(7, 7, 7, 7),
        dims=(16, 16, 16, 16),
        small_kernel_merged: bool = False,
        dropout: float = 0.1,
        head_dropout: float = 0.0,
        use_multi_scale: bool = False,
        revin: bool = True,
        affine: bool = True,
        subtract_last: bool = False,
        individual: bool = False,
    ):
        super().__init__()

        self.revin = revin
        if self.revin:
            self.revin_layer = RevIN_Official(enc_in, affine=affine, subtract_last=subtract_last)

        self.decomp = series_decomp(decomp_kernel)

        self.trend_head = nn.Linear(seq_len, pred_len)

        self.seasonal_model = ModernTCN_Official(
            patch_size=patch_size,
            patch_stride=patch_stride,
            downsample_ratio=downsample_ratio,
            ffn_ratio=ffn_ratio,
            num_blocks=list(num_blocks),
            large_size=list(large_size),
            small_size=list(small_size),
            dims=list(dims),
            nvars=enc_in,
            small_kernel_merged=small_kernel_merged,
            backbone_dropout=dropout,
            head_dropout=head_dropout,
            use_multi_scale=use_multi_scale,
            revin=False,
            affine=affine,
            subtract_last=subtract_last,
            seq_len=seq_len,
            individual=individual,
            target_window=pred_len,
        )

    def forward(self, x):
        # x: [B,L,C]
        if self.revin:
            x = self.revin_layer(x, 'norm')

        res, trend = self.decomp(x)

        trend_out = self.trend_head(trend.permute(0, 2, 1)).permute(0, 2, 1)

        seasonal_out = self.seasonal_model(res.permute(0, 2, 1))
        seasonal_out = seasonal_out.permute(0, 2, 1)

        out = seasonal_out + trend_out

        if self.revin:
            out = self.revin_layer(out, 'denorm')

        return out

In [4]:
# =========================
# 5. 학습/평가 유틸 (MSE/MAE)  +  6. 실험 러너
# =========================
# Colab에서 위→아래로 실행해도 바로 돌아가도록,
# 원래 뒤쪽(셀 5)에 있던 학습/평가 유틸을 이 셀 상단으로 올려둡니다.


def _inverse_transform_3d(arr_3d, scaler):
    """arr_3d: [N, T, C]"""
    if scaler is None:
        return arr_3d
    N, T, C = arr_3d.shape
    return scaler.inverse_transform(arr_3d.reshape(-1, C)).reshape(N, T, C)


def _calc_metrics(preds, trues):
    mse = float(np.mean((preds - trues) ** 2))
    mae = float(np.mean(np.abs(preds - trues)))
    return mse, mae


def train_model(
    model,
    train_loader,
    val_loader,
    epochs=30,
    lr=1e-3,
    weight_decay=1e-4,
    patience=5,
    min_epochs=5,
    grad_clip=1.0,
    use_amp=True,
):
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, min_lr=1e-5)

    model.to(device)

    if hasattr(torch, 'amp') and hasattr(torch.amp, 'GradScaler'):
        scaler_amp = torch.amp.GradScaler(enabled=bool(use_amp and device.type == 'cuda'))
    else:
        scaler_amp = torch.cuda.amp.GradScaler(enabled=bool(use_amp and device.type == 'cuda'))

    best_state = None
    best_val = float('inf')
    wait = 0

    for epoch in range(epochs):
        model.train()
        train_losses = []

        for batch_x, batch_y in train_loader:
            batch_x = batch_x.to(device).float()
            batch_y = batch_y.to(device).float()

            optimizer.zero_grad(set_to_none=True)

            if hasattr(torch, 'amp') and hasattr(torch.amp, 'autocast'):
                autocast_ctx = torch.amp.autocast(device_type=device.type, enabled=bool(use_amp and device.type == 'cuda'))
            else:
                autocast_ctx = torch.cuda.amp.autocast(enabled=bool(use_amp and device.type == 'cuda'))

            with autocast_ctx:
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)

            scaler_amp.scale(loss).backward()

            if grad_clip is not None:
                scaler_amp.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            scaler_amp.step(optimizer)
            scaler_amp.update()

            train_losses.append(loss.item())

        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x = batch_x.to(device).float()
                batch_y = batch_y.to(device).float()
                outputs = model(batch_x)
                val_loss = criterion(outputs, batch_y)
                val_losses.append(val_loss.item())

        avg_val = float(np.mean(val_losses))
        scheduler.step(avg_val)

        if avg_val < best_val:
            best_val = avg_val
            best_state = deepcopy(model.state_dict())
            wait = 0
        else:
            wait += 1

        if (epoch + 1) >= min_epochs and wait >= patience:
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    return model


def evaluate_model(model, data_loader, scaler=None, capture_sample=True):
    model.eval()
    preds = []
    trues = []

    sample = None

    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(data_loader):
            batch_x = batch_x.to(device).float()
            batch_y = batch_y.to(device).float()
            outputs = model(batch_x)

            pred = outputs.detach().cpu().numpy()
            true = batch_y.detach().cpu().numpy()

            if capture_sample and i == 0:
                inp = batch_x.detach().cpu().numpy()
                inp_u = _inverse_transform_3d(inp, scaler)
                pred_u = _inverse_transform_3d(pred, scaler)
                true_u = _inverse_transform_3d(true, scaler)

                sample = {
                    'input': inp_u[0, :, -1],
                    'true': true_u[0, :, -1],
                    'pred': pred_u[0, :, -1],
                }

            preds.append(pred)
            trues.append(true)

    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)

    mse_s, mae_s = _calc_metrics(preds, trues)

    preds_u = _inverse_transform_3d(preds, scaler)
    trues_u = _inverse_transform_3d(trues, scaler)
    mse_u, mae_u = _calc_metrics(preds_u, trues_u)

    return {
        'mse_scaled': mse_s,
        'mae_scaled': mae_s,
        'mse_unscaled': mse_u,
        'mae_unscaled': mae_u,
        'sample': sample,
    }


# =========================
# 6. 실험 러너 (데이터셋 x 모델 x input_len x pred_len)
# =========================

# 큰 데이터셋(Electricity/Traffic 등)에서 변수 수가 너무 크면 모델이 매우 무거워질 수 있어
# 기본값으로는 custom 데이터셋 변수 수를 16으로 제한해두었습니다.
# - 원본 풀 변수를 쓰고 싶으면: CUSTOM_MAX_FEATURES = None
CUSTOM_MAX_FEATURES = 16


def make_model(model_name: str, num_inputs: int, seq_len: int, pred_len: int):
    if model_name == 'TCN':
        return TCN_Baseline(num_inputs, seq_len, pred_len, num_channels=[32] * 4, kernel_size=5)

    if model_name == 'DLinear':
        return DLinear(seq_len, pred_len, num_inputs)

    if model_name == 'DecompOpTCN':
        return DecompOpTCN(num_inputs, seq_len, pred_len, num_channels=[32] * 4, kernel_size=5)

    if model_name == 'Decomp-OpTCN-v3(Patch)':
        return DecompOpTCN_V3_Patch(
            num_inputs,
            seq_len,
            pred_len,
            num_channels=[32] * 4,
            kernel_size=5,
            patch_len=16,
            patch_stride=8,
            patch_embed_dim=16,
        )

    common_cfg = dict(
        seq_len=seq_len,
        enc_in=num_inputs,
        patch_size=16,
        patch_stride=16,
        downsample_ratio=2,
        ffn_ratio=2,
        num_blocks=(1, 1, 1, 1),
        large_size=(51, 51, 51, 51),
        small_size=(7, 7, 7, 7),
        dims=(16, 16, 16, 16),
        small_kernel_merged=False,
        dropout=0.1,
        head_dropout=0.0,
        use_multi_scale=False,
        revin=True,
        affine=True,
        subtract_last=False,
        individual=False,
    )

    if model_name == 'ModernTCN':
        return ModernTCN_Official_Wrapper(pred_len=pred_len, decomposition=False, **common_cfg)

    if model_name == 'ModernTCN+Decomp':
        return ModernTCN_Official_Wrapper(pred_len=pred_len, decomposition=True, decomp_kernel=25, **common_cfg)

    if model_name == 'ModernTCN(seasonal)+Linear(trend)':
        return ModernTCN_Seasonal_LinearTrend(pred_len=pred_len, decomp_kernel=25, **common_cfg)

    raise ValueError(f"Unknown model_name: {model_name}")


MODEL_NAMES = [
    'TCN',
    'DecompOpTCN',
    'Decomp-OpTCN-v3(Patch)',
    'DLinear',
    'ModernTCN',
    'ModernTCN+Decomp',
    'ModernTCN(seasonal)+Linear(trend)',
]


def run_grid(
    data_root: str = None,
    dataset_names=None,
    model_names=None,
    input_lens=None,
    pred_lens=None,
    batch_size: int = BATCH_SIZE,
    epochs: int = EPOCHS,
    patience: int = PATIENCE,
    min_epochs: int = MIN_EPOCHS,
    lr: float = LR,
    weight_decay: float = WEIGHT_DECAY,
    grad_clip: float = GRAD_CLIP,
    use_amp: bool = USE_AMP,
    resume: bool = True,
):
    """전체 그리드 실험을 실행하고, 결과 DF와 sample(dict)를 반환합니다.

    - resume=True이면, 이전 실행의 `results_running.csv`를 읽어 이미 완료된 조합은 스킵하고 이어서 실행합니다.
    - 단, **학습 epoch 중간부터 이어붙이기**는 체크포인트가 없어서 지원하지 않습니다(그리드 단위 resume).
    """

    data_root = data_root or DATA_DIR  # Colab 기본: /content/data

    dataset_names = dataset_names or DATASET_NAMES
    model_names = model_names or MODEL_NAMES
    input_lens = input_lens or INPUT_LENS
    pred_lens = pred_lens or PRED_LENS

    # 결과 저장 폴더
    SAMPLE_DIR = os.path.join(ARTIFACT_DIR, 'samples')
    os.makedirs(SAMPLE_DIR, exist_ok=True)

    running_csv = os.path.join(CSV_DIR, 'results_running.csv')

    results = []
    done_keys = set()

    if resume and os.path.exists(running_csv):
        try:
            df_prev = pd.read_csv(running_csv)
            # schema 체크
            need_cols = {'Dataset', 'Model', 'InputLen', 'PredLen'}
            if need_cols.issubset(set(df_prev.columns)):
                results = df_prev.to_dict('records')
                for _, r in df_prev[['Dataset', 'Model', 'InputLen', 'PredLen']].iterrows():
                    done_keys.add((str(r['Dataset']), str(r['Model']), int(r['InputLen']), int(r['PredLen'])))
                print(f"[resume] loaded {len(results)} rows from {running_csv}")
            else:
                print(f"[resume] found {running_csv} but columns mismatch -> start fresh")
        except Exception as e:
            print(f"[resume] failed to read {running_csv} ({e}) -> start fresh")

    samples = {}  # 이번 실행에서 새로 계산한 샘플들만 메모리에 쌓음

    print(f"Running grid: datasets={len(dataset_names)}, models={len(model_names)}, in={input_lens}, out={pred_lens}")
    print(f"DATA_DIR={data_root}")
    print(f"CUSTOM_MAX_FEATURES={CUSTOM_MAX_FEATURES} (custom datasets only)")

    for dname in dataset_names:
        for seq_len in input_lens:
            for pred_len in pred_lens:
                train_ds = build_dataset(dname, root=data_root, flag='train', seq_len=seq_len, pred_len=pred_len, max_features_custom=CUSTOM_MAX_FEATURES)
                val_ds = build_dataset(dname, root=data_root, flag='val', seq_len=seq_len, pred_len=pred_len, max_features_custom=CUSTOM_MAX_FEATURES)
                test_ds = build_dataset(dname, root=data_root, flag='test', seq_len=seq_len, pred_len=pred_len, max_features_custom=CUSTOM_MAX_FEATURES)

                num_inputs = int(train_ds[0][0].shape[-1])
                scaler = train_ds.scaler

                train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
                val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)
                test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, drop_last=False)

                for mname in model_names:
                    # resume: 이미 끝난 조합은 스킵
                    k = (dname, mname, int(seq_len), int(pred_len))
                    if k in done_keys:
                        continue

                    setup_seed(2025)
                    model = make_model(mname, num_inputs, seq_len, pred_len)

                    print(f"[{dname}] in={seq_len} out={pred_len} vars={num_inputs} | {mname} ...", end=' ')
                    t0 = time.time()

                    model = train_model(
                        model,
                        train_loader,
                        val_loader,
                        epochs=epochs,
                        lr=lr,
                        weight_decay=weight_decay,
                        patience=patience,
                        min_epochs=min_epochs,
                        grad_clip=grad_clip,
                        use_amp=use_amp,
                    )

                    dt = time.time() - t0
                    metrics = evaluate_model(model, test_loader, scaler=scaler, capture_sample=True)

                    row = {
                        'Dataset': dname,
                        'Model': mname,
                        'InputLen': int(seq_len),
                        'PredLen': int(pred_len),
                        'NumVars': int(num_inputs),
                        'MSE_scaled': metrics['mse_scaled'],
                        'MAE_scaled': metrics['mae_scaled'],
                        'MSE_unscaled': metrics['mse_unscaled'],
                        'MAE_unscaled': metrics['mae_unscaled'],
                        'TrainSeconds': float(dt),
                    }
                    results.append(row)
                    done_keys.add((dname, mname, int(seq_len), int(pred_len)))

                    print(f"done {dt:.1f}s | MSE={row['MSE_scaled']:.4f}, MAE={row['MAE_scaled']:.4f} (scaled)")

                    key = (dname, mname, seq_len, pred_len)
                    samples[key] = metrics['sample']

                    # 1) 메트릭 중간 저장
                    pd.DataFrame(results).to_csv(os.path.join(CSV_DIR, 'results_running.csv'), index=False)

                    # 2) 샘플 inference 저장 (런타임 종료/재시작 대비)
                    # - 각 run마다 1개 샘플(마지막 변수)만 저장
                    s = metrics['sample']
                    if s is not None:
                        # 파일명 안전화
                        safe_m = mname.replace('/', '_')
                        out_path = os.path.join(
                            SAMPLE_DIR,
                            f"{dname}__{safe_m}__in{seq_len}__out{pred_len}.npz",
                        )
                        np.savez_compressed(out_path, input=s['input'], true=s['true'], pred=s['pred'])

    df = pd.DataFrame(results)
    out_csv = os.path.join(CSV_DIR, 'results_final.csv')
    df.to_csv(out_csv, index=False)
    print(f"Saved: {out_csv}")
    return df, samples


# Colab 실행 예시 (전체 그리드 1번에 실행):
# - datasets: DATASET_NAMES (7개)
# - models:   MODEL_NAMES (7개)
# - input:    INPUT_LENS (2개)
# - output:   PRED_LENS (4개)
# 총 7*7*2*4 = 392번 학습/평가를 수행하므로 시간이 많이 걸립니다.
ensure_datasets()  # /content/data 로 다운로드
df, samples = run_grid(input_lens=[336])  # 모든 데이터셋/모델/input/pred 조합 실행

# (디버그용) 일부만 빠르게 돌리고 싶으면 예:
# df, samples = run_grid(dataset_names=['ETTh1'], input_lens=[96], pred_lens=[96], model_names=['DLinear','TCN'])

ETTh1: OK (/content/data/ETTh1.csv)
ETTh2: OK (/content/data/ETTh2.csv)
ETTm1: OK (/content/data/ETTm1.csv)
ETTm2: OK (/content/data/ETTm2.csv)
Electricity: OK (/content/data/electricity.csv)
Weather: OK (/content/data/weather.csv)
Traffic: OK (/content/data/traffic.csv)
Running grid: datasets=7, models=7, in=[336], out=[96, 192, 336, 720]
DATA_DIR=/content/data
CUSTOM_MAX_FEATURES=16 (custom datasets only)
[ETTh1] in=336 out=96 vars=7 | TCN ... done 27.2s | MSE=0.4433, MAE=0.4600 (scaled)
[ETTh1] in=336 out=96 vars=7 | DecompOpTCN ... done 34.0s | MSE=0.4011, MAE=0.4231 (scaled)
[ETTh1] in=336 out=96 vars=7 | Decomp-OpTCN-v3(Patch) ... done 77.2s | MSE=0.4854, MAE=0.4837 (scaled)
[ETTh1] in=336 out=96 vars=7 | DLinear ... done 22.4s | MSE=0.3755, MAE=0.3956 (scaled)
[ETTh1] in=336 out=96 vars=7 | ModernTCN ... done 42.5s | MSE=0.4227, MAE=0.4377 (scaled)
[ETTh1] in=336 out=96 vars=7 | ModernTCN+Decomp ... done 66.6s | MSE=0.4163, MAE=0.4365 (scaled)
[ETTh1] in=336 out=96 vars=7 | Mode

In [8]:
# =========================
# 7. 결과 요약/시각화 유틸 (inference 출력)
# =========================


def plot_sample(sample, seq_len: int, pred_len: int, title: str = ''):
    """sample = {'input','true','pred'} (각 1D array)"""
    if sample is None:
        print('No sample')
        return

    input_seq = sample['input']
    true_seq = sample['true']
    pred_seq = sample['pred']

    x_past = np.arange(0, seq_len)
    x_future = np.arange(seq_len, seq_len + pred_len)

    plt.figure(figsize=(14, 4))
    plt.plot(x_past, input_seq, label='Input (Past)', color='grey', alpha=0.6)
    plt.plot(x_future, true_seq, label='Ground Truth', color='black', linewidth=2)
    plt.plot(x_future, pred_seq, label='Prediction', color='tab:blue', linewidth=2)
    plt.axvline(x=seq_len, color='k', linestyle=':', alpha=0.3)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.show()


def save_sample_plot(sample, seq_len: int, pred_len: int, out_path: str, title: str = ''):
    if sample is None:
        return

    os.makedirs(os.path.dirname(out_path), exist_ok=True)

    input_seq = sample['input']
    true_seq = sample['true']
    pred_seq = sample['pred']

    x_past = np.arange(0, seq_len)
    x_future = np.arange(seq_len, seq_len + pred_len)

    plt.figure(figsize=(14, 4))
    plt.plot(x_past, input_seq, label='Input (Past)', color='grey', alpha=0.6)
    plt.plot(x_future, true_seq, label='Ground Truth', color='black', linewidth=2)
    plt.plot(x_future, pred_seq, label='Prediction', color='tab:blue', linewidth=2)
    plt.axvline(x=seq_len, color='k', linestyle=':', alpha=0.3)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()


def summarize(df: pd.DataFrame, metric: str = 'MSE_scaled'):
    """데이터셋/모델별로 보기 쉬운 피벗."""
    if df is None or len(df) == 0:
        print('Empty df')
        return None

    piv = df.pivot_table(
        index=['Dataset', 'Model'],
        columns=['InputLen', 'PredLen'],
        values=[metric],
        aggfunc='mean',
    )
    return piv


def save_all_sample_plots(samples: dict, out_dir: str = PLOT_DIR):
    """run_grid()가 반환한 samples를 전부 파일로 저장."""
    for (dname, mname, seq_len, pred_len), sample in samples.items():
        fname = f"{dname}__{mname}__in{seq_len}__out{pred_len}.png".replace('/', '_')
        out_path = os.path.join(out_dir, dname, fname)
        title = f"{dname} | {mname} | in={seq_len} out={pred_len}"
        save_sample_plot(sample, seq_len, pred_len, out_path, title=title)


# =========================
# 디스크에서 결과/샘플 로드 (런타임 재시작 후에도 plot 가능)
# =========================

def load_results_from_disk(csv_dir: str = CSV_DIR):
    """results_final.csv가 있으면 그걸, 없으면 results_running.csv를 로드."""
    cand_final = os.path.join(csv_dir, 'results_final.csv')
    cand_running = os.path.join(csv_dir, 'results_running.csv')

    if os.path.exists(cand_final) and os.path.getsize(cand_final) > 0:
        return pd.read_csv(cand_final)
    if os.path.exists(cand_running) and os.path.getsize(cand_running) > 0:
        return pd.read_csv(cand_running)
    raise FileNotFoundError(f"No results CSV found in {csv_dir} (expected results_final.csv or results_running.csv)")


def load_samples_from_disk(sample_dir: str = None):
    """/content/artifacts/samples/*.npz 를 읽어서 samples dict 형태로 복원.

    파일명 포맷: {Dataset}__{Model}__in{seq}__out{pred}.npz
    """
    sample_dir = sample_dir or os.path.join(ARTIFACT_DIR, 'samples')
    if not os.path.exists(sample_dir):
        raise FileNotFoundError(f"Sample dir not found: {sample_dir}")

    out = {}
    for fn in os.listdir(sample_dir):
        if not fn.endswith('.npz'):
            continue
        base = fn[:-4]
        parts = base.split('__')
        if len(parts) < 4:
            continue
        dname = parts[0]
        mname = '__'.join(parts[1:-2])
        in_part = parts[-2]
        out_part = parts[-1]
        if not (in_part.startswith('in') and out_part.startswith('out')):
            continue
        try:
            seq_len = int(in_part[2:])
            pred_len = int(out_part[3:])
        except Exception:
            continue

        path = os.path.join(sample_dir, fn)
        data = np.load(path)
        out[(dname, mname, seq_len, pred_len)] = {
            'input': data['input'],
            'true': data['true'],
            'pred': data['pred'],
        }
    return out


def save_all_sample_plots_from_disk(sample_dir: str = None, out_dir: str = PLOT_DIR):
    """samples 폴더의 .npz만으로 전부 PNG 저장 (df/samples 변수 없이도 가능)."""
    sample_dir = sample_dir or os.path.join(ARTIFACT_DIR, 'samples')
    if not os.path.exists(sample_dir):
        raise FileNotFoundError(f"Sample dir not found: {sample_dir}")

    n = 0
    for fn in sorted(os.listdir(sample_dir)):
        if not fn.endswith('.npz'):
            continue

        base = fn[:-4]
        parts = base.split('__')
        if len(parts) < 4:
            continue

        dname = parts[0]
        mname = '__'.join(parts[1:-2])
        in_part = parts[-2]
        out_part = parts[-1]
        if not (in_part.startswith('in') and out_part.startswith('out')):
            continue

        try:
            seq_len = int(in_part[2:])
            pred_len = int(out_part[3:])
        except Exception:
            continue

        path = os.path.join(sample_dir, fn)
        data = np.load(path)
        sample = {'input': data['input'], 'true': data['true'], 'pred': data['pred']}

        out_path = os.path.join(out_dir, dname, f"{dname}__{mname}__in{seq_len}__out{pred_len}.png".replace('/', '_'))
        title = f"{dname} | {mname} | in={seq_len} out={pred_len}"
        save_sample_plot(sample, seq_len, pred_len, out_path, title=title)
        n += 1

    print(f"Saved {n} plots -> {out_dir}")


# 사용 예시(Colab) - 런타임 재시작 후 Drive 백업본 복구해서 plotting만 할 때:
# 1) (필요 시) restore_artifacts_from_drive()
# 2) df는 CSV에서 로드해서 pivot 등 요약
df_ = load_results_from_disk()
piv = summarize(df_, metric='MSE_scaled'); print(piv)
# 3) samples는 .npz만으로 바로 png 저장
save_all_sample_plots_from_disk()  # -> /content/artifacts/plots

                                              MSE_scaled                      \
InputLen                                             336                       
PredLen                                              96        192       336   
Dataset     Model                                                              
ETTh1       DLinear                             0.375545  0.421809  0.440936   
            Decomp-OpTCN-v3(Patch)              0.485429  0.498916  0.519915   
            DecompOpTCN                         0.401059  0.449332  0.523952   
            ModernTCN                           0.422733  0.473717  0.500149   
            ModernTCN(seasonal)+Linear(trend)   0.403498  0.443086  0.463041   
            ModernTCN+Decomp                    0.416270  0.487477  0.589186   
            TCN                                 0.443319  0.594575  0.597250   
ETTh2       DLinear                             0.275135  0.350481  0.386020   
            Decomp-OpTCN-v3(Patch)      