In [1]:
import os
import yaml
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from metrics.dataloader_trace import RainDataset, get_transform, compute_mean_std, save_trace_dict_to_csv

# Import modules
from models.archs.DPENet_v1 import DPENet
from models.archs.DPENet_v3 import DPENet_v3
from models.archs.Network_test import DPENet_Traceable
from models.archs.losses import SSIMLoss_v2, EdgeLoss_v2, L1Loss
from models.CosineAnnealingRestartCyclicLR import CosineAnnealingRestartCyclicLR

import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image

def load_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    
    # Ensure specific types
    config["epochs"] = int(config["epochs"])
    config["batch_size"] = int(config["batch_size"])
    config["lr"] = float(config["lr"])
    config["eta_min"] = float(config["eta_min"])
    config["periods"] = [int(period) for period in config["periods"]]
    config["restart_weights"] = [float(weight) for weight in config["restart_weights"]]
    config["num_workers"] = int(config["num_workers"])
    config["log_interval"] = int(config["log_interval"])
    config["use_amp"] = bool(config["use_amp"])
    config["grad_clip"] = float(config["grad_clip"])  # Êñ∞Â¢ûÊ¢ØÂ∫¶Ë£ÅÂâ™ÈñæÂÄºÔºåÈ†êË®≠ÁÇ∫ 1.0
    config["save_trace"] = bool(config["save_trace"])
    
    return config



In [2]:
if __name__ == "__main__":
    # Load configuration from YAML file
    config = load_config('config.yml')
        
    # ÂâµÂª∫Â≠òÊîæÊ®°ÂûãÁöÑË≥áÊñôÂ§æ
    os.makedirs(config["checkpoint_dir"], exist_ok=True)

    # ÂâµÂª∫Ê®°Âûã
    model = DPENet_Traceable()
    model.to(config["device"])

    # ÂâµÂª∫ÊêçÂ§±ÂáΩÊï∏
    ssim_loss = SSIMLoss_v2().to(config["device"])
    edge_loss = EdgeLoss_v2().to(config["device"])
    l1_loss = L1Loss(loss_weight=1.0, reduction='mean').to(config["device"])

    # ÂâµÂª∫ÂÑ™ÂåñÂô® & Â≠∏ÁøíÁéáË™øÊï¥Âô®
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50, 80], gamma=0.2)

    train_dataset = RainDataset(mode='train', dataset_name=config["dataset_name"], transform=get_transform(train=True))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config["batch_size"], 
                                            shuffle=True, drop_last=True, num_workers=4, pin_memory=True)

    # Ë®≠ÂÆö AMPÔºàÊ∑∑ÂêàÁ≤æÂ∫¶Ôºâ
    # scaler = GradScaler('cuda', enabled=config["use_amp"])
    # Ë®ìÁ∑¥Ëø¥Âúà
    print(f"ÈñãÂßãË®ìÁ∑¥ DPENet_CFIMÔºå‰ΩøÁî®Ë®≠ÂÇôÔºö{config['device']}")
    
    for epoch in range(config["epochs"]):
        
        model.train()
        running_loss = 0.0
        running_ssim = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{config['epochs']}]")

        for batch_idx, (input_img, target_img) in enumerate(pbar):
            input_img, target_img = input_img.to(config["device"]), target_img.to(config["device"])

            optimizer.zero_grad()

            mid_output , trace_features = model(input_img, trace=True) #, break_flagÁî®ÊñºË™øË©¶

            ssim_val = ssim_loss(mid_output, target_img)
            edge_val = edge_loss(mid_output, target_img)
            
            if torch.isnan(ssim_val) or torch.isinf(ssim_val):
                print("‚ö† SSIM Loss Âá∫ÈåØ", ssim_val)

            if torch.isnan(edge_val) or torch.isinf(edge_val):
                print("‚ö† Edge Loss Âá∫ÈåØ", edge_val)

            loss = 1 - ssim_val + 0.05 * edge_val + 0.2
            """with autocast"""
            # ÂèçÂêëÂÇ≥Êí≠
            loss.backward()
            
            # Ê¢ØÂ∫¶Ë£ÅÂâ™ÔºåÈò≤Ê≠¢Ê¢ØÂ∫¶ÁàÜÁÇ∏ÊàñÈúáÁõ™
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
            optimizer.step()

            # Ë®òÈåÑ Loss
            running_loss += loss.item()
            running_ssim += ssim_val.item()
            if batch_idx % config["log_interval"] == 0:
                pbar.set_postfix(loss=loss.item())
        # ‚úÖ ÂÑ≤Â≠ò‰∏≠ÈñìÂ±§ÁâπÂæµÔºàÂ¶ÇÈúÄÔºâ
            if config.get("save_trace", False):  # ‰øÆÊîπÔºöÊñ∞Â¢ûÈÅ∏È†ÖÊéßÂà∂ trace Ë≥áÊñôÂÑ≤Â≠ò
                print("üß™ ÂïüÂãï trace Ë≥áÊñôÂÑ≤Â≠òÂäüËÉΩÔºÅ")
                trace_dir = os.path.join(config["trace_output_dir"], f"epoch{epoch+1}_batch{batch_idx}")
                os.makedirs(trace_dir, exist_ok=True)
                save_trace_dict_to_csv(trace_features, prefix_dir=trace_dir)  # ‚úÖ ‰øÆÊîπÈÄô‰∏ÄË°åÔºÅ

        avg_loss = running_loss / len(train_loader)
        avg_ssim = running_ssim / len(train_loader)

        print(f"üîπ Epoch [{epoch+1}/{config['epochs']}], Loss: {avg_loss:.6f}, SSIM: {avg_ssim:.6f}")
        scheduler.step()  # Êõ¥Êñ∞Â≠∏ÁøíÁéá



ÈñãÂßãË®ìÁ∑¥ DPENet_CFIMÔºå‰ΩøÁî®Ë®≠ÂÇôÔºöcuda


Epoch [1/1]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  5.12it/s, loss=0.679]

üß™ ÂïüÂãï trace Ë≥áÊñôÂÑ≤Â≠òÂäüËÉΩÔºÅ
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: input, shape: (1, 3, 16, 16), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: experts_layer0_output, shape: (1, 3, 16, 16), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: experts_layer1_output, shape: (1, 3, 16, 16), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: experts_layer2_output, shape: (1, 3, 16, 16), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: experts_layer0_weight, shape: (3, 3, 3, 3), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: experts_layer1_weight, shape: (3, 3, 3, 3), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: experts_layer2_weight, shape: (3, 3, 3, 3), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: conv1_layer0_output, shape: (1, 3, 16, 16), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: conv1_layer0_weight, shape: (3, 9, 1, 1), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: conv1_layer0_bias, shape: (3,), ndim: 1
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: pa_layer0_output, shape: (1, 1, 16, 16), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: pa_layer0_weight, shape: (1, 3, 3, 3), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: pa_layer1_output, shape: (1, 1, 16, 16), ndim: 4
üìÅ Ê≠£Âú®ÂÑ≤Â≠ò: pa_layer2_output, shape: (1, 3, 16, 16), ndim: 4
ü


