In [1]:
from dataset import create_dataset
from model.UNet import UNet
from utils.engine import RFDiffusionTrainer
from utils.tools import train_one_epoch, load_yaml, train_parse_option
import torch
from utils.callbacks import ModelCheckpoint, set_seed, EarlyStopping
from utils.RectifiedFlow import RectifiedFlow
import sys
import os

In [2]:
set_seed(42)  # 设置随机种子为42

In [3]:
# 模拟命令行输入
sys.argv = [
            'train.py',
            "--trainer", "rf", 
            "--model", "unet",
            "--scheduler", "ReduceLR"
           ]


In [4]:
def train(config, args):
    consume = config["consume"]
    if consume:
        if args.trainer == 'rf':
           cp = torch.load(config["rf_consume_path"])
        config = cp["config"]
    print(config)

    device = torch.device(config["device"])
    loader = create_dataset(**config["Dataset"])
    start_epoch = 1

    if args.model == 'unet':
       model = UNet(**config["Model"]).to(device)
        
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-4)
    if args.trainer == 'rf':
       trainer = RFDiffusionTrainer(model).to(device)

    if args.trainer == 'rf':
       model_checkpoint = ModelCheckpoint(**config["RF_Callback"])

    # Add learning rate scheduler
    if args.scheduler == 'StepLR':
       scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    elif args.scheduler == 'ReduceLR':
       scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1)

    # 初始化 EarlyStopping
    early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=10)  

    if consume:
        model.load_state_dict(cp["model"])
        optimizer.load_state_dict(cp["optimizer"])
        model_checkpoint.load_state_dict(cp["model_checkpoint"])
        start_epoch = cp["start_epoch"] + 1

    if args.trainer == 'rf':
        for epoch in range(start_epoch, config["epochs"] + 1):
            loss = train_one_epoch(trainer, loader, optimizer, device, epoch, config["epochs"])
            # Step the scheduler every epoch
            scheduler.step(loss)
            # Save checkpoint
            model_checkpoint.step(loss, model=model.state_dict(), config=config,
                                  optimizer=optimizer.state_dict(), start_epoch=epoch,
                                  model_checkpoint=model_checkpoint.state_dict())
            # 检查是否早停
            if args.earlystopping == True:
               if early_stopping.step(loss):
                  print(f"Early stopping at epoch {epoch}")
                  break
        
    



In [5]:
args = train_parse_option()
config = load_yaml("config.yml", encoding="utf-8")
train(config, args)

{'Model': {'in_channels': 3, 'out_channels': 3, 'model_channels': 128, 'attention_resolutions': [2], 'num_res_blocks': 2, 'dropout': 0.1, 'channel_mult': [1, 2, 2, 2], 'conv_resample': True, 'num_heads': 8, 'num_classes': 10, 'image_w': 32, 'image_h': 32}, 'Classifier_Model': {'num_classes': 10}, 'Dataset': {'dataset': 'cwru', 'train': True, 'data_path': './data/cwru/BR1_200_train_set_balance', 'download': False, 'image_size': [32, 32], 'mode': 'RGB', 'suffix': ['png', 'jpg'], 'batch_size': 64, 'shuffle': True, 'drop_last': True, 'pin_memory': True, 'num_workers': 4}, 'Classifier_Dataset_train': {'dataset': 'classifier_train', 'train_data_path': './data/cwru_rf_result/cwru_sampler_br1_5_500epoch', 'image_size': [32, 32], 'mode': 'RGB', 'suffix': ['png', 'jpg'], 'batch_size': 64, 'drop_last': True, 'pin_memory': True, 'num_workers': 4}, 'Classifier_Dataset_test': {'dataset': 'classifier_test', 'test_data_path': './data/cwru/test_set', 'image_size': [32, 32], 'mode': 'RGB', 'suffix': ['p

Epoch: 1/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:22<00:00, 22.74s/it, train_loss=0.0324]
Epoch: 2/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:21<00:00, 21.07s/it, train_loss=0.0243]
Epoch: 3/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.86s/it, train_loss=0.0204]
Epoch: 4/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.94s/it, train_loss=0.0185]
Epoch: 5/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.80s/it, train_loss=0.0172]
Epoch: 6/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.82s/it, train_loss=0.0163]
Epoch: 7/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.83s/it, train_loss=0.0157]
Epoch: 8/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.83s/it, train_loss=0.015]
Epoch: 9/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.86s/it, train_loss=0.0144]
Epoch: 10/10: 100%|[38;2;255;146;74m██████████[0m| 1/1 [00:20<00:00, 20.99s/it, train_loss=0.0138]
