In [1]:
"""
TAU 모델 학습 스크립트
기존 URP_baseline.ipynb 기반 - 바로 실행 가능
"""

import os
import math
import numpy as np
import cv2
import netCDF4 as nc

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.ops import sigmoid_focal_loss
from torch.nn.functional import binary_cross_entropy_with_logits
from sklearn.metrics import roc_auc_score, confusion_matrix
from tqdm import tqdm
from pathlib import Path

from utils import *
from tau_frost_model import get_tau_model

from torch.cuda.amp import autocast, GradScaler

In [None]:
import torch

In [2]:
# ====== 실험 설정 ====== #
seeds = [0]  # 10번 실험

torch.backends.cudnn.benchmark = True  # 입력 크기가 고정되면 속도 향상
torch.backends.cudnn.deterministic = False  # 재현성보다 속도 우선

# 데이터 설정
ASOS = True
AAFOS = False
assert ASOS or AAFOS, "At least one of ASOS or AAFOS must be True."

channels = '16ch'
time_range = [-18, -15, -12]  # TAU: 4개 시점 (baseline은 [-12])
n_times = len(time_range)
resolution = '2km'
postfix = 'TAU_early_fusion'
tau_n_heads = 8
tau_n_layers = 2

# 출력 경로
output_path = "results/"
output_path += 'asos_' if ASOS else ''
output_path += 'aafos_' if AAFOS else ''
output_path += channels + '_'
output_path += 'time' + str(time_range) + '_'
output_path += resolution
output_path += ('_' + postfix) if postfix != '' else postfix
print(f"Output path name: {output_path}")

# ASOS:AAFOS 비율
asos_aafos_ratio = 5.0
asos_weight = asos_aafos_ratio / (asos_aafos_ratio + 1.0 * AAFOS) if ASOS else 0.0
aafos_weight = 1.0 / (asos_aafos_ratio * ASOS + 1.0) if AAFOS else 0.0
print(f"ASOS weight: {asos_weight:.2f}, AAFOS weight: {aafos_weight:.2f}")

# 기타 설정
latlon = False
central_patch = False
use_patch = False

Output path name: results/asos_16ch_time[-18, -15, -12]_2km_TAU_early_fusion
ASOS weight: 1.00, AAFOS weight: 0.00


In [3]:
# ====== 채널 설정 ====== #
if channels == '16ch':
    channels_name = ['vi004','vi005','vi006','vi008','nr013','nr016','sw038',
                     'wv063','wv069','wv073','ir087','ir096','ir105','ir112','ir123','ir133']
    channels_calib = ['vi004','vi005','vi006','vi008','nr013','nr016','sw038',
                      'wv063','wv069','wv073','ir087','ir096','ir105','ir112','ir123','ir133']
    
    channels_mean = [1.1912e-01, 1.1464e-01, 1.0734e-01, 1.2504e-01, 5.4983e-02, 9.0381e-02,
                     2.7813e+02, 2.3720e+02, 2.4464e+02, 2.5130e+02, 2.6948e+02, 2.4890e+02,
                     2.7121e+02, 2.7071e+02, 2.6886e+02, 2.5737e+02]
    channels_std = [0.1306, 0.1303, 0.1306, 0.1501, 0.0268, 0.0838, 15.8211, 6.1468,
                    7.8054, 9.3251, 16.4265, 9.6150, 17.2518, 17.6064, 17.0090, 12.5026]
else:
    raise ValueError("Invalid channels.")

n_channels = len(channels_name)

In [None]:
# ====== 데이터 설정 ====== #
train_data_info_list = []
if ASOS:
    train_data_info_list.append({
        'label_type': 'asos',
        'start_date_str': '20200101',
        'end_date_str': '20230630',
        'hour_col_pairs': [(6, 'AM')],
        'label_keys': ['93','108','112','119','131','133','136','143','146','156','177',
                       '102','104','115','138','152','155','159','165','168','169','184','189']
    })

test_asos_data_info_list = [
    {
        'label_type': 'asos',
        'start_date_str': '20230701',
        'end_date_str': '20240630',
        'hour_col_pairs': [(6, 'AM')],
        'label_keys': ['93','108','112','119','131','133','136','143','146','156','177',
                       '102','104','115','138','152','155','159','165','168','169','184','189']
    },
] if ASOS else None

# 이미지 크기 설정
origin_size = 900 if resolution == '2km' else 1800
image_size = 512 if resolution == '2km' else 1024
patch_size = 512

# 데이터 경로
data_path = '/home/dm4/repo/data/kma_data/date_kst_URP'

misc_channels = {
    'elevation': 'elevation_1km_3600.npy',
    'vegetation': 'vegetation_1km_3600.npy',
    'watermap': 'watermap_1km_avg_3600.npy'
}
lat_lon_path = 'assets/gk2a_ami_ko010lc_latlon.nc'

# 학습 설정
batch_size = 64
num_workers = 8
epochs = 25
lr = 1e-3
decay = [10, 20]
lr_decay = 0.1
weight_decay = 1e-5
threshold = [0.25]

# TAU 모델 설정
fusion_type = 'early'  # 'early' or 'late'
tau_n_heads = 8
tau_n_layers = 2

# ====== 자동 계산 값 ====== #
asos_x_base, asos_y_base, asos_image_size = get_crop_base(image_size, label_type='asos')
aafos_x_base, aafos_y_base, aafos_image_size = get_crop_base(image_size, label_type='aafos')
aafos_x_base -= asos_x_base
aafos_y_base -= asos_y_base

total_channels = len(channels_name) * len(time_range) + len(misc_channels.keys())
total_channels += 2 if latlon else 0
print(f"Total channels: {total_channels}")

# ====== 관측소 좌표 설정 ====== #
asos_land_map = {k: coord_to_map(*v, origin_size) for k, v in ASOS_LAND_COORD.items()}
asos_land_map = {k: (v[0] - asos_x_base, v[1] - asos_y_base) for k, v in asos_land_map.items()}
asos_coast_map = {k: coord_to_map(*v, origin_size) for k, v in ASOS_COAST_COORD.items()}
asos_coast_map = {k: (v[0] - asos_x_base, v[1] - asos_y_base) for k, v in asos_coast_map.items()}
asos_map_dict = {**asos_land_map, **asos_coast_map}
print(f"ASOS stations: {list(asos_map_dict.keys())}")


Total channels: 51
ASOS stations: [93, 108, 112, 119, 131, 133, 136, 143, 146, 156, 177, 102, 104, 115, 138, 152, 155, 159, 165, 168, 169, 184, 189]


In [5]:
# ====== Mean/Std 설정 ====== #
channels_mean = channels_mean * len(time_range)
channels_std = channels_std * len(time_range)

# ====== Misc Images 로드 ====== #
misc_images = []
for misc_channel, misc_path in misc_channels.items():
    misc_image = np.load(f'assets/misc_channels/{misc_path}', allow_pickle=True)
    misc_image = cv2.resize(misc_image, (asos_image_size, asos_image_size), interpolation=cv2.INTER_CUBIC)
    misc_images.append(misc_image)
misc_images = np.stack(misc_images, axis=0)
misc_images = torch.tensor(misc_images, dtype=torch.float32)

if latlon:
    lat_lon_data = nc.Dataset(lat_lon_path)
    lat = lat_lon_data['lat'][:].data
    lon = lat_lon_data['lon'][:].data
    lat = cv2.resize(lat, (origin_size, origin_size), interpolation=cv2.INTER_CUBIC)
    lon = cv2.resize(lon, (origin_size, origin_size), interpolation=cv2.INTER_CUBIC)
    asos_lat = lat[asos_y_base:asos_y_base + asos_image_size, asos_x_base:asos_x_base + asos_image_size]
    asos_lon = lon[asos_y_base:asos_y_base + asos_image_size, asos_x_base:asos_x_base + asos_image_size]
    
    lat_image = cv2.resize(asos_lat, (asos_image_size, asos_image_size), interpolation=cv2.INTER_CUBIC)
    lon_image = cv2.resize(asos_lon, (asos_image_size, asos_image_size), interpolation=cv2.INTER_CUBIC)
    lat_image = torch.tensor(lat_image, dtype=torch.float32).unsqueeze(0)
    lon_image = torch.tensor(lon_image, dtype=torch.float32).unsqueeze(0)
    misc_images = torch.cat([misc_images, lat_image, lon_image], dim=0)

print(f'MISC images shape: {misc_images.shape}')

# ====== n_misc 계산 ====== #
n_misc = len(misc_channels) + (2 if latlon else 0)

# Normalize misc
misc_mean = misc_images.mean(dim=[1, 2], keepdim=True)
misc_std = misc_images.std(dim=[1, 2], keepdim=True)

total_mean = channels_mean + misc_mean.squeeze().tolist()
total_std = channels_std + misc_std.squeeze().tolist()

print(f'Total mean length: {len(total_mean)}')
print(f'Total std length: {len(total_std)}')

MISC images shape: torch.Size([3, 512, 512])
Total mean length: 51
Total std length: 51


In [6]:
# ====== Patch Candidates 생성 ====== #
asos_patch_candidate = np.zeros([image_size, image_size], dtype=np.uint8)
for x, y in asos_map_dict.values():
    y_min = np.clip(y - 3 * patch_size // 4, 0, image_size - patch_size + 1)
    y_max = np.clip(y - patch_size // 4, 0, image_size - patch_size + 1)
    x_min = np.clip(x - 3 * patch_size // 4, 0, image_size - patch_size + 1)
    x_max = np.clip(x - patch_size // 4, 0, image_size - patch_size + 1)
    asos_patch_candidate[y_min:y_max, x_min:x_max] = 1

asos_patch_candidate = np.argwhere(asos_patch_candidate == 1)[:, [1, 0]]
print(f'ASOS patch candidates shape: {asos_patch_candidate.shape}')

patch_candidates = {'asos': asos_patch_candidate}

ASOS patch candidates shape: (1, 2)


In [7]:
# ====== Transform 설정 ====== #
transform = transforms.Compose([
    transforms.Normalize(mean=total_mean, std=total_std)
])

In [8]:
def calc_measure_valid(Y_test, Y_test_hat, cutoff=0.5):
    Y_test = Y_test.ravel()
    Y_test_hat = Y_test_hat.ravel()
    
    Y_valid = (~np.isnan(Y_test))
    Y_test = Y_test[Y_valid]
    Y_test_hat = Y_test_hat[Y_valid]

    cfmat = confusion_matrix(Y_test, Y_test_hat > cutoff, labels = [0,1])
    acc = np.trace(cfmat) / np.sum(cfmat)
    csi = cfmat[1,1] /(np.sum(cfmat) - cfmat[0,0] + 1e-8)
    
    try:
        auroc = roc_auc_score(Y_test, Y_test_hat)
    except Exception as e:
        auroc = 0.0

    return csi, acc, auroc

In [9]:
# ====== TAU용 train 함수 (baseline 손실함수 유지) ====== #
def train(model, images, labels, coords, map_dict, cls_num=0):
    """
    TAU 모델용 학습 함수 - baseline과 동일한 손실함수 사용
    
    Args:
        model: TAUDeepLab 모델
        images: (B, T*C + misc, H, W) 입력 이미지 (이미 cuda로 이동됨)
        labels: (B, N_stations) 관측소별 라벨
        coords: (B, 2) 패치 좌표 [px, py]
        map_dict: {station_id: (x, y)} 관측소 좌표
        cls_num: 클래스 인덱스 (ASOS=0, AAFOS=1)
    
    Returns:
        loss: Focal Loss + Non-frost Loss + CSI Loss
    """
    images = images.cuda()
    labels = labels.cuda()
    
    # TAU 모델 forward - 입력 형태: (B, T*C + misc, H, W)
    pred_map = model.forward(images)[:, cls_num]  # (B, H, W)
    
    # 관측소 위치에서 prediction 추출 (패치 좌표 고려)
    pred_vec = torch.zeros_like(labels)
    for b, (px, py) in enumerate(coords):
        px, py = int(px), int(py)
        for i, (x, y) in enumerate(map_dict.values()):
            # 패치 내부에 있는 관측소만 prediction 사용
            if px <= x < px + patch_size and py <= y < py + patch_size:
                pred_vec[b, i] = pred_map[b, y - py, x - px]
            else:
                # 패치 외부 관측소는 라벨 값으로 대체 (loss 기여 없음)
                pred_vec[b, i] = labels[b, i]
    
    # Valid mask (NaN 제외)
    labels_valid = (~torch.isnan(labels)).float()
    labels = torch.nan_to_num(labels, 0.0)
    
    # 1. Focal Loss
    loss_focal_raw = sigmoid_focal_loss(pred_vec, labels, alpha=-1, gamma=2, reduction='none')
    loss_focal = (loss_focal_raw * labels_valid).sum() / labels_valid.sum().clamp(min=1.0)
    
    # 2. Non-frost Loss (배치 내 모든 라벨이 0일 때 전체 맵에 대해 BCE 적용)
    valid_any_batch = (labels_valid.sum() > 0)
    all_zero_batch = ((labels * labels_valid).sum() == 0)
    
    if not (valid_any_batch and all_zero_batch):
        loss_non_frost = torch.tensor(0.0, device=labels.device)
    else:
        loss_non_frost = binary_cross_entropy_with_logits(
            pred_map, torch.zeros_like(pred_map), reduction='mean'
        )
    
    # 3. CSI Loss (미분 가능한 CSI 근사)
    pred_prob = torch.sigmoid(pred_vec)
    tp = torch.sum(pred_prob * labels * labels_valid, dim=0)
    fn = torch.sum((1 - pred_prob) * labels * labels_valid, dim=0)
    fp = torch.sum(pred_prob * (1 - labels) * labels_valid, dim=0)
    
    # CSI = TP / (TP + FN + FP) → -log(CSI) 최소화
    loss_csi = torch.mean(-torch.log(tp + 1e-10) + torch.log(tp + fn + fp + 1e-10))
    
    return loss_focal + loss_non_frost + loss_csi

In [10]:
# ====== 메인 학습 루프 ====== #
if __name__ == "__main__":
    
    print(f"\n{'='*60}")
    print(f"TAU Frost Forecasting Model Training")
    print(f"{'='*60}")
    print(f"Fusion type: {fusion_type}")
    print(f"Time range: {time_range}")
    print(f"N times: {n_times}, N channels: {n_channels}")
    print(f"{'='*60}\n")
    
    # ====== 데이터셋 생성 ====== #
    print("Creating datasets...")
    train_dataset = GK2ADataset(
        data_path=data_path,
        output_path=output_path,
        data_info_list=train_data_info_list,
        channels=channels,
        time_range=time_range,
        channels_calib=channels_calib,
        image_size=image_size,
        misc_images=misc_images,
        patch_size=patch_size,
        patch_candidates=patch_candidates,
        transform=transform,
        train=True
    )
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
        pin_memory=True  # GPU 전송 속도 향상
    )
    
    if ASOS:
        test_asos_dataset = GK2ADataset(
            data_path=data_path,
            output_path=output_path,
            data_info_list=test_asos_data_info_list,
            channels=channels,
            time_range=time_range,
            channels_calib=channels_calib,
            image_size=image_size,
            misc_images=misc_images,
            patch_size=patch_size,
            patch_candidates=None,
            transform=transform,
            train=False
        )
        test_asos_dataloader = DataLoader(
            test_asos_dataset,
            batch_size=batch_size // 2,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
            pin_memory=True
        )
    
    # ====== Seed 별 학습 ====== #
    for seed in seeds:
        print(f'\n{"="*60}')
        print(f'Seed {seed}')
        print(f'{"="*60}')
        
        if os.path.exists(f'{output_path}/{seed}/ckpt.pt'):
            print(f'Seed {seed} already done. Skipping...')
            continue
        
        # 시드 설정
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # TAU 모델 생성
        model = get_tau_model(
            total_channels=total_channels,
            patch_size=patch_size,
            n_channels=n_channels,
            n_misc=n_misc,
            n_times=n_times,
            num_classes=2,
            tau_n_heads=tau_n_heads,
            tau_n_layers=tau_n_layers
        )
        model = model.cuda()

        print(f"모델 파라미터 수: {sum(p.numel() for p in model.parameters()):,}")
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay, gamma=lr_decay)
        
        # AMP GradScaler 초기화
        scaler = torch.cuda.amp.GradScaler()
        
        start_epoch = 0
        
        # 결과 저장
        results = dict(
            loss={'asos': [], 'aafos': [], 'total': []},
            csi={'asos': {}, 'aafos': {}},
            acc={'asos': {}, 'aafos': {}},
            auroc={'asos': [], 'aafos': []},
            best_asos={},
            best_aafos={},
            best_mean={},
        )
        
        for cutoff in threshold:
            cutoff_str = str(cutoff)
            results['csi']['asos'][cutoff_str] = []
            results['acc']['asos'][cutoff_str] = []
            results['best_asos'][cutoff_str] = {'csi': 0.0, 'epoch': -1, 'model': None}
            
            results['csi']['aafos'][cutoff_str] = []
            results['acc']['aafos'][cutoff_str] = []
            results['best_aafos'][cutoff_str] = {'csi': 0.0, 'epoch': -1, 'model': None}
            
            results['best_mean'][cutoff_str] = {'csi': 0.0, 'epoch': -1, 'model': None}
        
        # Resume 체크
        if os.path.exists(f'{output_path}/{seed}/resume.pt'):
            resume = torch.load(f'{output_path}/{seed}/resume.pt')
            model.load_state_dict(resume['model'])
            optimizer.load_state_dict(resume['optimizer'])
            scheduler.load_state_dict(resume['scheduler'])
            if 'scaler' in resume:
                scaler.load_state_dict(resume['scaler'])
            start_epoch = resume['epoch'] + 1
            results = resume['results']
            print(f'Resuming from epoch {start_epoch}...')
        
        # ====== 학습 루프 ====== #
        for epoch in range(start_epoch, epochs):
            model.train()
            
            total_loss_asos = 0.0
            total_loss_aafos = 0.0
            total_loss = 0.0
            
            # 데이터셋 길이 동기화
            train_dataset.sync_dataset_length()
            
            pbar = tqdm(train_dataloader, desc=f'Epoch {epoch:2d}')
            for batch in pbar:
                optimizer.zero_grad(set_to_none=True)
                
                # AMP autocast 적용
                with torch.cuda.amp.autocast():
                    # ASOS 학습
                    if ASOS:
                        images, label, coords = batch[0]
                        loss_asos = train(model, images, label, coords, asos_map_dict, cls_num=0)
                    else:
                        loss_asos = torch.tensor(0.0).cuda()
                    
                    # AAFOS 학습
                    if AAFOS:
                        images, label, coords = batch[1] if ASOS else batch[0]
                        loss_aafos = train(model, images, label, coords, aafos_map_dict, cls_num=1)
                    else:
                        loss_aafos = torch.tensor(0.0).cuda()
                    
                    loss = asos_weight * loss_asos + aafos_weight * loss_aafos
                
                # Scaler를 사용한 backward 및 optimizer step
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                total_loss_asos += loss_asos.item()
                total_loss_aafos += loss_aafos.item()
                total_loss += loss.item()
                
                # Progress bar 업데이트
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'asos': f'{loss_asos.item():.4f}'
                })
            
            total_loss_asos /= len(train_dataloader)
            total_loss_aafos /= len(train_dataloader)
            total_loss /= len(train_dataloader)
            
            results['loss']['asos'].append(total_loss_asos)
            results['loss']['aafos'].append(total_loss_aafos)
            results['loss']['total'].append(total_loss)
            
            print(f'Epoch {epoch:2d} - Total: {total_loss:.4f}, ASOS: {total_loss_asos:.4f}, AAFOS: {total_loss_aafos:.4f}')
            scheduler.step()
            
            # ====== 평가 ====== #
            model.eval()
            with torch.no_grad():
                if ASOS:
                    asos_pred_vec_list = []
                    asos_labels_list = []
                    
                    for batch in test_asos_dataloader:
                        images, label, coords = batch[0]
                        images = images.cuda()
                        
                        # 평가 시에도 AMP autocast 적용 (추론 속도 향상)
                        with torch.cuda.amp.autocast():
                            # forward_by_patch는 (B, T*C + misc, H, W) 형태를 받음
                            pred_map = model.forward_by_patch(images, patch_size=patch_size)[:, 0]
                            pred_map = F.sigmoid(pred_map)
                        
                        # 관측소 위치에서 prediction 추출
                        pred_vec = []
                        for x, y in asos_map_dict.values():
                            pred_vec.append(pred_map[:, y, x])
                        pred_vec = torch.stack(pred_vec, dim=1)
                        
                        asos_pred_vec_list.append(pred_vec.cpu().numpy())
                        asos_labels_list.append(label.numpy())
                    
                    pred_vecs = np.concatenate(asos_pred_vec_list, axis=0)
                    labels = np.concatenate(asos_labels_list, axis=0)
                    
                    for cutoff in threshold:
                        asos_result = calc_measure_valid(labels, pred_vecs, cutoff=cutoff)
                        csi, acc, auroc = asos_result[0], asos_result[1], asos_result[2]
                        cutoff_str = str(cutoff)
                        
                        results['csi']['asos'][cutoff_str].append(csi)
                        results['acc']['asos'][cutoff_str].append(acc)
                        
                        is_best = csi > results['best_asos'][cutoff_str]['csi']
                        print(f'\t - ASOS (T={cutoff:.2f}): CSI {csi:.4f}, AUROC {auroc:.4f} {"*" if is_best else ""}')
                        
                        if is_best:
                            results['best_asos'][cutoff_str]['csi'] = csi
                            results['best_asos'][cutoff_str]['epoch'] = epoch
                            results['best_asos'][cutoff_str]['model'] = model.state_dict()
                        
                        # Land/Coast 분리 평가
                        asos_land_result = calc_measure_valid(labels[:, :11], pred_vecs[:, :11], cutoff=cutoff)
                        asos_coast_result = calc_measure_valid(labels[:, 11:], pred_vecs[:, 11:], cutoff=cutoff)
                        print(f'\t   - Land:  CSI {asos_land_result[0]:.4f}, AUROC {asos_land_result[2]:.4f}')
                        print(f'\t   - Coast: CSI {asos_coast_result[0]:.4f}, AUROC {asos_coast_result[2]:.4f}')
                    
                    results['auroc']['asos'].append(auroc)
            
            print()
            
            # Resume 저장 (scaler 상태 포함)
            resume = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'scaler': scaler.state_dict(),
                'epoch': epoch,
                'results': results
            }
            os.makedirs(f'{output_path}/{seed}', exist_ok=True)
            torch.save(resume, f'{output_path}/{seed}/resume.pt')
        
        # 최종 저장
        torch.save(results, f'{output_path}/{seed}/ckpt.pt')
        if os.path.exists(f'{output_path}/{seed}/resume.pt'):
            os.remove(f'{output_path}/{seed}/resume.pt')
        
        print(f'\nSeed {seed} completed!')
        print(f'Best ASOS CSI (T=0.25): {results["best_asos"]["0.25"]["csi"]:.4f} at epoch {results["best_asos"]["0.25"]["epoch"]}')
    
    print(f"\n{'='*60}")
    print("All seeds completed!")
    print(f"{'='*60}")


TAU Frost Forecasting Model Training
Fusion type: early
Time range: [-18, -15, -12]
N times: 3, N channels: 16

Creating datasets...
== Preparing asos...



Processing asos:   1%|▍                                          | 12/1277 [00:00<00:11, 113.98it/s]

  - 20200101 AM skipped, 2019-12-31 12:00:00 not in date_table
  - 20200104 AM skipped, 2020-01-03 12:00:00 not in date_table


Processing asos:  51%|█████████████████████▎                    | 648/1277 [00:03<00:01, 399.62it/s]

  - 20211003 AM skipped, 2021-10-02 12:00:00 not in date_table


Processing asos: 100%|█████████████████████████████████████████| 1277/1277 [00:06<00:00, 182.52it/s]



  - Total 1274 image-label pairs prepared
  - train_asos_image_label_list.yaml saved

== asos dataset length synced to 1274
GK2A Dataset initialized


== Preparing asos...



Processing asos: 100%|███████████████████████████████████████████| 366/366 [00:02<00:00, 182.75it/s]



  - Total 366 image-label pairs prepared
  - test_asos_image_label_list.yaml saved

== asos dataset length synced to 366
GK2A Dataset initialized



Seed 0
모델 파라미터 수: 12,305,490
== asos dataset length synced to 1274


Epoch  0: 100%|██████████| 637/637 [20:55<00:00,  1.97s/it, loss=0.9736, asos=0.9736]  


Epoch  0 - Total: 7.2798, ASOS: 7.2798, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.0549, AUROC 0.5285 *
	   - Land:  CSI 0.0700, AUROC 0.5208
	   - Coast: CSI 0.0000, AUROC 0.5000

== asos dataset length synced to 1274


Epoch  1: 100%|██████████| 637/637 [20:59<00:00,  1.98s/it, loss=4.5024, asos=4.5024]  


Epoch  1 - Total: 6.6243, ASOS: 6.6243, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.0202, AUROC 0.7021 
	   - Land:  CSI 0.0253, AUROC 0.6591
	   - Coast: CSI 0.0000, AUROC 0.5685

== asos dataset length synced to 1274


Epoch  2: 100%|██████████| 637/637 [20:15<00:00,  1.91s/it, loss=2.2354, asos=2.2354]  


Epoch  2 - Total: 6.0050, ASOS: 6.0050, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2764, AUROC 0.7795 *
	   - Land:  CSI 0.3500, AUROC 0.8332
	   - Coast: CSI 0.0000, AUROC 0.5609

== asos dataset length synced to 1274


Epoch  3: 100%|██████████| 637/637 [23:45<00:00,  2.24s/it, loss=0.0070, asos=0.0070]  


Epoch  3 - Total: 5.1003, ASOS: 5.1003, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.4071, AUROC 0.7947 *
	   - Land:  CSI 0.4964, AUROC 0.8556
	   - Coast: CSI 0.0000, AUROC 0.5389

== asos dataset length synced to 1274


Epoch  4: 100%|██████████| 637/637 [23:28<00:00,  2.21s/it, loss=5.7848, asos=5.7848]  


Epoch  4 - Total: 4.7282, ASOS: 4.7282, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.3867, AUROC 0.8346 
	   - Land:  CSI 0.4778, AUROC 0.9179
	   - Coast: CSI 0.0000, AUROC 0.5064

== asos dataset length synced to 1274


Epoch  5: 100%|██████████| 637/637 [23:28<00:00,  2.21s/it, loss=0.0535, asos=0.0535]  


Epoch  5 - Total: 4.4513, ASOS: 4.4513, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.1742, AUROC 0.6862 
	   - Land:  CSI 0.1867, AUROC 0.5662
	   - Coast: CSI 0.0000, AUROC 0.5000

== asos dataset length synced to 1274


Epoch  6: 100%|██████████| 637/637 [23:23<00:00,  2.20s/it, loss=5.0439, asos=5.0439]  


Epoch  6 - Total: 4.6727, ASOS: 4.6727, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2750, AUROC 0.8198 
	   - Land:  CSI 0.3404, AUROC 0.8765
	   - Coast: CSI 0.0000, AUROC 0.5500

== asos dataset length synced to 1274


Epoch  7: 100%|██████████| 637/637 [23:20<00:00,  2.20s/it, loss=0.4335, asos=0.4335]  


Epoch  7 - Total: 4.7558, ASOS: 4.7558, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.3055, AUROC 0.8254 
	   - Land:  CSI 0.3793, AUROC 0.8866
	   - Coast: CSI 0.0000, AUROC 0.5477

== asos dataset length synced to 1274


Epoch  8: 100%|██████████| 637/637 [23:19<00:00,  2.20s/it, loss=10.3959, asos=10.3959]


Epoch  8 - Total: 4.5642, ASOS: 4.5642, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2746, AUROC 0.7600 
	   - Land:  CSI 0.3405, AUROC 0.7831
	   - Coast: CSI 0.0000, AUROC 0.5133

== asos dataset length synced to 1274


Epoch  9: 100%|██████████| 637/637 [23:22<00:00,  2.20s/it, loss=0.2829, asos=0.2829]  


Epoch  9 - Total: 4.6409, ASOS: 4.6409, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2502, AUROC 0.7930 
	   - Land:  CSI 0.3160, AUROC 0.8405
	   - Coast: CSI 0.0000, AUROC 0.5331

== asos dataset length synced to 1274


Epoch 10: 100%|██████████| 637/637 [23:20<00:00,  2.20s/it, loss=0.0005, asos=0.0005]  


Epoch 10 - Total: 4.3123, ASOS: 4.3123, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2919, AUROC 0.7882 
	   - Land:  CSI 0.3654, AUROC 0.8374
	   - Coast: CSI 0.0000, AUROC 0.5324

== asos dataset length synced to 1274


Epoch 11: 100%|██████████| 637/637 [23:25<00:00,  2.21s/it, loss=0.2024, asos=0.2024]  


Epoch 11 - Total: 4.1959, ASOS: 4.1959, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.3460, AUROC 0.8044 
	   - Land:  CSI 0.4283, AUROC 0.8497
	   - Coast: CSI 0.0000, AUROC 0.5433

== asos dataset length synced to 1274


Epoch 12: 100%|██████████| 637/637 [23:29<00:00,  2.21s/it, loss=0.1642, asos=0.1642]  


Epoch 12 - Total: 4.2040, ASOS: 4.2040, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2371, AUROC 0.7859 
	   - Land:  CSI 0.2987, AUROC 0.8277
	   - Coast: CSI 0.0000, AUROC 0.5290

== asos dataset length synced to 1274


Epoch 13: 100%|██████████| 637/637 [23:36<00:00,  2.22s/it, loss=0.0005, asos=0.0005]  


Epoch 13 - Total: 4.0846, ASOS: 4.0846, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.1490, AUROC 0.7946 
	   - Land:  CSI 0.1901, AUROC 0.8415
	   - Coast: CSI 0.0000, AUROC 0.5312

== asos dataset length synced to 1274


Epoch 14: 100%|██████████| 637/637 [23:40<00:00,  2.23s/it, loss=5.0095, asos=5.0095]  


Epoch 14 - Total: 4.0948, ASOS: 4.0948, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2341, AUROC 0.8020 
	   - Land:  CSI 0.2965, AUROC 0.8548
	   - Coast: CSI 0.0000, AUROC 0.5331

== asos dataset length synced to 1274


Epoch 15: 100%|██████████| 637/637 [23:30<00:00,  2.21s/it, loss=7.5213, asos=7.5213]  


Epoch 15 - Total: 4.0364, ASOS: 4.0364, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.1833, AUROC 0.7926 
	   - Land:  CSI 0.2323, AUROC 0.8341
	   - Coast: CSI 0.0000, AUROC 0.5356

== asos dataset length synced to 1274


Epoch 16: 100%|██████████| 637/637 [23:33<00:00,  2.22s/it, loss=2.0528, asos=2.0528]  


Epoch 16 - Total: 4.0561, ASOS: 4.0561, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.1738, AUROC 0.8105 
	   - Land:  CSI 0.2219, AUROC 0.8763
	   - Coast: CSI 0.0000, AUROC 0.5356

== asos dataset length synced to 1274


Epoch 17: 100%|██████████| 637/637 [23:24<00:00,  2.20s/it, loss=4.2481, asos=4.2481]  


Epoch 17 - Total: 4.0476, ASOS: 4.0476, AAFOS: 0.0000
	 - ASOS (T=0.25): CSI 0.2819, AUROC 0.7976 
	   - Land:  CSI 0.3561, AUROC 0.8581
	   - Coast: CSI 0.0000, AUROC 0.5335

== asos dataset length synced to 1274


Epoch 18:  51%|█████▏    | 328/637 [12:42<11:58,  2.33s/it, loss=12.8180, asos=12.8180]


KeyboardInterrupt: 