AetherNet 是一个基于设计模式构建的图像恢复研究框架,整合了多种先进的图像超分辨率、去噪和压缩伪影去除模型。项目采用工厂模式、策略模式、模板方法和注册表模式,提供统一且可扩展的训练与测试接口。
研究状态: 开发中
- 设计模式驱动: 工厂、策略、模板方法、注册表模式,代码清晰易扩展
- 即插即用模型: SwinIR、MambaIRv2、CWFNet、EDSR、ELAN、SRFormer
- 配置驱动训练: YAML配置 + 命令行参数,灵活控制实验
- 多任务支持: 经典SR、真实世界去噪、JPEG伪影去除
全局注册表管理模型、数据集、损失函数等组件:
from core import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class MyModel(nn.Module):
...
# 获取已注册模型
model_cls = MODEL_REGISTRY.get('MyModel')统一的模型创建接口:
from core import ModelFactory
# 通过配置创建
config = {'model_name': 'SwinIR', 'upscale': 4, 'embed_dim': 180}
model = ModelFactory.create(config)
# 通过名称创建
model = ModelFactory.create_by_name('CWFNet', upscale=4)灵活的图像退化策略:
from core import ClassicalSRDegradation, ColorDenoiseDegradation
# 经典SR退化(双三次下采样)
degradation = ClassicalSRDegradation(scale=4)
lr_image = degradation.apply(hr_image)
# 去噪退化(高斯噪声)
degradation = ColorDenoiseDegradation(sigma=25)
noisy_image = degradation.apply(clean_image)标准化的训练/测试流程:
from core import BaseTrainer
class SwinIRTrainer(BaseTrainer):
def build_model(self):
return SwinIR(**self.config['model'])
def train_step(self, batch):
lr, hr = batch
output = self.model(lr)
loss = self.criterion(output, hr)
...
return loss.item()
trainer = SwinIRTrainer(config)
trainer.train()- Python 3.8+
- PyTorch 1.12+
- CUDA 11.0+ (推荐)
git clone https://github.com/Jackksonns/AetherNet.git
cd AetherNet
# 创建虚拟环境(推荐)
conda create -n aethernet python=3.9
conda activate aethernet
# 安装PyTorch(根据CUDA版本选择)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
# 安装依赖
pip install -r requirements.txt
# MambaIRv2支持(可选)
pip install mamba-ssm# 下载DIV2K数据集
cd datasets
bash organize_div2k.sh
# 目录结构
datasets/
├── DIV2K/
│ ├── train/
│ │ ├── HR/ # 高分辨率图像
│ │ └── LR_bicubic/ # 双三次下采样图像
│ │ ├── X2/
│ │ ├── X3/
│ │ └── X4/
│ └── val/参考 datasets/dataset.py 实现自定义数据集类。
# 使用配置文件训练
python scripts/train/train.py --config configs/train_MambaIRv2_SR_x4.yml
# 训练SwinIR
python scripts/train/train_swinir.py --scale 4 --batch_size 8 --epochs 100
# 训练CWFNet
python scripts/train/train_cwfnet.py --scale 4 --use_wavelet --use_cvoca# 测试SwinIR
python scripts/test/test_swinir.py --checkpoint weights/swinir_x4.pth
# 测试CWFNet
python scripts/test/test_cwfnet.py --checkpoint weights/cwfnet_x4.pthfrom models import SwinIR
from core import ModelFactory
import torch
# 方式1:直接导入
model = SwinIR(upscale=4, in_chans=3, img_size=64,
window_size=8, img_range=1., depths=[6]*6,
embed_dim=180, num_heads=[6]*6, mlp_ratio=2,
upsampler='pixelshuffle', resi_connection='1conv')
# 方式2:工厂创建
model = ModelFactory.create_by_name('SwinIR', upscale=4)
# 加载权重
model.load_state_dict(torch.load('weights/model.pth'))
model.eval()
# 推理
with torch.no_grad():
sr_image = model(lr_image)| 模型 | Set5 (PSNR/SSIM) | Set14 (PSNR/SSIM) | BSD100 (PSNR/SSIM) | Urban100 (PSNR/SSIM) | 参数量 |
|---|---|---|---|---|---|
| EDSR | 32.46/0.8968 | 28.80/0.7876 | 27.71/0.7420 | 26.64/0.8033 | 43M |
| SwinIR | 32.72/0.9021 | 28.94/0.7914 | 27.83/0.7459 | 27.07/0.8164 | 12M |
| MambaIRv2 | 32.85/0.9031 | 29.05/0.7931 | 27.92/0.7481 | 27.35/0.8215 | 10M |
| CWFNet | TBD | TBD | TBD | TBD | TBD |
实验结果持续更新中
# configs/train_swinir_x4.yml
model:
name: SwinIR
upscale: 4
in_chans: 3
img_size: 64
window_size: 8
embed_dim: 180
depths: [6, 6, 6, 6, 6, 6]
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: pixelshuffle
training:
batch_size: 8
epochs: 100
lr: 1e-4
scheduler:
type: cosine
data:
train_dir: datasets/DIV2K/train
val_dir: datasets/DIV2K/val
patch_size: 64如果本项目对您的研究有帮助,请引用相关工作: waiting for updation...
AetherNet - 基于设计模式的图像恢复研究框架
Developed by Jackksonns
