In [1]:
import os
import yaml
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from metrics.dataloader import RainDataset, get_transform

# Import modules
from models.archs.DPENet_v1 import DPENet
from models.archs.losses import SSIMLoss_v2, EdgeLoss_v2, L1Loss
from models.CosineAnnealingRestartCyclicLR import CosineAnnealingRestartCyclicLR

In [2]:
def load_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    
    # Ensure specific types
    config["epochs"] = int(config["epochs"])
    config["batch_size"] = int(config["batch_size"])
    config["lr"] = float(config["lr"])
    config["eta_min"] = float(config["eta_min"])
    config["periods"] = [int(period) for period in config["periods"]]
    config["restart_weights"] = [float(weight) for weight in config["restart_weights"]]
    config["num_workers"] = int(config["num_workers"])
    config["log_interval"] = int(config["log_interval"])
    config["use_amp"] = bool(config["use_amp"])
    config["grad_clip"] = float(config["grad_clip"])  # Êñ∞Â¢ûÊ¢ØÂ∫¶Ë£ÅÂâ™ÈñæÂÄºÔºåÈ†êË®≠ÁÇ∫ 1.0
    
    return config

In [5]:
if __name__ == "__main__":
    # Load configuration from YAML file
    config = load_config('config.yml')
        
    # ÂâµÂª∫Â≠òÊîæÊ®°ÂûãÁöÑË≥áÊñôÂ§æ
    os.makedirs(config["checkpoint_dir"], exist_ok=True)

    # ÂâµÂª∫Ê®°Âûã
    model = DPENet()
    model.to(config["device"])

    # ÂâµÂª∫ÊêçÂ§±ÂáΩÊï∏
    ssim_loss = SSIMLoss_v2().to(config["device"])
    edge_loss = EdgeLoss_v2().to(config["device"])
    l1_loss = L1Loss(loss_weight=1.0, reduction='mean').to(config["device"])

    # ÂâµÂª∫ÂÑ™ÂåñÂô® & Â≠∏ÁøíÁéáË™øÊï¥Âô®
    optimizer = optim.AdamW(model.parameters(), lr=config["lr"], 
                            betas=(0.9, 0.999), weight_decay=1e-4)
    scheduler = CosineAnnealingRestartCyclicLR(optimizer, periods=config["periods"], 
                                            restart_weights=config["restart_weights"], 
                                            eta_mins=[config["lr"], config["eta_min"]])

    # Ë®≠ÂÆö DataLoader
    train_dataset = RainDataset(mode='train', dataset_name='Rain13K', 
                                transform=get_transform(train=True))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config["batch_size"], 
                                            shuffle=True, num_workers=4)

    # Ë®≠ÂÆö AMPÔºàÊ∑∑ÂêàÁ≤æÂ∫¶Ôºâ
    scaler = GradScaler('cuda', enabled=config["use_amp"])
    # Ë®ìÁ∑¥Ëø¥Âúà
    print(f"ÈñãÂßãË®ìÁ∑¥ DPENet_CFIMÔºå‰ΩøÁî®Ë®≠ÂÇôÔºö{config['device']}")
    
    for epoch in range(config["epochs"]):
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{config['epochs']}]")

        for batch_idx, (input_img, target_img) in enumerate(pbar):
            input_img, target_img = input_img.to(config["device"]), target_img.to(config["device"])

            optimizer.zero_grad()

            # ÂâçÂêëÂÇ≥Êí≠ÔºàAMP Ê∑∑ÂêàÁ≤æÂ∫¶Ôºâ
            with autocast(device_type='cuda', enabled=config["use_amp"]):
                output = model(input_img)
                ssim_val = ssim_loss(output, target_img)
                edge_val = edge_loss(output, target_img)
                if torch.isnan(ssim_val) or torch.isinf(ssim_val):
                    print("‚ö† SSIM Loss Âá∫ÈåØ", ssim_val)

                if torch.isnan(edge_val) or torch.isinf(edge_val):
                    print("‚ö† Edge Loss Âá∫ÈåØ", edge_val)

                loss = ssim_val + 0.05 * edge_val

            # ÂèçÂêëÂÇ≥Êí≠
            scaler.scale(loss).backward()
            
            # Ê¢ØÂ∫¶Ë£ÅÂâ™ÔºåÈò≤Ê≠¢Ê¢ØÂ∫¶ÁàÜÁÇ∏ÊàñÈúáÁõ™
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
            
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()  # Êõ¥Êñ∞Â≠∏ÁøíÁéá

            # Ë®òÈåÑ Loss
            running_loss += loss.item()
            if batch_idx % config["log_interval"] == 0:
                pbar.set_postfix(loss=loss.item())

        avg_loss = running_loss / len(train_loader)
        print(f"üîπ Epoch [{epoch+1}/{config['epochs']}], Loss: {avg_loss:.6f}, LR: {scheduler.get_lr()[0]:.8f}")

        # ‰øùÂ≠òÊ®°ÂûãÔºàÊØè 10 ÂÄã Epoch ‰øùÂ≠ò‰∏ÄÊ¨°Ôºâ
        if (epoch + 1) % 10 == 0:
            save_path = os.path.join(config["checkpoint_dir"], f"DPENet_CFIM_epoch{epoch+1}.pth")
            torch.save(model.state_dict(), save_path)
            print(f"Ê®°ÂûãÂ∑≤‰øùÂ≠òÔºö{save_path}")

ÈñãÂßãË®ìÁ∑¥ DPENet_CFIMÔºå‰ΩøÁî®Ë®≠ÂÇôÔºöcuda


Epoch [1/10]:   0%|          | 0/857 [00:00<?, ?it/s]

Epoch [1/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0283]


üîπ Epoch [1/10], Loss: 0.028200, LR: 0.00100000


Epoch [2/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0264]


üîπ Epoch [2/10], Loss: 0.028542, LR: 0.00100000


Epoch [3/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0291]


üîπ Epoch [3/10], Loss: 0.028502, LR: 0.00100000


Epoch [4/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0317]


üîπ Epoch [4/10], Loss: 0.028511, LR: 0.00100000


Epoch [5/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0296]


üîπ Epoch [5/10], Loss: 0.028319, LR: 0.00100000


Epoch [6/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0266]


üîπ Epoch [6/10], Loss: 0.028358, LR: 0.00100000


Epoch [7/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0268]


üîπ Epoch [7/10], Loss: 0.028388, LR: 0.00100000


Epoch [8/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0202]


üîπ Epoch [8/10], Loss: 0.029198, LR: 0.00100000


Epoch [9/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0279]


üîπ Epoch [9/10], Loss: 0.028328, LR: 0.00100000


Epoch [10/10]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 857/857 [05:28<00:00,  2.61it/s, loss=0.0296]

üîπ Epoch [10/10], Loss: 0.028519, LR: 0.00100000
Ê®°ÂûãÂ∑≤‰øùÂ≠òÔºöcheckpoints/DPENet_CFIM_epoch10.pth



