# Temporal Attention Unit (TAU) for Frost Prediction

**Paper**: "Temporal Attention Unit: Towards Efficient Spatiotemporal Predictive Learning" (CVPR 2023)

논문 구조 그대로 구현:
1. Spatial Encoder: 단순 2D CNN
2. Temporal Module: TAU 블록 스택
   - Intra-frame Statical Attention (SA)
   - Inter-frame Dynamical Attention (DA)
3. Spatial Decoder: 단순 2D CNN with skip connection
4. Loss: MSE + α * DDR (Differential Divergence Regularization)

In [22]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import netCDF4 as nc
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms

from utils import *
from utils.gk2a_dataset import GK2ADataset
from utils.projection import *

from sklearn.metrics import roc_auc_score, confusion_matrix
from torchvision.ops import sigmoid_focal_loss
from torch.nn.functional import binary_cross_entropy_with_logits

import random
import copy
import pandas as pd
from tqdm import tqdm

from model.frost_tau import FrostTAU

## 1. 실험 설정

In [None]:
# ====== 실험 설정 ======
seeds = [0, 1, 2]

# 데이터 설정
ASOS = True
AAFOS = False

channels = '16ch'
time_range_list = [[-18, -15, -12], [-21, -18, -15, -12], [-24, -21, -18, -15, -12], [-27, -24, -21, -18, -15, -12], [-30, -27, -24, -21, -18, -15, -12]]
time_range = time_range_list[0]  # 기본값 설정 (학습 루프에서 덮어씀)
resolution = '2km'
postfix = 'tau'

window_size = 1
window_stride = 1

# ASOS/AAFOS 가중치
asos_aafos_ratio = 5.0
asos_weight = asos_aafos_ratio / (asos_aafos_ratio + 1.0 * AAFOS) if ASOS else 0.0
aafos_weight = 1.0 / (asos_aafos_ratio * ASOS + 1.0) if AAFOS else 0.0
print(f"ASOS weight: {asos_weight:.2f}, AAFOS weight: {aafos_weight:.2f}")

# 채널 설정
channels_name = ['vi004','vi005','vi006','vi008','nr013','nr016','sw038','wv063','wv069','wv073','ir087','ir096','ir105','ir112','ir123','ir133']
channels_calib = channels_name.copy()

channels_mean = [1.1912e-01, 1.1464e-01, 1.0734e-01, 1.2504e-01, 5.4983e-02, 9.0381e-02,
            2.7813e+02, 2.3720e+02, 2.4464e+02, 2.5130e+02, 2.6948e+02, 2.4890e+02,
            2.7121e+02, 2.7071e+02, 2.6886e+02, 2.5737e+02]
channels_std = [0.1306,  0.1303,  0.1306,  0.1501,  0.0268,  0.0838, 15.8211,  6.1468,
            7.8054,  9.3251, 16.4265,  9.6150, 17.2518, 17.6064, 17.0090, 12.5026]

# output_path 설정
output_path = "results/"
output_path += 'asos_' if ASOS else ''
output_path += 'aafos_' if AAFOS else ''
output_path += channels + '_'
output_path += 'time' + str(time_range) + '_'
output_path += resolution
output_path += ('_' + postfix) if postfix else ''
output_path += f'_ws{window_size}_stride{window_stride}'
print(f"Output path: {output_path}")

ASOS weight: 1.00, AAFOS weight: 0.00
Output path: results/asos_16ch_time[-18, -15, -12]_2km_tau_ws1_stride1


In [24]:
# 데이터 정보
train_data_info_list = [{
    'label_type': 'asos',
    'start_date_str': '20200101',
    'end_date_str': '20230630',
    'hour_col_pairs': [(6,'AM')],
    'label_keys': ['93','108','112','119','131','133','136','143','146','156','177','102','104','115','138','152','155','159','165','168','169','184','189']
}] if ASOS else None

test_asos_data_info_list = [{
    'label_type': 'asos',
    'start_date_str': '20230701',
    'end_date_str': '20240630',
    'hour_col_pairs': [(6,'AM')],
    'label_keys': ['93','108','112','119','131','133','136','143','146','156','177','102','104','115','138','152','155','159','165','168','169','184','189']
}] if ASOS else None

# 이미지 설정
origin_size = 900 if resolution == '2km' else 1800
image_size = 192
patch_size = 192

# 데이터 경로
data_path = '/home/dm4/repo/data/kma_data/date_kst_URP'
lat_lon_path = 'assets/gk2a_ami_ko010lc_latlon.nc'

misc_channels = {
    'elevation':'elevation_1km_3600.npy',
    'vegetation':'vegetation_1km_3600.npy',
    'watermap':'watermap_1km_avg_3600.npy'
}

# 학습 설정
batch_size = 16
num_workers = 8
epochs = 25
lr = 1e-3
decay = [10, 20]
lr_decay = 0.1
weight_decay = 1e-5
threshold = [0.25]

# 관측소 좌표
asos_x_base, asos_y_base, asos_image_size = get_crop_base(image_size, label_type='asos')
seq_len = len(time_range)
total_channels = len(channels_name) + len(misc_channels.keys())

# 윈도우 설정 출력
actual_window_size = window_size if window_size is not None else seq_len
n_windows = (seq_len - actual_window_size) // window_stride + 1

print(f"seq_len: {seq_len}")
print(f"total_channels: {total_channels}")
print(f"image_size: {image_size}")
print(f"\nWindow config: size={actual_window_size}, stride={window_stride}")
print(f"Number of windows: {n_windows}")
print(f"Head input channels: {n_windows * actual_window_size * total_channels}")

seq_len: 3
total_channels: 19
image_size: 192

Window config: size=1, stride=1
Number of windows: 3
Head input channels: 57


## 2. 데이터 준비

In [25]:
# 관측소 좌표
asos_map_dict = get_station_map(image_size, label_type='asos', origin_size=origin_size)

asos_land_keys = list(ASOS_LAND_COORD.keys())
asos_coast_keys = list(ASOS_COAST_COORD.keys())
asos_land_map = {k: v for k, v in asos_map_dict.items() if k in asos_land_keys}
asos_coast_map = {k: v for k, v in asos_map_dict.items() if k in asos_coast_keys}

print(f"ASOS stations: {len(asos_map_dict)}")
print(f"  - Land: {len(asos_land_map)}")
print(f"  - Coast: {len(asos_coast_map)}")

ASOS stations: 23
  - Land: 11
  - Coast: 12


In [26]:
# Misc 이미지 로드
misc_images = []
for misc_channel, misc_path in misc_channels.items():
    misc_image = np.load(f'assets/misc_channels/{misc_path}', allow_pickle=True)
    misc_image = process_misc_image(misc_image, image_size, interpolation='cubic')
    misc_images.append(misc_image)

misc_images = np.stack(misc_images, axis=0)
misc_images = torch.tensor(misc_images, dtype=torch.float32)
print(f"MISC images shape: {misc_images.shape}")

# Normalize
misc_mean = misc_images.mean(dim=[1,2], keepdim=True)
misc_std = misc_images.std(dim=[1,2], keepdim=True)

total_mean = channels_mean + misc_mean.squeeze().tolist()
total_std = channels_std + misc_std.squeeze().tolist()

MISC images shape: torch.Size([3, 192, 192])


In [27]:
# Patch candidates (192x192에서는 전체 이미지 사용)
asos_patch_candidate = np.zeros([image_size, image_size], dtype=np.uint8)
for x, y in asos_map_dict.values():
    y_min = np.clip(y - 3*patch_size//4, 0, image_size - patch_size+1)
    y_max = np.clip(y - patch_size//4, 0, image_size - patch_size+1)
    x_min = np.clip(x - 3*patch_size//4, 0, image_size - patch_size+1)
    x_max = np.clip(x - patch_size//4, 0, image_size - patch_size+1)
    asos_patch_candidate[y_min:y_max, x_min:x_max] = 1

asos_patch_candidate = np.argwhere(asos_patch_candidate == 1)[:, [1, 0]]
patch_candidates = {'asos': asos_patch_candidate}
print(f"ASOS patch candidates: {asos_patch_candidate.shape}")

ASOS patch candidates: (1, 2)


In [28]:
# Transform
class SequenceNormalize:
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean, dtype=torch.float32)
        self.std = torch.tensor(std, dtype=torch.float32)
    
    def __call__(self, tensor):
        if tensor.dim() == 4:  # (T, C, H, W)
            mean = self.mean.view(1, -1, 1, 1)
            std = self.std.view(1, -1, 1, 1)
        else:  # (C, H, W)
            mean = self.mean.view(-1, 1, 1)
            std = self.std.view(-1, 1, 1)
        return (tensor - mean) / std

transform = SequenceNormalize(mean=total_mean, std=total_std)

# Dataset
train_dataset = GK2ADataset(
    data_path=data_path, output_path=output_path, data_info_list=train_data_info_list,
    channels=channels, time_range=time_range, channels_calib=channels_calib,
    image_size=image_size, misc_images=misc_images,
    patch_size=patch_size, patch_candidates=patch_candidates,
    transform=transform, train=True
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

if ASOS:
    test_asos_dataset = GK2ADataset(
        data_path=data_path, output_path=output_path, data_info_list=test_asos_data_info_list,
        channels=channels, time_range=time_range, channels_calib=channels_calib,
        image_size=image_size, misc_images=misc_images,
        patch_size=patch_size, patch_candidates=None,
        transform=transform, train=False
    )
    test_asos_dataloader = DataLoader(test_asos_dataset, batch_size=batch_size//2, shuffle=False, num_workers=num_workers, drop_last=False)

== Preparing asos...



Processing asos:   1%|▌                                          | 16/1277 [00:00<00:08, 154.42it/s]

  - 20200101 AM skipped, 2019-12-31 12:00:00 not in date_table
  - 20200104 AM skipped, 2020-01-03 12:00:00 not in date_table


Processing asos:  51%|█████████████████████▍                    | 651/1277 [00:02<00:01, 432.23it/s]

  - 20211003 AM skipped, 2021-10-02 12:00:00 not in date_table


Processing asos: 100%|█████████████████████████████████████████| 1277/1277 [00:05<00:00, 246.29it/s]



  - Total 1274 image-label pairs prepared
  - train_asos_image_label_list.yaml saved

== asos dataset length synced to 1274
GK2A Dataset initialized


== Preparing asos...



Processing asos: 100%|███████████████████████████████████████████| 366/366 [00:01<00:00, 248.35it/s]



  - Total 366 image-label pairs prepared
  - test_asos_image_label_list.yaml saved

== asos dataset length synced to 366
GK2A Dataset initialized




## 3. TAU 모델 정의

In [29]:
# TAU 모델 설정
model = FrostTAU(
    seq_len=seq_len,
    in_channels=total_channels,
    img_size=image_size,
    out_channels=1,
    hid_S=64,        # Encoder/Decoder hidden channels
    N_S=4,           # Encoder/Decoder layers
    N_T=8,           # TAU blocks
    kernel_size=21,  # TAU attention kernel size
    mlp_ratio=4.0,
    drop=0.0,
    drop_path=0.1,
    # Window parameters
    window_size=window_size,
    window_stride=window_stride,
).cuda()

# 윈도우 설정 확인
print(f"Window config: size={model.window_size}, stride={model.window_stride}")
print(f"Number of windows: {model.n_windows}")

# Parameter count
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {n_params:,}")
print(f"Trainable parameters: {n_trainable:,}")

Window config: size=1, stride=1
Number of windows: 3
Total parameters: 1,511,130
Trainable parameters: 1,511,130


In [30]:
model

FrostTAU(
  (encoder): Encoder(
    (enc): Sequential(
      (0): ConvSC(
        (conv): Conv2d(19, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): GroupNorm(2, 64, eps=1e-05, affine=True)
        (act): SiLU(inplace=True)
      )
      (1): ConvSC(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): GroupNorm(2, 64, eps=1e-05, affine=True)
        (act): SiLU(inplace=True)
      )
      (2): ConvSC(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): GroupNorm(2, 64, eps=1e-05, affine=True)
        (act): SiLU(inplace=True)
      )
      (3): ConvSC(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): GroupNorm(2, 64, eps=1e-05, affine=True)
        (act): SiLU(inplace=True)
      )
    )
  )
  (temporal): TemporalModule(
    (blocks): ModuleList(
      (0): TAUBlock(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, 

## 4. 학습

In [31]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def calc_measure_valid(Y_test, Y_test_hat, cutoff=0.5):
    Y_test = Y_test.ravel()
    Y_test_hat = Y_test_hat.ravel()
    
    Y_valid = (~np.isnan(Y_test))
    Y_test = Y_test[Y_valid]
    Y_test_hat = Y_test_hat[Y_valid]

    cfmat = confusion_matrix(Y_test, Y_test_hat > cutoff, labels=[0,1])
    acc = np.trace(cfmat) / np.sum(cfmat)
    csi = cfmat[1,1] / (np.sum(cfmat) - cfmat[0,0] + 1e-8)
    
    try:
        auroc = roc_auc_score(Y_test, Y_test_hat)
    except:
        auroc = 0.0

    return csi, acc, auroc

In [32]:
def train_step(model, images, labels, coords, map_dict, cls_idx=0, lambda_tau=0.1):
    """TAU 학습 스텝 (Sliding Window 지원)"""
    images = images.cuda(non_blocking=True)
    labels = labels.cuda(non_blocking=True)

    y_target = images.clone()
    pred_map, loss_dict = model(images, y_target)
    pred_map = pred_map[:, cls_idx]

    pred_vec = torch.zeros_like(labels)

    for b, (px, py) in enumerate(coords):
        for i, (x, y) in enumerate(map_dict.values()):
            if px <= x < px + patch_size and py <= y < py + patch_size:
                pred_vec[b, i] = pred_map[b, y - py, x - px]
            else:
                pred_vec[b, i] = labels[b, i]

    labels_valid = (~torch.isnan(labels)).float()
    labels = torch.nan_to_num(labels, 0.0)

    # Focal loss
    loss_focal_raw = sigmoid_focal_loss(pred_vec, labels, alpha=-1, gamma=2, reduction='none')
    loss_focal = (loss_focal_raw * labels_valid).sum() / labels_valid.sum().clamp(min=1.0)

    # Non-frost loss
    valid_any_batch = (labels_valid.sum() > 0)
    all_zero_batch = ((labels * labels_valid).sum() == 0)

    if not (valid_any_batch and all_zero_batch):
        loss_non_frost = torch.tensor(0.0, device=labels.device)
    else:
        loss_non_frost = binary_cross_entropy_with_logits(
            pred_map, torch.zeros_like(pred_map), reduction='mean'
        )

    # CSI loss
    prob = torch.sigmoid(pred_vec)
    tp = torch.sum(prob * labels * labels_valid, dim=0)
    fn = torch.sum((1 - prob) * labels * labels_valid, dim=0)
    fp = torch.sum(prob * (1 - labels) * labels_valid, dim=0)
    loss_csi = torch.mean(-torch.log(tp + 1e-10) + torch.log(tp + fn + fp + 1e-10))

    # Total loss
    loss_cls = loss_focal + loss_non_frost + loss_csi
    
    # TAU loss (MSE + DDR) - DDR은 window_size > 1일 때만 계산됨
    loss_mse = loss_dict['loss_mse']
    loss_ddr = loss_dict['loss_ddr']
    loss_tau = loss_mse + loss_ddr
    total_loss = loss_cls + lambda_tau * loss_tau

    return total_loss, {
        'cls': loss_cls.item(),
        'mse': loss_mse.item() if torch.is_tensor(loss_mse) else loss_mse,
        'ddr': loss_ddr.item() if torch.is_tensor(loss_ddr) else loss_ddr,
    }

In [33]:
# 학습 루프
lambda_tau = 0.1
accumulation_steps = 4

print(f"Batch size: {batch_size}")
print(f"Accumulation steps: {accumulation_steps}")
print(f"Effective batch size: {batch_size * accumulation_steps}")
print(f"Window config: size={window_size}, stride={window_stride}")

for time_range in time_range_list:
    print(f"\n{'='*60}")
    print(f"Running experiment with time_range={time_range}")
    print(f"{'='*60}\n")
    
    seq_len = len(time_range)
    
    # output_path 설정
    output_path = "results/"
    output_path += 'asos_' if ASOS else ''
    output_path += 'aafos_' if AAFOS else ''
    output_path += channels + '_'
    output_path += 'time' + str(time_range) + '_'
    output_path += resolution
    output_path += ('_' + postfix) if postfix else ''
    output_path += f'_ws{window_size}_stride{window_stride}'
    print(f"Output path: {output_path}")
    
    # Dataset 생성
    train_dataset = GK2ADataset(
        data_path=data_path, output_path=output_path, data_info_list=train_data_info_list,
        channels=channels, time_range=time_range, channels_calib=channels_calib,
        image_size=image_size, misc_images=misc_images,
        patch_size=patch_size, patch_candidates=patch_candidates,
        transform=transform, train=True
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

    test_asos_dataset = GK2ADataset(
        data_path=data_path, output_path=output_path, data_info_list=test_asos_data_info_list,
        channels=channels, time_range=time_range, channels_calib=channels_calib,
        image_size=image_size, misc_images=misc_images,
        patch_size=patch_size, patch_candidates=None,
        transform=transform, train=False
    )
    test_asos_dataloader = DataLoader(test_asos_dataset, batch_size=batch_size//2, shuffle=False, num_workers=num_workers, drop_last=False)

    for seed in seeds:
        if os.path.exists(f'{output_path}/{seed}/ckpt.pt'):
            print(f'Seed {seed} already done. Skipping...')
            continue

        set_seed(seed)
        os.makedirs(f'{output_path}/{seed}', exist_ok=True)

        model = FrostTAU(
            seq_len=seq_len,
            in_channels=total_channels,
            img_size=image_size,
            out_channels=1,
            hid_S=64,
            N_S=4,
            N_T=8,
            kernel_size=21,
            mlp_ratio=4.0,
            drop=0.0,
            drop_path=0.1,
            window_size=window_size,
            window_stride=window_stride,
        ).cuda()

        print(f"Model window_size={model.window_size}, n_windows={model.n_windows}")

        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay, gamma=lr_decay)

        results = {
            'loss': {'asos': [], 'total': []},
            'csi': {'asos': {}},
            'acc': {'asos': {}},
            'auroc': {'asos': []},
            'best_asos': {},
        }

        for cutoff_str in threshold:
            cutoff_str = str(cutoff_str)
            results['csi']['asos'][cutoff_str] = []
            results['acc']['asos'][cutoff_str] = []
            results['best_asos'][cutoff_str] = {'csi': 0.0, 'epoch': -1, 'model': None}

        for epoch in range(epochs):
            model.train()

            total_loss = 0.0
            total_loss_cls = 0.0
            total_loss_mse = 0.0
            total_loss_ddr = 0.0

            train_dataset.sync_dataset_length()

            optimizer.zero_grad()

            for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch}")):
                images, label, coords = batch[0]
                loss, loss_breakdown = train_step(
                    model, images, label, coords, asos_map_dict,
                    cls_idx=0, lambda_tau=lambda_tau
                )

                loss = loss / accumulation_steps
                loss.backward()

                if (batch_idx + 1) % accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                total_loss += float(loss.item()) * accumulation_steps
                total_loss_cls += loss_breakdown['cls']
                total_loss_mse += loss_breakdown['mse']
                total_loss_ddr += loss_breakdown['ddr']

            if (batch_idx + 1) % accumulation_steps != 0:
                optimizer.step()
                optimizer.zero_grad()

            n_batches = len(train_dataloader)
            total_loss /= n_batches
            total_loss_cls /= n_batches
            total_loss_mse /= n_batches
            total_loss_ddr /= n_batches

            print(f"Epoch {epoch} - Total: {total_loss:.4f}")
            print(f"  L_cls: {total_loss_cls:.4f}, L_mse: {total_loss_mse:.4f}, L_ddr: {total_loss_ddr:.4f}")

            results['loss']['total'].append(total_loss)

            # Validation
            model.eval()
            with torch.no_grad():
                pred_vecs = []
                labels_list = []
                pred_vecs_land = []
                labels_land = []
                pred_vecs_coast = []
                labels_coast = []

                for batch in test_asos_dataloader:
                    images, label, coords = batch[0]
                    images = images.cuda(non_blocking=True)

                    pred_map = model(images)[:, 0]

                    for b, (px, py) in enumerate(coords):
                        pred_vec = []
                        label_vec = []
                        pred_vec_land = []
                        label_land = []
                        pred_vec_coast = []
                        label_coast = []

                        for i, (key, (x, y)) in enumerate(asos_map_dict.items()):
                            if px <= x < px + patch_size and py <= y < py + patch_size:
                                pred_val = pred_map[b, y - py, x - px].item()
                                label_val = label[b, i].item()

                                pred_vec.append(pred_val)
                                label_vec.append(label_val)

                                if key in asos_land_keys:
                                    pred_vec_land.append(pred_val)
                                    label_land.append(label_val)
                                elif key in asos_coast_keys:
                                    pred_vec_coast.append(pred_val)
                                    label_coast.append(label_val)

                        if pred_vec:
                            pred_vecs.append(pred_vec)
                            labels_list.append(label_vec)
                        if pred_vec_land:
                            pred_vecs_land.append(pred_vec_land)
                            labels_land.append(label_land)
                        if pred_vec_coast:
                            pred_vecs_coast.append(pred_vec_coast)
                            labels_coast.append(label_coast)

                pred_vecs = np.array([p for pv in pred_vecs for p in pv])
                labels_list = np.array([l for lv in labels_list for l in lv])
                pred_vecs_land = np.array([p for pv in pred_vecs_land for p in pv])
                labels_land = np.array([l for lv in labels_land for l in lv])
                pred_vecs_coast = np.array([p for pv in pred_vecs_coast for p in pv])
                labels_coast = np.array([l for lv in labels_coast for l in lv])

                for cutoff in threshold:
                    cutoff_str = str(cutoff)
                    csi, acc, auroc = calc_measure_valid(labels_list, pred_vecs, cutoff=cutoff)
                    results['csi']['asos'][cutoff_str].append(csi)
                    results['acc']['asos'][cutoff_str].append(acc)

                    csi_land, _, _ = calc_measure_valid(labels_land, pred_vecs_land, cutoff=cutoff)
                    csi_coast, _, _ = calc_measure_valid(labels_coast, pred_vecs_coast, cutoff=cutoff)

                    is_best = csi > results['best_asos'][cutoff_str]['csi']
                    results['auroc']['asos'].append(auroc)

                    if is_best:
                        results['best_asos'][cutoff_str]['csi'] = csi
                        results['best_asos'][cutoff_str]['epoch'] = epoch
                        results['best_asos'][cutoff_str]['model'] = copy.deepcopy(model.state_dict())

                    print(f"  ASOS (T={cutoff}): CSI {csi:.4f}, AUROC {auroc:.4f} {'*' if is_best else ''}")
                    print(f"    Land: {csi_land:.4f}, Coast: {csi_coast:.4f}")

            scheduler.step()

            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'results': results,
                'window_size': window_size,
                'window_stride': window_stride,
            }, f'{output_path}/{seed}/resume.pt')

        for cutoff_str in threshold:
            cutoff_str = str(cutoff_str)
            if results['best_asos'][cutoff_str]['model'] is not None:
                torch.save({
                    'model': results['best_asos'][cutoff_str]['model'],
                    'epoch': results['best_asos'][cutoff_str]['epoch'],
                    'csi': results['best_asos'][cutoff_str]['csi'],
                    'window_size': window_size,
                    'window_stride': window_stride,
                }, f'{output_path}/{seed}/ckpt.pt')
                print(f"Seed {seed} - Best CSI: {results['best_asos'][cutoff_str]['csi']:.4f}")

        if os.path.exists(f'{output_path}/{seed}/resume.pt'):
            os.remove(f'{output_path}/{seed}/resume.pt')

Batch size: 16
Accumulation steps: 4
Effective batch size: 64
Window config: size=1, stride=1

Running experiment with time_range=[-18, -15, -12]

Output path: results/asos_16ch_time[-18, -15, -12]_2km_tau_ws1_stride1
== Preparing asos...



Processing asos:   1%|▌                                          | 16/1277 [00:00<00:08, 155.09it/s]

  - 20200101 AM skipped, 2019-12-31 12:00:00 not in date_table
  - 20200104 AM skipped, 2020-01-03 12:00:00 not in date_table


Processing asos:  51%|█████████████████████▍                    | 652/1277 [00:02<00:01, 432.44it/s]

  - 20211003 AM skipped, 2021-10-02 12:00:00 not in date_table


Processing asos: 100%|█████████████████████████████████████████| 1277/1277 [00:05<00:00, 246.73it/s]



  - Total 1274 image-label pairs prepared
  - train_asos_image_label_list.yaml saved

== asos dataset length synced to 1274
GK2A Dataset initialized


== Preparing asos...



Processing asos: 100%|███████████████████████████████████████████| 366/366 [00:01<00:00, 248.26it/s]



  - Total 366 image-label pairs prepared
  - test_asos_image_label_list.yaml saved

== asos dataset length synced to 366
GK2A Dataset initialized


Seed 0 already done. Skipping...
Seed 1 already done. Skipping...
Seed 2 already done. Skipping...


## 5. 평가

In [34]:
N_LAND = 11
N_COAST = 12

for time_range in time_range_list:
    print(f"\n{'='*60}")
    print(f"Running experiment with time_range={time_range}")
    print(f"{'='*60}\n")
    
    seq_len = len(time_range)
    
    # output_path 설정
    output_path = "results/"
    output_path += 'asos_' if ASOS else ''
    output_path += 'aafos_' if AAFOS else ''
    output_path += channels + '_'
    output_path += 'time' + str(time_range) + '_'
    output_path += resolution
    output_path += ('_' + postfix) if postfix else ''
    output_path += f'_ws{window_size}_stride{window_stride}'
    print(f"Output path: {output_path}")
    
    # Test Dataset 생성
    test_asos_dataset = GK2ADataset(
        data_path=data_path, output_path=output_path, data_info_list=test_asos_data_info_list,
        channels=channels, time_range=time_range, channels_calib=channels_calib,
        image_size=image_size, misc_images=misc_images,
        patch_size=patch_size, patch_candidates=None,
        transform=transform, train=False
    )
    test_asos_dataloader = DataLoader(test_asos_dataset, batch_size=batch_size//2, shuffle=False, num_workers=num_workers, drop_last=False)

    results_dict = {}
    for cutoff in threshold:
        results_dict[str(cutoff)] = {
            'asos': {},
            'asos_land': {},
            'asos_coast': {},
        }

    for seed in seeds:
        ckpt_path = f'{output_path}/{seed}/ckpt.pt'
        if not os.path.exists(ckpt_path):
            print(f'Seed {seed} not found. Skipping...')
            continue
    
        ckpt = torch.load(ckpt_path)
        print(f'--- Seed {seed}: epoch {ckpt["epoch"]}, CSI {ckpt["csi"]:.4f} ---')
    
        # 저장된 window 설정 로드 (없으면 기본값)
        ckpt_window_size = ckpt.get('window_size', window_size)
        ckpt_window_stride = ckpt.get('window_stride', window_stride)

        model = FrostTAU(
            seq_len=seq_len,
            in_channels=total_channels,
            img_size=image_size,
            out_channels=1,
            hid_S=64,
            N_S=4,
            N_T=8,
            kernel_size=21,
            mlp_ratio=4.0,
            drop=0.0,
            drop_path=0.1,
            window_size=ckpt_window_size,
            window_stride=ckpt_window_stride,
        ).cuda()
    
        model.load_state_dict(ckpt['model'])
        model.eval()
    
        for cutoff in threshold:
            cutoff_str = str(cutoff)

            with torch.no_grad():
                asos_pred_vec_list = []
                asos_labels_list = []

                for batch in tqdm(test_asos_dataloader, desc=f"Eval Seed {seed}"):
                    images, label, coords = batch[0]
                    images = images.cuda()

                    pred_map = model(images)[:, 0]
                    pred_map = torch.sigmoid(pred_map)

                    pred_vec = []
                    for x, y in asos_map_dict.values():
                        if 0 <= y < pred_map.shape[1] and 0 <= x < pred_map.shape[2]:
                            pred_vec.append(pred_map[:, y, x])
                        else:
                            pred_vec.append(torch.zeros(pred_map.shape[0], device=pred_map.device))
                    pred_vec = torch.stack(pred_vec, dim=1)

                    asos_pred_vec_list.append(pred_vec.cpu().numpy())
                    asos_labels_list.append(label.numpy())

                pred_vecs = np.concatenate(asos_pred_vec_list, axis=0)
                labels = np.concatenate(asos_labels_list, axis=0)

                pred_vecs_land = pred_vecs[:, :N_LAND]
                pred_vecs_coast = pred_vecs[:, N_LAND:]
                labels_land = labels[:, :N_LAND]
                labels_coast = labels[:, N_LAND:]

                label_cols = [col for _, col in test_asos_data_info_list[0]['hour_col_pairs']]
                for col in label_cols:
                    indices = [i for i, data_dict in enumerate(test_asos_dataset.data_info_list[0]['data_dict_list']) 
                              if data_dict['label_col'] == col]

                    pred_vec_selected = pred_vecs[indices]
                    labels_selected = labels[indices]
                    results = calc_measure_valid(labels_selected, pred_vec_selected, cutoff=cutoff)

                    results_dict[cutoff_str]['asos'].setdefault(col, {}).setdefault('csi', []).append(results[0])
                    results_dict[cutoff_str]['asos'][col].setdefault('acc', []).append(results[1])
                    results_dict[cutoff_str]['asos'][col].setdefault('auroc', []).append(results[2])
    
                    # Land
                    pred_vec_land_selected = pred_vecs_land[indices]
                    labels_land_selected = labels_land[indices]
                    results_land = calc_measure_valid(labels_land_selected, pred_vec_land_selected, cutoff=cutoff)
                    results_dict[cutoff_str]['asos_land'].setdefault(col, {}).setdefault('csi', []).append(results_land[0])
                    results_dict[cutoff_str]['asos_land'][col].setdefault('acc', []).append(results_land[1])
                    results_dict[cutoff_str]['asos_land'][col].setdefault('auroc', []).append(results_land[2])

                    # Coast
                    pred_vec_coast_selected = pred_vecs_coast[indices]
                    labels_coast_selected = labels_coast[indices]
                    results_coast = calc_measure_valid(labels_coast_selected, pred_vec_coast_selected, cutoff=cutoff)
                    results_dict[cutoff_str]['asos_coast'].setdefault(col, {}).setdefault('csi', []).append(results_coast[0])
                    results_dict[cutoff_str]['asos_coast'][col].setdefault('acc', []).append(results_coast[1])
                    results_dict[cutoff_str]['asos_coast'][col].setdefault('auroc', []).append(results_coast[2])

                    print(f"  {col}: CSI={results[0]:.4f}, Land={results_land[0]:.4f}, Coast={results_coast[0]:.4f}")

    # 결과 집계 (모든 시드 평가 후 - 들여쓰기 수정됨)
    all_results_rows = []

    for cutoff_str, type_dict_for_thr in results_dict.items():
        for data_type, type_dict in type_dict_for_thr.items():
            if not type_dict:
                continue

            csi_mean_list = []
            acc_mean_list = []
            auroc_mean_list = []

            for label_col, metrics_dict in type_dict.items():
                if not metrics_dict.get('csi'):
                    continue

                csi_mean = np.mean(metrics_dict['csi'])
                csi_mean_list.append(csi_mean)
                
                acc_mean = np.mean(metrics_dict['acc']) if metrics_dict.get('acc') else None
                auroc_mean = np.mean(metrics_dict['auroc']) if metrics_dict.get('auroc') else None
            
                if acc_mean is not None:
                    acc_mean_list.append(acc_mean)
                if auroc_mean is not None:
                    auroc_mean_list.append(auroc_mean)

                row = {
                    'type': data_type,
                    'threshold': cutoff_str,
                    'label': label_col,
                    'csi': f'{csi_mean:.4f}',
                }
                if acc_mean is not None:
                    row['acc'] = f'{acc_mean:.4f}'
                if auroc_mean is not None:
                    row['auroc'] = f'{auroc_mean:.4f}'
            
                all_results_rows.append(row)

            if csi_mean_list:
                row = {
                    'type': data_type,
                    'threshold': cutoff_str,
                    'label': 'mean',
                    'csi': f'{np.mean(csi_mean_list):.4f}',
                }
                if acc_mean_list:
                    row['acc'] = f'{np.mean(acc_mean_list):.4f}'
                if auroc_mean_list:
                    row['auroc'] = f'{np.mean(auroc_mean_list):.4f}'
            
                all_results_rows.append(row)

    results_df = pd.DataFrame(all_results_rows)
    print("\n=== Final Results ===")
    print(results_df.to_string(index=False))

    output_csv_path = f'{output_path}/final_results_tau.csv'
    results_df.to_csv(output_csv_path, index=False)
    print(f"\nSaved to {output_csv_path}")


Running experiment with time_range=[-18, -15, -12]

Output path: results/asos_16ch_time[-18, -15, -12]_2km_tau_ws1_stride1
== Preparing asos...



Processing asos: 100%|███████████████████████████████████████████| 366/366 [00:01<00:00, 247.35it/s]



  - Total 366 image-label pairs prepared
  - test_asos_image_label_list.yaml saved

== asos dataset length synced to 366
GK2A Dataset initialized


--- Seed 0: epoch 17, CSI 0.5117 ---


Eval Seed 0: 100%|██████████| 46/46 [00:11<00:00,  3.92it/s]


  AM: CSI=0.5204, Land=0.5817, Coast=0.2657
--- Seed 1: epoch 4, CSI 0.5266 ---


Eval Seed 1: 100%|██████████| 46/46 [00:11<00:00,  3.92it/s]


  AM: CSI=0.4894, Land=0.5439, Coast=0.2715
--- Seed 2: epoch 18, CSI 0.5238 ---


Eval Seed 2: 100%|██████████| 46/46 [00:11<00:00,  3.89it/s]

  AM: CSI=0.5027, Land=0.5543, Coast=0.2837

=== Final Results ===
      type threshold label    csi    acc  auroc
      asos      0.25    AM 0.5042 0.9138 0.9099
      asos      0.25  mean 0.5042 0.9138 0.9099
 asos_land      0.25    AM 0.5600 0.8713 0.9522
 asos_land      0.25  mean 0.5600 0.8713 0.9522
asos_coast      0.25    AM 0.2736 0.9529 0.8276
asos_coast      0.25  mean 0.2736 0.9529 0.8276

Saved to results/asos_16ch_time[-18, -15, -12]_2km_tau_ws1_stride1/final_results_tau.csv



