In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import matplotlib.pyplot as plt
import tqdm
from torchvision.utils import save_image
import itertools
from torchvision import models

In [2]:
class ImagePairDataset(Dataset):
    def __init__(self, folder1, folder2, pairs, transform=None, enhance_transform_1 = None, 
                 enhance_transform_2 = None, enhance_transform_3 = None):
        self.folder1 = folder1
        self.folder2 = folder2
        self.pairs = pairs
        self.transform = transform
        self.enhance_transform_1 = enhance_transform_1
        self.enhance_transform_2 = enhance_transform_2
        self.enhance_transform_3 = enhance_transform_3
        self.image_pairs = self.read_image_pairs()

    def read_image_pairs(self):
      image_pairs = []
      for image_pair in tqdm.tqdm(self.pairs):
        img1_path = os.path.join(self.folder1, image_pair[0])
        img2_path = os.path.join(self.folder2, image_pair[1])
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        image_pairs.append((img1, img2))
      return image_pairs
        
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
      return self.image_pairs[idx][0], self.image_pairs[idx][1]

In [3]:
transform = transforms.Compose([
    transforms.Resize((128, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

In [4]:
digital_dir = "D:/data/禄来红外400/数码"
film_dir = "D:/data/禄来红外400/样片"

In [5]:
def make_pairs(image1_dir, imgage2_dir):
    # Get the list of files in both folders
    images1 = sorted(os.listdir(image1_dir))
    images2 = sorted(os.listdir(imgage2_dir))
    
    print(len(images1))
    print(len(images2))

    # Ensure the number of files match
    if len(images1) != len(images2):
        raise ValueError("The two folders must have the same number of images.")

    # Create pairs of images (file1, file2)
    pairs = list(zip(images1, images2))
    
    return pairs

In [6]:
pairs = make_pairs(digital_dir, film_dir)

242
242


In [7]:
batch_size = 32

In [8]:
# Create datasets
train_dataset = ImagePairDataset(digital_dir, film_dir, pairs, transform=transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)

100%|██████████| 242/242 [00:23<00:00, 10.28it/s]


In [9]:
# ============================
# 1. 定义生成器 (ResNet Generator)
# ============================
class ResNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, num_residual_blocks=9):
        super(ResNetGenerator, self).__init__()
        
        # 使用预训练的 ResNet-18
        resnet = models.resnet18()
        resnet.load_state_dict(torch.load("./resnet18-5c106cde.pth", weights_only=False))
        
        # 编码器：获取 ResNet 的各层输出，构造 U-Net 风格的跳跃连接
        self.input_conv = nn.Sequential(
            resnet.conv1,   # 输出: 64通道, H/2 x W/2
            resnet.bn1,
            resnet.relu
        )
        self.maxpool = resnet.maxpool   # H/4 x W/4
        self.encoder1 = resnet.layer1   # 输出: 64通道, H/4 x W/4
        self.encoder2 = resnet.layer2   # 输出: 128通道, H/8 x W/8
        self.encoder3 = resnet.layer3   # 输出: 256通道, H/16 x W/16
        self.encoder4 = resnet.layer4   # 输出: 512通道, H/32 x W/32

        # 解码器：使用上采样（双线性插值）+ 卷积来恢复细节，并与跳跃连接特征融合
        self.up1 = self._up_block(512, 256)  # H/32 -> H/16
        self.up2 = self._up_block(512, 128)  # 256(from up1)+256(encoder3) -> H/16 -> H/8
        self.up3 = self._up_block(256, 64)   # 128(from up2)+128(encoder2) -> H/8 -> H/4
        self.up4 = self._up_block(128, 64)   # 64(from up3)+64(encoder1) -> H/4 -> H/2
        self.up5 = self._up_block(128, 64)   # 64(from up4)+64(from input_conv) -> H/2 -> H

        self.final_conv = nn.Conv2d(64, output_nc, kernel_size=1)

    def _up_block(self, in_channels, out_channels):
        """
        上采样块：先上采样（双线性），再卷积+ReLU
        """
        block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return block

    def forward(self, x):
        h = x.shape[2]
        w = x.shape[3]

        # 编码器部分
        x0 = self.input_conv(x)       # x0: [B, 64, H/2, W/2]
        x1 = self.maxpool(x0)         # x1: [B, 64, H/4, W/4]
        x1 = self.encoder1(x1)        # x1: [B, 64, H/4, W/4]
        x2 = self.encoder2(x1)        # x2: [B, 128, H/8, W/8]
        x3 = self.encoder3(x2)        # x3: [B, 256, H/16, W/16]
        x4 = self.encoder4(x3)        # x4: [B, 512, H/32, W/32]

        # 解码器部分：逐步上采样并与编码器对应层特征拼接
        d1 = self.up1(x4)             # d1: [B, 256, H/16, W/16]
        # 拼接 x3 (256通道)
        d1 = torch.cat([d1, x3], dim=1)  # [B, 256+256=512, H/16, W/16]

        d2 = self.up2(d1)             # d2: [B, 128, H/8, W/8]
        # 拼接 x2 (128通道)
        d2 = torch.cat([d2, x2], dim=1)  # [B, 128+128=256, H/8, W/8]

        d3 = self.up3(d2)             # d3: [B, 64, H/4, W/4]
        # 拼接 x1 (64通道)
        d3 = torch.cat([d3, x1], dim=1)  # [B, 64+64=128, H/4, W/4]

        d4 = self.up4(d3)             # d4: [B, 64, H/2, W/2]
        # 拼接 x0 (64通道)
        d4 = torch.cat([d4, x0], dim=1)  # [B, 64+64=128, H/2, W/2]

        d5 = self.up5(d4)             # d5: [B, 64, H, W]

        out = self.final_conv(d5)     # out: [B, output_nc, H, W]
        out = torch.tanh(out)

        out = torch.nn.functional.interpolate(out, size = (h, w), mode="bilinear")

        return out

In [10]:
# ============================
# 3. 定义判别器 (PatchGAN)
# ============================
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        """
        使用 70×70 PatchGAN 判别器
        """
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        in_channels = 64
        out_channels = in_channels * 2
        # 增加几层卷积
        for _ in range(3):
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            in_channels = out_channels
            out_channels = in_channels * 2

        # 最后一层卷积
        model += [
            nn.Conv2d(in_channels, 1, kernel_size=4, padding=1)
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [11]:

# ============================
# 4. 定义 CycleGAN 模型类（加入 AMP 支持）
# ============================
class CycleGAN:
    def __init__(self, device, G_AB = "", G_BA = "", D_A = "", D_B = ""):
        self.device = device

        # 初始化两个生成器：G_AB（数码→胶片）、G_BA（胶片→数码）
        self.G_AB = ResNetGenerator(3, 3).to(device)
        if G_AB != "":
            self.G_AB.load_state_dict(torch.load(G_AB))
        self.G_BA = ResNetGenerator(3, 3).to(device)
        if G_BA != "":
            self.G_BA.load_state_dict(torch.load(G_BA))
        # 初始化两个判别器：D_A（判别真实数码图像）、D_B（判别真实胶片图像）
        self.D_A = Discriminator(3).to(device)
        if D_A != "":
            self.D_A.load_state_dict(torch.load(D_A))
        self.D_B = Discriminator(3).to(device)
        if D_B != "":
            self.D_B.load_state_dict(torch.load(D_B))
        # 定义损失函数：对抗损失、循环一致性损失、身份损失
        self.criterion_GAN = nn.MSELoss().to(device)
        self.criterion_cycle = nn.L1Loss().to(device)
        self.criterion_identity = nn.L1Loss().to(device)

        # 优化器（两个生成器共用一个优化器）
        self.optimizer_G = optim.Adam(itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()),
                                    lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D_A = optim.Adam(self.D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D_B = optim.Adam(self.D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

        # 创建 AMP GradScaler 对象
        self.scaler_G = torch.amp.GradScaler(device=device)
        self.scaler_D_A = torch.amp.GradScaler(device=device)
        self.scaler_D_B = torch.amp.GradScaler(device=device)

    def set_input(self, real_A, real_B):
        self.real_A = real_A.to(self.device)
        self.real_B = real_B.to(self.device)

    def forward(self):
        # A→B→A
        self.fake_B = self.G_AB(self.real_A)
        self.rec_A = self.G_BA(self.fake_B)
        # B→A→B
        self.fake_A = self.G_BA(self.real_B)
        self.rec_B = self.G_AB(self.fake_A)

    def backward_G(self):
        # 身份损失：要求生成器在目标域图像上保持不变
        self.idt_A = self.G_BA(self.real_A)
        self.loss_idt_A = self.criterion_identity(self.idt_A, self.real_A) * 5.0
        self.idt_B = self.G_AB(self.real_B)
        self.loss_idt_B = self.criterion_identity(self.idt_B, self.real_B) * 5.0

        # 对抗损失
        pred_fake_B = self.D_B(self.fake_B)
        target_real = torch.ones_like(pred_fake_B, device=self.device)
        loss_GAN_AB = self.criterion_GAN(pred_fake_B, target_real)

        pred_fake_A = self.D_A(self.fake_A)
        target_real = torch.ones_like(pred_fake_A, device=self.device)
        loss_GAN_BA = self.criterion_GAN(pred_fake_A, target_real)

        # 循环一致性损失
        loss_cycle_A = self.criterion_cycle(self.rec_A, self.real_A) * 10.0
        loss_cycle_B = self.criterion_cycle(self.rec_B, self.real_B) * 10.0

        # 总生成器损失（仅保存，不调用 backward）
        self.loss_G = self.loss_idt_A + self.loss_idt_B + loss_GAN_AB + loss_GAN_BA + loss_cycle_A + loss_cycle_B

    def backward_D_basic(self, netD, real, fake):
        target_real = torch.ones_like(netD(real), device=self.device)
        target_fake = torch.zeros_like(netD(fake.detach()), device=self.device)
        loss_real = self.criterion_GAN(netD(real), target_real)
        loss_fake = self.criterion_GAN(netD(fake.detach()), target_fake)
        loss_D = (loss_real + loss_fake) * 0.5
        return loss_D

    def backward_D_A(self):
        self.loss_D_A = self.backward_D_basic(self.D_A, self.real_A, self.fake_A)

    def backward_D_B(self):
        self.loss_D_B = self.backward_D_basic(self.D_B, self.real_B, self.fake_B)

    def optimize_parameters(self):
        # --------------------------
        # 更新生成器
        # --------------------------
        self.optimizer_G.zero_grad()
        with torch.cuda.amp.autocast():
            self.forward()       # 生成 fake_B、fake_A、rec_A、rec_B
            self.backward_G()    # 计算 self.loss_G
        self.scaler_G.scale(self.loss_G).backward()
        self.scaler_G.step(self.optimizer_G)
        self.scaler_G.update()

        # --------------------------
        # 更新判别器 D_A
        # --------------------------
        self.optimizer_D_A.zero_grad()
        with torch.cuda.amp.autocast():
            self.backward_D_A()  # 计算 self.loss_D_A
        self.scaler_D_A.scale(self.loss_D_A).backward()
        self.scaler_D_A.step(self.optimizer_D_A)
        self.scaler_D_A.update()

        # --------------------------
        # 更新判别器 D_B
        # --------------------------
        self.optimizer_D_B.zero_grad()
        with torch.cuda.amp.autocast():
            self.backward_D_B()  # 计算 self.loss_D_B
        self.scaler_D_B.scale(self.loss_D_B).backward()
        self.scaler_D_B.step(self.optimizer_D_B)
        self.scaler_D_B.update()

In [12]:
# ============================
# 5. 训练循环函数
# ============================
def train(cyclegan, dataloader, num_epochs=200, save_interval=10, pre_epoch = 0):
    for epoch in range(num_epochs):
        for i, (data_A, data_B) in enumerate(dataloader):
            real_A = data_A
            real_B = data_B
            cyclegan.set_input(real_A, real_B)
            cyclegan.optimize_parameters()

            if i % 100 == 0:
                print(f"Epoch [{epoch+1 + pre_epoch}/{num_epochs + pre_epoch}] Batch [{i}] | "
                      f"Loss_G: {cyclegan.loss_G.item():.4f} | "
                      f"Loss_D_A: {cyclegan.loss_D_A.item():.4f} | "
                      f"Loss_D_B: {cyclegan.loss_D_B.item():.4f}")

        # 定期保存输出结果及模型
        if (epoch + 1) % save_interval == 0:
            os.makedirs("output", exist_ok=True)
            with torch.no_grad():
                fake_B = cyclegan.G_AB(real_A.cuda())
                fake_A = cyclegan.G_BA(real_B.cuda())
            save_image(fake_B, f"output/fake_B_epoch_{epoch+1+pre_epoch}.png", normalize=True)
            save_image(fake_A, f"output/fake_A_epoch_{epoch+1+pre_epoch}.png", normalize=True)
            torch.save(cyclegan.G_AB.state_dict(), f"output/G_AB_epoch_{epoch+1+pre_epoch}.pth")
            torch.save(cyclegan.G_BA.state_dict(), f"output/G_BA_epoch_{epoch+1+pre_epoch}.pth")
            torch.save(cyclegan.D_A.state_dict(), f"output/D_A_epoch_{epoch+1+pre_epoch}.pth")
            torch.save(cyclegan.D_B.state_dict(), f"output/D_B_epoch_{epoch+1+pre_epoch}.pth")


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
# 初始化 CycleGAN 模型
cyclegan = CycleGAN(device)

In [14]:
cyclegan = CycleGAN(device, G_AB="output/G_AB_epoch_200.pth", 
                    G_BA="output/G_BA_epoch_200.pth", 
                    D_A="output/D_A_epoch_200.pth", 
                    D_B="output/D_B_epoch_200.pth")

  self.scaler_G = torch.cuda.amp.GradScaler()
  self.scaler_D_A = torch.cuda.amp.GradScaler()
  self.scaler_D_B = torch.cuda.amp.GradScaler()


In [15]:
# 开始训练
train(cyclegan, train_loader, num_epochs=200, pre_epoch = 0, save_interval=100)

  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch [1/200] Batch [0] | Loss_G: 17.1859 | Loss_D_A: 0.7276 | Loss_D_B: 0.6201
Epoch [2/200] Batch [0] | Loss_G: 15.1006 | Loss_D_A: 0.4649 | Loss_D_B: 0.3228
Epoch [3/200] Batch [0] | Loss_G: 10.0335 | Loss_D_A: 0.2056 | Loss_D_B: 0.2335
Epoch [4/200] Batch [0] | Loss_G: 7.4754 | Loss_D_A: 0.1524 | Loss_D_B: 0.1834
Epoch [5/200] Batch [0] | Loss_G: 6.4414 | Loss_D_A: 0.1901 | Loss_D_B: 0.1219
Epoch [6/200] Batch [0] | Loss_G: 5.3579 | Loss_D_A: 0.1497 | Loss_D_B: 0.1242
Epoch [7/200] Batch [0] | Loss_G: 5.5880 | Loss_D_A: 0.1069 | Loss_D_B: 0.1188
Epoch [8/200] Batch [0] | Loss_G: 6.0971 | Loss_D_A: 0.2743 | Loss_D_B: 0.1996
Epoch [9/200] Batch [0] | Loss_G: 5.3487 | Loss_D_A: 0.1404 | Loss_D_B: 0.0966
Epoch [10/200] Batch [0] | Loss_G: 4.8820 | Loss_D_A: 0.0975 | Loss_D_B: 0.0938
Epoch [11/200] Batch [0] | Loss_G: 6.0928 | Loss_D_A: 0.2463 | Loss_D_B: 0.3102
Epoch [12/200] Batch [0] | Loss_G: 4.6406 | Loss_D_A: 0.1275 | Loss_D_B: 0.1369
Epoch [13/200] Batch [0] | Loss_G: 5.0668 | Lo

# 测试

In [13]:
device = torch.device("cpu")

In [14]:
cyclegan = CycleGAN(device, G_AB="output/G_AB_epoch_100.pth", 
                    G_BA="output/G_BA_epoch_100.pth", 
                    D_A="output/D_A_epoch_100.pth", 
                    D_B="output/D_B_epoch_100.pth")

  self.scaler_G = torch.cuda.amp.GradScaler()
  self.scaler_D_A = torch.cuda.amp.GradScaler()
  self.scaler_D_B = torch.cuda.amp.GradScaler()


In [15]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

In [16]:
test_digital_dir = "D:/data/柯达金200/微调数据集/测试照片"
test_film_dir = "D:/data/柯达金200/微调数据集/测试胶片"

In [17]:
pairs = make_pairs(test_digital_dir, test_film_dir)

1
1


In [18]:
test_dataset = ImagePairDataset(test_digital_dir, test_film_dir, pairs, transform=transform)

100%|██████████| 1/1 [00:01<00:00,  1.32s/it]


In [19]:
# 测试阶段：展示结果
with torch.no_grad():
    img1, img2 = test_dataset[0]  # 取一个示例
    img1 = img1.unsqueeze(0).to(device)
    img2 = img2.unsqueeze(0).to(device)  # 增加批量维度
    with torch.no_grad():
        cyclegan.G_AB.eval()
        outputA = cyclegan.G_AB(img1)
        outputB = cyclegan.G_BA(img2)
    # 将结果从 Tensor 转换回图片
    outputA = outputA.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    outputB = outputB.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    img1 = img1.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    img2 = img2.squeeze(0).cpu().numpy().transpose(1, 2, 0)

    # 使用 matplotlib 显示图片
    fig, axes = plt.subplots(1, 4, figsize=(12, 4))
    axes[0].imshow(img1)
    axes[0].set_title("Input Image 1")
    axes[1].imshow(img2)
    axes[1].set_title("Input Image 2")
    axes[2].imshow(outputA)
    axes[2].set_title("Predicted Image A")
    axes[3].imshow(outputB)
    axes[3].set_title("Predicted Image B")
    plt.show()

  x = torch.nn.functional.upsample(x, size = (h, w), mode="bilinear")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.970074..0.9999682].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.99999946..0.99990547].


: 