In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import matplotlib.pyplot as plt
import tqdm
from torchvision.utils import save_image
import itertools
from torchvision import models

In [None]:
class ImagePairDataset(Dataset):
    def __init__(self, folder1, folder2, pairs, transform=None, enhance_transform_1 = None, 
                 enhance_transform_2 = None, enhance_transform_3 = None):
        self.folder1 = folder1
        self.folder2 = folder2
        self.pairs = pairs
        self.transform = transform
        self.enhance_transform_1 = enhance_transform_1
        self.enhance_transform_2 = enhance_transform_2
        self.enhance_transform_3 = enhance_transform_3
        self.image_pairs = self.read_image_pairs()

    def read_image_pairs(self):
      image_pairs = []
      for image_pair in tqdm.tqdm(self.pairs):
        img1_path = os.path.join(self.folder1, image_pair[0])
        img2_path = os.path.join(self.folder2, image_pair[1])
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        image_pairs.append((img1, img2))
      return image_pairs
        
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
      return self.image_pairs[idx][0], self.image_pairs[idx][1]

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

In [None]:
digital_dir = "D:/data/禄来红外400/数码"
film_dir = "D:/data/禄来红外400/样片"

In [None]:
def make_pairs(image1_dir, imgage2_dir):
    # Get the list of files in both folders
    images1 = sorted(os.listdir(image1_dir))
    images2 = sorted(os.listdir(imgage2_dir))
    
    print(len(images1))
    print(len(images2))

    # Ensure the number of files match
    if len(images1) != len(images2):
        raise ValueError("The two folders must have the same number of images.")

    # Create pairs of images (file1, file2)
    pairs = list(zip(images1, images2))
    
    return pairs

In [None]:
pairs = make_pairs(digital_dir, film_dir)

In [None]:
batch_size = 32

In [None]:
# Create datasets
train_dataset = ImagePairDataset(digital_dir, film_dir, pairs, transform=transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)

In [None]:
# ============================
# 1. Defining the Generator (ResNet Generator)
# ============================
class ResNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, num_residual_blocks=9):
        super(ResNetGenerator, self).__init__()
        
        # Using pre-trained ResNet-18
        resnet = models.resnet18()
        resnet.load_state_dict(torch.load("./resnet18-5c106cde.pth", weights_only=False))
        
        # Encoder: Get the output of each layer of ResNet and construct U-Net style skip connections
        self.input_conv = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu
        )
        self.maxpool = resnet.maxpool   # H/4 x W/4
        self.encoder1 = resnet.layer1
        self.encoder2 = resnet.layer2
        self.encoder3 = resnet.layer3
        self.encoder4 = resnet.layer4

        # Decoder: Use upsampling (bilinear interpolation) + convolution to restore details and fuse with skip connection features
        self.up1 = self._up_block(512, 256)
        self.up2 = self._up_block(512, 128)
        self.up3 = self._up_block(256, 64)
        self.up4 = self._up_block(128, 64)
        self.up5 = self._up_block(128, 64)

        self.final_conv = nn.Conv2d(64, output_nc, kernel_size=1)

    def _up_block(self, in_channels, out_channels):
        """
        Upsampling block: upsampling first (bilinear), then convolution + ReLU
        """
        block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return block

    def forward(self, x):
        h = x.shape[2]
        w = x.shape[3]

        # Encoder
        x0 = self.input_conv(x)
        x1 = self.maxpool(x0)
        x1 = self.encoder1(x1)
        x2 = self.encoder2(x1)
        x3 = self.encoder3(x2)
        x4 = self.encoder4(x3)

        # Decoder: gradually upsample and concatenate with the corresponding layer features of the encoder
        d1 = self.up1(x4) 
        # Stitching x3 (256 channels)
        d1 = torch.cat([d1, x3], dim=1)

        d2 = self.up2(d1)
        # Stitching x2 (128 channels)
        d2 = torch.cat([d2, x2], dim=1)

        d3 = self.up3(d2)
        # Stitching x1 (64 channels)
        d3 = torch.cat([d3, x1], dim=1)

        d4 = self.up4(d3)
        # Stitching x0 (64 channels)
        d4 = torch.cat([d4, x0], dim=1)

        d5 = self.up5(d4)

        out = self.final_conv(d5)
        out = torch.tanh(out)

        out = torch.nn.functional.interpolate(out, size = (h, w), mode="bilinear")

        return out

In [None]:
# ============================
# 3. Defining the Discriminator (PatchGAN)
# ============================
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        """
        Using 70×70 PatchGAN discriminator
        """
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        in_channels = 64
        out_channels = in_channels * 2
        # Add several layers of convolution
        for _ in range(3):
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            in_channels = out_channels
            out_channels = in_channels * 2

        # The last convolution layer
        model += [
            nn.Conv2d(in_channels, 1, kernel_size=4, padding=1)
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:

# ============================
# 4. Define the CycleGAN model class (add AMP support)
# ============================
class CycleGAN:
    def __init__(self, device, G_AB = "", G_BA = "", D_A = "", D_B = ""):
        self.device = device

        # Initialize two generators: G_AB (digital → film), G_BA (film → digital)
        self.G_AB = ResNetGenerator(3, 3).to(device)
        if G_AB != "":
            self.G_AB.load_state_dict(torch.load(G_AB))
        self.G_BA = ResNetGenerator(3, 3).to(device)
        if G_BA != "":
            self.G_BA.load_state_dict(torch.load(G_BA))
        # Initialize two discriminators: D_A (discriminates real digital images) and D_B (discriminates real film images)
        self.D_A = Discriminator(3).to(device)
        if D_A != "":
            self.D_A.load_state_dict(torch.load(D_A))
        self.D_B = Discriminator(3).to(device)
        if D_B != "":
            self.D_B.load_state_dict(torch.load(D_B))
        # Define loss function: adversarial loss, cycle consistency loss, identity loss
        self.criterion_GAN = nn.MSELoss().to(device)
        self.criterion_cycle = nn.L1Loss().to(device)
        self.criterion_identity = nn.L1Loss().to(device)

        # Optimizer (two generators share one optimizer)
        self.optimizer_G = optim.Adam(itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()),
                                    lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D_A = optim.Adam(self.D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D_B = optim.Adam(self.D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

        # Creating an AMP GradScaler object
        self.scaler_G = torch.amp.GradScaler(device=device)
        self.scaler_D_A = torch.amp.GradScaler(device=device)
        self.scaler_D_B = torch.amp.GradScaler(device=device)

    def set_input(self, real_A, real_B):
        self.real_A = real_A.to(self.device)
        self.real_B = real_B.to(self.device)

    def forward(self):
        # A→B→A
        self.fake_B = self.G_AB(self.real_A)
        self.rec_A = self.G_BA(self.fake_B)
        # B→A→B
        self.fake_A = self.G_BA(self.real_B)
        self.rec_B = self.G_AB(self.fake_A)

    def backward_G(self):
        # Identity loss: requires the generator to remain invariant on target domain images
        self.idt_A = self.G_BA(self.real_A)
        self.loss_idt_A = self.criterion_identity(self.idt_A, self.real_A) * 5.0
        self.idt_B = self.G_AB(self.real_B)
        self.loss_idt_B = self.criterion_identity(self.idt_B, self.real_B) * 5.0

        # Fighting Losses
        pred_fake_B = self.D_B(self.fake_B)
        target_real = torch.ones_like(pred_fake_B, device=self.device)
        loss_GAN_AB = self.criterion_GAN(pred_fake_B, target_real)

        pred_fake_A = self.D_A(self.fake_A)
        target_real = torch.ones_like(pred_fake_A, device=self.device)
        loss_GAN_BA = self.criterion_GAN(pred_fake_A, target_real)

        # Cycle consistency loss
        loss_cycle_A = self.criterion_cycle(self.rec_A, self.real_A) * 10.0
        loss_cycle_B = self.criterion_cycle(self.rec_B, self.real_B) * 10.0

        # Total generator loss (only saved, no backward call)
        self.loss_G = self.loss_idt_A + self.loss_idt_B + loss_GAN_AB + loss_GAN_BA + loss_cycle_A + loss_cycle_B

    def backward_D_basic(self, netD, real, fake):
        target_real = torch.ones_like(netD(real), device=self.device)
        target_fake = torch.zeros_like(netD(fake.detach()), device=self.device)
        loss_real = self.criterion_GAN(netD(real), target_real)
        loss_fake = self.criterion_GAN(netD(fake.detach()), target_fake)
        loss_D = (loss_real + loss_fake) * 0.5
        return loss_D

    def backward_D_A(self):
        self.loss_D_A = self.backward_D_basic(self.D_A, self.real_A, self.fake_A)

    def backward_D_B(self):
        self.loss_D_B = self.backward_D_basic(self.D_B, self.real_B, self.fake_B)

    def optimize_parameters(self):
        # --------------------------
        # Update Generator
        # --------------------------
        self.optimizer_G.zero_grad()
        with torch.cuda.amp.autocast():
            self.forward()       # Generate fake_B, fake_A, rec_A, rec_B
            self.backward_G()    # Calculate self.loss_G
        self.scaler_G.scale(self.loss_G).backward()
        self.scaler_G.step(self.optimizer_G)
        self.scaler_G.update()

        # --------------------------
        # Update the discriminator D_A
        # --------------------------
        self.optimizer_D_A.zero_grad()
        with torch.cuda.amp.autocast():
            self.backward_D_A()  # Calculate self.loss_D_A
        self.scaler_D_A.scale(self.loss_D_A).backward()
        self.scaler_D_A.step(self.optimizer_D_A)
        self.scaler_D_A.update()

        # --------------------------
        # Update the discriminator D_B
        # --------------------------
        self.optimizer_D_B.zero_grad()
        with torch.cuda.amp.autocast():
            self.backward_D_B()  # Calculate self.loss_D_B
        self.scaler_D_B.scale(self.loss_D_B).backward()
        self.scaler_D_B.step(self.optimizer_D_B)
        self.scaler_D_B.update()

In [None]:
# ============================
# 5. Training loop function
# ============================
def train(cyclegan, dataloader, num_epochs=200, save_interval=10, pre_epoch = 0):
    for epoch in range(num_epochs):
        for i, (data_A, data_B) in enumerate(dataloader):
            real_A = data_A
            real_B = data_B
            cyclegan.set_input(real_A, real_B)
            cyclegan.optimize_parameters()

            if i % 100 == 0:
                print(f"Epoch [{epoch+1 + pre_epoch}/{num_epochs + pre_epoch}] Batch [{i}] | "
                      f"Loss_G: {cyclegan.loss_G.item():.4f} | "
                      f"Loss_D_A: {cyclegan.loss_D_A.item():.4f} | "
                      f"Loss_D_B: {cyclegan.loss_D_B.item():.4f}")

        # Save output results and models regularly
        if (epoch + 1) % save_interval == 0:
            os.makedirs("output", exist_ok=True)
            with torch.no_grad():
                fake_B = cyclegan.G_AB(real_A.cuda())
                fake_A = cyclegan.G_BA(real_B.cuda())
            save_image(fake_B, f"output/fake_B_epoch_{epoch+1+pre_epoch}.png", normalize=True)
            save_image(fake_A, f"output/fake_A_epoch_{epoch+1+pre_epoch}.png", normalize=True)
            torch.save(cyclegan.G_AB.state_dict(), f"output/G_AB_epoch_{epoch+1+pre_epoch}.pth")
            torch.save(cyclegan.G_BA.state_dict(), f"output/G_BA_epoch_{epoch+1+pre_epoch}.pth")
            torch.save(cyclegan.D_A.state_dict(), f"output/D_A_epoch_{epoch+1+pre_epoch}.pth")
            torch.save(cyclegan.D_B.state_dict(), f"output/D_B_epoch_{epoch+1+pre_epoch}.pth")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
cyclegan = CycleGAN(device)

In [None]:
cyclegan = CycleGAN(device, G_AB="output/G_AB_epoch_200.pth", 
                    G_BA="output/G_BA_epoch_200.pth", 
                    D_A="output/D_A_epoch_200.pth", 
                    D_B="output/D_B_epoch_200.pth")

In [None]:
train(cyclegan, train_loader, num_epochs=200, pre_epoch = 0, save_interval=100)