## Image-to-Chinese Ink Painting Style Transfer Using CycleGAN

In the final project, we intended to use CycleGAN to conduct style transfer.

During this process, we tried different methods to train our models.

Below is the entire process.

(1)'the final model' refers to the third model, which is also the final code we used to train CycleGAN.

(2)'Previous attempt 1' is the second model we trained, using CycleGAN + VGG + self-attention + brush loss. However, according to the test results, the performance of this model is not good, because the converted image does not resemble Chinese ink painting.

(3)'Previous attempt 2' is the first model we trained, using CycleGAN + CGG + self-attention. According to the test results, the style transfer effect was quite good, but all the results had low pixel quality. Therefore, we decided not to use this model.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!unzip /content/drive/MyDrive/DLfinal/CIP_dataset1_process.zip

Archive:  /content/drive/MyDrive/DLfinal/CIP_dataset1_process.zip
replace CIP_dataset1_process/testA/cliff_2608866102_1de76908a4.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

## The final model

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from torchvision.utils import save_image
from torchvision.models import vgg16, VGG16_Weights
from PIL import Image

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset class
class ImageDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A) if os.path.isfile(os.path.join(root_A, f))])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B) if os.path.isfile(os.path.join(root_B, f))])
        self.transform = transform

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def __getitem__(self, idx):
        file_A = self.files_A[idx % len(self.files_A)]
        file_B = self.files_B[idx % len(self.files_B)]
        img_A = Image.open(file_A).convert("RGB")
        img_B = Image.open(file_B).convert("RGB")

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

# Transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load datasets
dataset = ImageDataset(
    root_A="/content/CIP_dataset1_process/trainA",
    root_B="/content/CIP_dataset1_process/trainB",
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 16, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        residual = self.block(x)
        attention = self.se(residual)
        return x + residual * attention

# Generator
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_residual_blocks=12):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.down = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(num_residual_blocks)]
        )
        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.output = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down(x)
        x = self.res_blocks(x)
        x = self.up(x)
        x = self.output(x)
        return x

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Initialize models
generator_A2B = Generator(3, 3).to(device)
generator_B2A = Generator(3, 3).to(device)
discriminator_A = Discriminator(3).to(device)
discriminator_B = Discriminator(3).to(device)

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# VGG Loss
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = vgg16(weights=VGG16_Weights.DEFAULT).features[:16].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.L1Loss()

    def forward(self, input, target):
        input_features = self.vgg(input)
        target_features = self.vgg(target)
        loss = self.criterion(input_features, target_features)
        return loss

vgg_loss_fn = VGGPerceptualLoss().to(device)

# Optimizers
optimizer_G = optim.Adam(
    list(generator_A2B.parameters()) + list(generator_B2A.parameters()),
    lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(discriminator_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(discriminator_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Hyperparameters
epochs = 20
lambda_cycle = 10.0

# Training loop
for epoch in range(epochs):
    for i, batch in enumerate(dataloader):
        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)

        # Generate fake images
        fake_B = generator_A2B(real_A)
        fake_A = generator_B2A(real_B)

        # GAN Loss
        loss_GAN_A2B = criterion_GAN(discriminator_B(fake_B), torch.ones_like(discriminator_B(fake_B)))
        loss_GAN_B2A = criterion_GAN(discriminator_A(fake_A), torch.ones_like(discriminator_A(fake_A)))

        # Cycle Loss
        recov_A = generator_B2A(fake_B)
        recov_B = generator_A2B(fake_A)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        # VGG Perceptual Loss
        loss_vgg_A2B = vgg_loss_fn(fake_B, real_B)
        loss_vgg_B2A = vgg_loss_fn(fake_A, real_A)

        # Total Generator Loss
        loss_G = loss_GAN_A2B + loss_GAN_B2A + lambda_cycle * (loss_cycle_A + loss_cycle_B) + loss_vgg_A2B + loss_vgg_B2A

        # Update Generators
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # Discriminator Loss
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        loss_real_A = criterion_GAN(discriminator_A(real_A), torch.ones_like(discriminator_A(real_A)))
        loss_fake_A = criterion_GAN(discriminator_A(fake_A.detach()), torch.zeros_like(discriminator_A(fake_A)))
        loss_D_A = (loss_real_A + loss_fake_A) / 2

        loss_real_B = criterion_GAN(discriminator_B(real_B), torch.ones_like(discriminator_B(real_B)))
        loss_fake_B = criterion_GAN(discriminator_B(fake_B.detach()), torch.zeros_like(discriminator_B(fake_B)))
        loss_D_B = (loss_real_B + loss_fake_B) / 2

        # Update Discriminators
        loss_D_A.backward()
        optimizer_D_A.step()
        loss_D_B.backward()
        optimizer_D_B.step()

        print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {(loss_D_A + loss_D_B).item()}] [G loss: {loss_G.item()}]")


[Epoch 0/20] [Batch 0/125] [D loss: 1.435420274734497] [G loss: 15.446497917175293]
[Epoch 0/20] [Batch 1/125] [D loss: 5.483593463897705] [G loss: 14.646185874938965]
[Epoch 0/20] [Batch 2/125] [D loss: 2.009331703186035] [G loss: 12.646941184997559]
[Epoch 0/20] [Batch 3/125] [D loss: 1.7499877214431763] [G loss: 12.519965171813965]
[Epoch 0/20] [Batch 4/125] [D loss: 1.3720650672912598] [G loss: 12.281742095947266]
[Epoch 0/20] [Batch 5/125] [D loss: 1.1149532794952393] [G loss: 11.182249069213867]
[Epoch 0/20] [Batch 6/125] [D loss: 0.6254368424415588] [G loss: 9.932400703430176]
[Epoch 0/20] [Batch 7/125] [D loss: 0.6335865259170532] [G loss: 9.074888229370117]
[Epoch 0/20] [Batch 8/125] [D loss: 0.7213040590286255] [G loss: 9.52000904083252]
[Epoch 0/20] [Batch 9/125] [D loss: 0.7534567713737488] [G loss: 9.733882904052734]
[Epoch 0/20] [Batch 10/125] [D loss: 0.8575993180274963] [G loss: 9.211196899414062]
[Epoch 0/20] [Batch 11/125] [D loss: 1.023082971572876] [G loss: 9.143610

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torchvision.transforms.functional import gaussian_blur
from PIL import Image

# Dataset Definition
class ImageDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A) if os.path.isfile(os.path.join(root_A, f))])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B) if os.path.isfile(os.path.join(root_B, f))])
        self.transform = transform

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def __getitem__(self, idx):
        file_A = self.files_A[idx % len(self.files_A)]
        file_B = self.files_B[idx % len(self.files_B)]

        img_A = Image.open(file_A).convert("RGB")
        img_B = Image.open(file_B).convert("RGB")

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

# Residual Block
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

# Generator
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),

            *[ResBlock(256) for _ in range(6)],

            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, out_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Brushstroke Loss
def brushstroke_loss(fake, real):
    grad_fake_x = torch.abs(fake[:, :, :, :-1] - fake[:, :, :, 1:])
    grad_fake_y = torch.abs(fake[:, :, :-1, :] - fake[:, :, 1:, :])
    grad_real_x = torch.abs(real[:, :, :, :-1] - real[:, :, :, 1:])
    grad_real_y = torch.abs(real[:, :, :-1, :] - real[:, :, 1:, :])

    # Padding to restore original dimensions
    grad_fake_x = nn.functional.pad(grad_fake_x, (0, 1, 0, 0))
    grad_fake_y = nn.functional.pad(grad_fake_y, (0, 0, 0, 1))
    grad_real_x = nn.functional.pad(grad_real_x, (0, 1, 0, 0))
    grad_real_y = nn.functional.pad(grad_real_y, (0, 0, 0, 1))

    grad_fake = grad_fake_x + grad_fake_y
    grad_real = grad_real_x + grad_real_y
    return torch.mean(torch.abs(grad_fake - grad_real))

# Ink Wash Loss
def ink_wash_loss(fake, real):
    fake_blur = gaussian_blur(fake, kernel_size=5, sigma=2)
    real_blur = gaussian_blur(real, kernel_size=5, sigma=2)
    return torch.mean((fake_blur - real_blur) ** 2)

# Loss Functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# Hyperparameters
lr = 0.0001
batch_size = 8
epochs = 50
lambda_cycle = 15.0
lambda_brushstroke = 5.0
lambda_inkwash = 0.5

# Dataset and DataLoader
transform = Compose([
    Resize((256, 256)),
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageDataset("/content/CIP_dataset1_process/trainA", "/content/CIP_dataset1_process/trainB", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize Models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator_A2B = Generator(3, 3).to(device)
generator_B2A = Generator(3, 3).to(device)
discriminator_A = Discriminator(3).to(device)
discriminator_B = Discriminator(3).to(device)

# Optimizers
optimizer_G = optim.Adam(
    list(generator_A2B.parameters()) + list(generator_B2A.parameters()), lr=lr, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(discriminator_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(discriminator_B.parameters(), lr=lr, betas=(0.5, 0.999))

# Training Loop
for epoch in range(epochs):
    for i, batch in enumerate(dataloader):
        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)

        # Train Generators
        optimizer_G.zero_grad()
        fake_B = generator_A2B(real_A)
        fake_A = generator_B2A(real_B)

        # Compute Losses
        loss_GAN_A2B = criterion_GAN(discriminator_B(fake_B), torch.ones_like(discriminator_B(fake_B)))
        loss_GAN_B2A = criterion_GAN(discriminator_A(fake_A), torch.ones_like(discriminator_A(fake_A)))
        loss_cycle = lambda_cycle * (criterion_cycle(generator_B2A(fake_B), real_A) + criterion_cycle(generator_A2B(fake_A), real_B))
        loss_brushstroke = lambda_brushstroke * (brushstroke_loss(fake_B, real_B) + brushstroke_loss(fake_A, real_A))
        loss_inkwash = lambda_inkwash * (ink_wash_loss(fake_B, real_B) + ink_wash_loss(fake_A, real_A))

        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle + loss_brushstroke + loss_inkwash
        loss_G.backward()
        torch.nn.utils.clip_grad_norm_(generator_A2B.parameters(), max_norm=1.0)  # Gradient clipping
        torch.nn.utils.clip_grad_norm_(generator_B2A.parameters(), max_norm=1.0)
        optimizer_G.step()

        # Train Discriminators
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        loss_real_A = criterion_GAN(discriminator_A(real_A), torch.ones_like(discriminator_A(real_A)))
        loss_fake_A = criterion_GAN(discriminator_A(fake_A.detach()), torch.zeros_like(discriminator_A(fake_A)))
        loss_D_A = (loss_real_A + loss_fake_A) * 0.5

        loss_real_B = criterion_GAN(discriminator_B(real_B), torch.ones_like(discriminator_B(real_B)))
        loss_fake_B = criterion_GAN(discriminator_B(fake_B.detach()), torch.zeros_like(discriminator_B(fake_B)))
        loss_D_B = (loss_real_B + loss_fake_B) * 0.5

        loss_D_A.backward()
        optimizer_D_A.step()
        loss_D_B.backward()
        optimizer_D_B.step()

    print(f"Epoch [{epoch + 1}/{epochs}] | G Loss: {loss_G.item():.4f} | D_A Loss: {loss_D_A.item():.4f} | D_B Loss: {loss_D_B.item():.4f}")


Epoch [1/50] | G Loss: 8.0374 | D_A Loss: 0.2207 | D_B Loss: 0.1752
Epoch [2/50] | G Loss: 8.9772 | D_A Loss: 0.2269 | D_B Loss: 0.2014
Epoch [3/50] | G Loss: 6.8893 | D_A Loss: 0.1086 | D_B Loss: 0.2151
Epoch [4/50] | G Loss: 8.6775 | D_A Loss: 0.1874 | D_B Loss: 0.1777
Epoch [5/50] | G Loss: 9.2408 | D_A Loss: 0.1910 | D_B Loss: 0.2174
Epoch [6/50] | G Loss: 7.6039 | D_A Loss: 0.2553 | D_B Loss: 0.1597
Epoch [7/50] | G Loss: 7.7090 | D_A Loss: 0.2076 | D_B Loss: 0.2273
Epoch [8/50] | G Loss: 6.9157 | D_A Loss: 0.1721 | D_B Loss: 0.0894
Epoch [9/50] | G Loss: 7.1992 | D_A Loss: 0.2538 | D_B Loss: 0.0898
Epoch [10/50] | G Loss: 6.4527 | D_A Loss: 0.1283 | D_B Loss: 0.1351
Epoch [11/50] | G Loss: 7.0470 | D_A Loss: 0.1158 | D_B Loss: 0.1587
Epoch [12/50] | G Loss: 6.1959 | D_A Loss: 0.1297 | D_B Loss: 0.1968
Epoch [13/50] | G Loss: 6.3391 | D_A Loss: 0.1882 | D_B Loss: 0.1673
Epoch [14/50] | G Loss: 5.8589 | D_A Loss: 0.1613 | D_B Loss: 0.1722
Epoch [15/50] | G Loss: 6.8535 | D_A Loss: 

Now, we are trying to test the model

In [None]:
import os
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from PIL import Image
from torchvision.utils import save_image

# save model
model_save_path = "./models/"
os.makedirs(model_save_path, exist_ok=True)

torch.save(generator_A2B.state_dict(), os.path.join(model_save_path, "generator_A2B.pth"))
torch.save(generator_B2A.state_dict(), os.path.join(model_save_path, "generator_B2A.pth"))
print("模型保存成功！")

# Define the test set path
test_root_A = "/content/CIP_dataset1_process/testA"  # test A
test_root_B = "/content/CIP_dataset1_process/testB"  # test B
output_path = "./generated_images/"
os.makedirs(output_path, exist_ok=True)

# Load test set
class TestImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.files = sorted([os.path.join(root, f) for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))])
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, os.path.basename(img_path)  # Returns the image and its file name

# Test set conversion
transform = Compose([
    Resize((256, 256)),
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_dataset_A = TestImageDataset(test_root_A, transform=transform)
test_loader_A = DataLoader(test_dataset_A, batch_size=1, shuffle=False)

test_dataset_B = TestImageDataset(test_root_B, transform=transform)
test_loader_B = DataLoader(test_dataset_B, batch_size=1, shuffle=False)

# Load the trained model
generator_A2B = Generator(3, 3).to(device)
generator_B2A = Generator(3, 3).to(device)

generator_A2B.load_state_dict(torch.load(os.path.join(model_save_path, "generator_A2B.pth")))
generator_B2A.load_state_dict(torch.load(os.path.join(model_save_path, "generator_B2A.pth")))

generator_A2B.eval()
generator_B2A.eval()
print("模型加载成功！")

# The test set generates images
def generate_images(loader, generator, output_subfolder):
    output_dir = os.path.join(output_path, output_subfolder)
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for i, (img, filename) in enumerate(loader):
            img = img.to(device)
            fake_img = generator(img)
            fake_img = (fake_img * 0.5 + 0.5).clamp(0, 1)  # Normalized albedo

            save_path = os.path.join(output_dir, filename[0])  # Save the resulting image
            save_image(fake_img, save_path)

# Produces A -> B and B -> A results
generate_images(test_loader_A, generator_A2B, "A_to_B")
generate_images(test_loader_B, generator_B2A, "B_to_A")

print("生成图像已保存到 ./generated_images/")



模型保存成功！
模型加载成功！
生成图像已保存到 ./generated_images/


  generator_A2B.load_state_dict(torch.load(os.path.join(model_save_path, "generator_A2B.pth")))
  generator_B2A.load_state_dict(torch.load(os.path.join(model_save_path, "generator_B2A.pth")))


In [None]:
!zip -r /content/generated_images /content/generated_images
!zip -r /content/models /content/models

  adding: content/generated_images/ (stored 0%)
  adding: content/generated_images/A_to_B/ (stored 0%)
  adding: content/generated_images/A_to_B/cliff_3537491673_a2939e5d44.jpg (deflated 3%)
  adding: content/generated_images/A_to_B/湖亭.jpg (deflated 1%)
  adding: content/generated_images/A_to_B/hengshan4.jpg (deflated 1%)
  adding: content/generated_images/A_to_B/hengshan1.jpg (deflated 1%)
  adding: content/generated_images/A_to_B/IMG_3909.jpg (deflated 1%)
  adding: content/generated_images/A_to_B/cliff_2608866102_1de76908a4.jpg (deflated 2%)
  adding: content/generated_images/A_to_B/cliff_3745745866_1c95a6a923.jpg (deflated 2%)
  adding: content/generated_images/B_to_A/ (stored 0%)
  adding: content/models/ (stored 0%)
  adding: content/models/generator_B2A.pth (deflated 8%)
  adding: content/models/generator_A2B.pth (deflated 8%)


In [None]:
!rm -rf /content/generated_images

## Previous attempt 1: CycleGAN+VGG+self attention+brush loss

We've tried this mothod. We used cyclegan+vgg+selfattion+brush loss.

But the result is not good.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.models import vgg16, VGG16_Weights
from PIL import Image, ImageOps  # Fix: Import the ImageOps module

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset class with dynamic resizing
class ImageDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None, target_size=(256, 256)):
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A) if os.path.isfile(os.path.join(root_A, f))])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B) if os.path.isfile(os.path.join(root_B, f))])
        self.transform = transform
        self.target_size = target_size

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def __getitem__(self, idx):
        file_A = self.files_A[idx % len(self.files_A)]
        file_B = self.files_B[idx % len(self.files_B)]
        img_A = Image.open(file_A).convert("RGB")
        img_B = Image.open(file_B).convert("RGB")

        # Resize and pad to target size
        img_A = self.resize_and_pad(img_A)
        img_B = self.resize_and_pad(img_B)

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

    def resize_and_pad(self, img):
        """Resize while keeping aspect ratio, then pad to target size."""
        aspect = img.width / img.height
        if aspect > 1:  # Width > Height
            new_width = self.target_size[0]
            new_height = int(self.target_size[0] / aspect)
        else:  # Height >= Width
            new_height = self.target_size[1]
            new_width = int(self.target_size[1] * aspect)

        img = img.resize((new_width, new_height), Image.BICUBIC)
        pad_width = (self.target_size[0] - new_width) // 2
        pad_height = (self.target_size[1] - new_height) // 2
        padding = (pad_width, pad_height, self.target_size[0] - new_width - pad_width, self.target_size[1] - new_height - pad_height)
        img = ImageOps.expand(img, border=padding, fill=0)  # Use ImageOps.expand to populate
        return img

# Data enhanced transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load data set
dataset = ImageDataset(
    root_A="/content/CIP_dataset1_process/trainA",  # Replace with the actual path
    root_B="/content/CIP_dataset1_process/trainB",  # Replace with the actual path
    transform=transform,
    target_size=(256, 256)
)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Residual Block with self-attention
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 16, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        residual = self.block(x)
        attention = self.se(residual)
        return x + residual * attention

# Generator
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_residual_blocks=9):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.down = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(num_residual_blocks)]
        )
        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.output = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down(x)
        x = self.res_blocks(x)
        x = self.up(x)
        x = self.output(x)
        return x

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)


# Initialize models
generator_A2B = Generator(3, 3).to(device)
generator_B2A = Generator(3, 3).to(device)
discriminator_A = Discriminator(3).to(device)
discriminator_B = Discriminator(3).to(device)

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# VGG Perceptual Loss
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = vgg16(weights=VGG16_Weights.DEFAULT).features[:16].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.L1Loss()

    def forward(self, input, target):
        input_features = self.vgg(input)
        target_features = self.vgg(target)
        return self.criterion(input_features, target_features)

vgg_loss_fn = VGGPerceptualLoss().to(device)

# Additional Loss Functions
def compute_gradient(image):
    """
    Compute gradient of an image using finite difference.
    Handles edge cases where dimensions mismatch due to slicing.
    """
    image = (image + 1) / 2  # Convert to [0, 1]
    grad_x = torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])
    grad_y = torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])

    # Pad gradients to match original image size
    grad_x = nn.functional.pad(grad_x, (0, 1, 0, 0))  # Pad last column
    grad_y = nn.functional.pad(grad_y, (0, 0, 0, 1))  # Pad last row
    return grad_x + grad_y


def compute_histogram(image, bins=256):
    image = (image + 1) / 2  # Convert to [0, 1]
    hist = torch.histc(image, bins=bins, min=0.0, max=1.0)
    return hist / hist.sum()

def brushstroke_loss(fake, real):
    grad_fake = compute_gradient(fake)
    grad_real = compute_gradient(real)
    return torch.mean(torch.abs(grad_fake - grad_real))

def inkwash_loss(fake, real):
    hist_fake = compute_histogram(fake)
    hist_real = compute_histogram(real)
    return torch.mean((hist_fake - hist_real) ** 2)

def gaussian_blur(image, kernel_size=5, sigma=2):
    """
    Apply Gaussian blur to simulate diffusion effect.
    Ensure kernel_size is a tuple of two integers as required by torchvision.
    """
    from torchvision.transforms.functional import gaussian_blur as blur
    kernel_size = (kernel_size, kernel_size)  # Ensure kernel_size is a tuple
    return torch.stack([blur(image[:, i:i+1], kernel_size, sigma) for i in range(3)], dim=1)



# Optimizers and Schedulers (Same as the original)
optimizer_G = optim.Adam(
    list(generator_A2B.parameters()) + list(generator_B2A.parameters()),
    lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(discriminator_A.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(discriminator_B.parameters(), lr=0.0001, betas=(0.5, 0.999))
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.5)
scheduler_D_A = optim.lr_scheduler.StepLR(optimizer_D_A, step_size=10, gamma=0.5)
scheduler_D_B = optim.lr_scheduler.StepLR(optimizer_D_B, step_size=10, gamma=0.5)

# Training loop
epochs = 50  # Total number of epochs
lambda_cycle = 10.0  # Weight for cycle consistency loss
lambda_vgg = 0.1  # Weight for perceptual loss
lambda_brushstroke = 0.5  # Weight for brushstroke loss
lambda_inkwash = 0.5  # Weight for ink wash loss
lambda_diffusion = 0.2  # Weight for diffusion loss

for epoch in range(epochs):
    for i, batch in enumerate(dataloader):
        # Load real images
        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)

        # ======================
        # Train Generators
        # ======================
        optimizer_G.zero_grad()

        # Generate fake images
        fake_B = generator_A2B(real_A)
        fake_A = generator_B2A(real_B)

        # GAN Loss
        loss_GAN_A2B = criterion_GAN(discriminator_B(fake_B), torch.ones_like(discriminator_B(fake_B)))
        loss_GAN_B2A = criterion_GAN(discriminator_A(fake_A), torch.ones_like(discriminator_A(fake_A)))

        # Cycle Consistency Loss
        recov_A = generator_B2A(fake_B)
        recov_B = generator_A2B(fake_A)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        # VGG Perceptual Loss
        loss_vgg_A2B = vgg_loss_fn(fake_B, real_B)
        loss_vgg_B2A = vgg_loss_fn(fake_A, real_A)

        # Brushstroke Loss
        loss_brushstroke_A2B = brushstroke_loss(fake_B, real_B)
        loss_brushstroke_B2A = brushstroke_loss(fake_A, real_A)

        # Ink Wash Loss
        loss_inkwash_A2B = inkwash_loss(fake_B, real_B)
        loss_inkwash_B2A = inkwash_loss(fake_A, real_A)

        # Diffusion Effect Loss (blurred outputs)
        blurred_fake_B = gaussian_blur(fake_B)
        blurred_real_B = gaussian_blur(real_B)
        loss_diffusion_A2B = criterion_cycle(blurred_fake_B, blurred_real_B)

        # Total Generator Loss
        loss_G = (
            loss_GAN_A2B + loss_GAN_B2A
            + lambda_cycle * (loss_cycle_A + loss_cycle_B)
            + lambda_vgg * (loss_vgg_A2B + loss_vgg_B2A)
            + lambda_brushstroke * (loss_brushstroke_A2B + loss_brushstroke_B2A)
            + lambda_inkwash * (loss_inkwash_A2B + loss_inkwash_B2A)
            + lambda_diffusion * loss_diffusion_A2B
        )

        # Backpropagation for generators
        loss_G.backward()
        optimizer_G.step()

        # ======================
        # Train Discriminators
        # ======================
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        # Discriminator A loss
        loss_real_A = criterion_GAN(discriminator_A(real_A), torch.ones_like(discriminator_A(real_A)))
        loss_fake_A = criterion_GAN(discriminator_A(fake_A.detach()), torch.zeros_like(discriminator_A(fake_A)))
        loss_D_A = (loss_real_A + loss_fake_A) / 2

        # Discriminator B loss
        loss_real_B = criterion_GAN(discriminator_B(real_B), torch.ones_like(discriminator_B(real_B)))
        loss_fake_B = criterion_GAN(discriminator_B(fake_B.detach()), torch.zeros_like(discriminator_B(fake_B)))
        loss_D_B = (loss_real_B + loss_fake_B) / 2

        # Backpropagation for discriminators
        loss_D_A.backward()
        optimizer_D_A.step()
        loss_D_B.backward()
        optimizer_D_B.step()

        # ======================
        # Logging
        # ======================
        print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
              f"[D loss: {(loss_D_A + loss_D_B).item()}] [G loss: {loss_G.item()}]")

    # ======================
    # Save Model Checkpoints
    # ======================
    if (epoch + 1) % 10 == 0:  # Save model every 10 epochs
        torch.save(generator_A2B.state_dict(), f"generator_A2B_epoch_{epoch + 1}.pth")
        torch.save(generator_B2A.state_dict(), f"generator_B2A_epoch_{epoch + 1}.pth")
        torch.save(discriminator_A.state_dict(), f"discriminator_A_epoch_{epoch + 1}.pth")
        torch.save(discriminator_B.state_dict(), f"discriminator_B_epoch_{epoch + 1}.pth")

    # ======================
    # Update Learning Rate
    # ======================
    scheduler_G.step()
    scheduler_D_A.step()
    scheduler_D_B.step()


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
[Epoch 10/50] [Batch 0/125] [D loss: 0.37099453806877136] [G loss: 4.141383647918701]
[Epoch 10/50] [Batch 1/125] [D loss: 0.30168700218200684] [G loss: 3.3745131492614746]
[Epoch 10/50] [Batch 2/125] [D loss: 0.4045707583427429] [G loss: 3.6180644035339355]
[Epoch 10/50] [Batch 3/125] [D loss: 0.5047250986099243] [G loss: 3.3772401809692383]
[Epoch 10/50] [Batch 4/125] [D loss: 0.3110966384410858] [G loss: 3.6385719776153564]
[Epoch 10/50] [Batch 5/125] [D loss: 0.3707459568977356] [G loss: 3.1487364768981934]
[Epoch 10/50] [Batch 6/125] [D loss: 0.3306986689567566] [G loss: 3.4431471824645996]
[Epoch 10/50] [Batch 7/125] [D loss: 0.42097383737564087] [G loss: 3.300662040710449]
[Epoch 10/50] [Batch 8/125] [D loss: 0.37541526556015015] [G loss: 3.24800968170166]
[Epoch 10/50] [Batch 9/125] [D loss: 0.4334828853607178] [G loss: 3.063694953918457]
[Epoch 10/50] [Batch 10/125] [D loss: 0.34918758273124695] [G loss: 3.3563663959503174]
[Epoch 10/50

test

In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image, ImageOps  # Used to load and preprocess images

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# TestImageDataset define
class TestImageDataset(Dataset):
    def __init__(self, root, transform=None, target_size=(256, 256)):
        self.files = sorted([os.path.join(root, f) for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))])
        self.transform = transform
        self.target_size = target_size

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]
        img = Image.open(file).convert("RGB")
        img = self.resize_and_pad(img)  # Resize and fill
        if self.transform:
            img = self.transform(img)
        return img, file

    def resize_and_pad(self, img):
        """Resize while keeping aspect ratio, then pad to target size."""
        aspect = img.width / img.height
        if aspect > 1:  # Width > Height
            new_width = self.target_size[0]
            new_height = int(self.target_size[0] / aspect)
        else:  # Height >= Width
            new_height = self.target_size[1]
            new_width = int(self.target_size[1] * aspect)

        img = img.resize((new_width, new_height), Image.BICUBIC)
        pad_width = (self.target_size[0] - new_width) // 2
        pad_height = (self.target_size[1] - new_height) // 2
        padding = (pad_width, pad_height, self.target_size[0] - new_width - pad_width, self.target_size[1] - new_height - pad_height)
        img = ImageOps.expand(img, border=padding, fill=0)
        return img

# data pre-processing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize到[-1, 1]
])

# Test data path
testA_dir = "/content/CIP_dataset1_process/testA"
testB_dir = "/content/CIP_dataset1_process/testB"
output_dir = "result_improved"
os.makedirs(output_dir, exist_ok=True)

# Load the test data set
testA_dataset = TestImageDataset(testA_dir, transform=transform, target_size=(256, 256))
testB_dataset = TestImageDataset(testB_dir, transform=transform, target_size=(256, 256))
testA_loader = DataLoader(testA_dataset, batch_size=1, shuffle=False)
testB_loader = DataLoader(testB_dataset, batch_size=1, shuffle=False)

# Generator model definition
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_residual_blocks=9):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.down = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(num_residual_blocks)]
        )
        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.output = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down(x)
        x = self.res_blocks(x)
        x = self.up(x)
        x = self.output(x)
        return x

# Residual Block with self-attention
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 16, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        residual = self.block(x)
        attention = self.se(residual)
        return x + residual * attention

# Load the pre-trained model
generator_A2B = Generator(3, 3, num_residual_blocks=9).to(device)
generator_B2A = Generator(3, 3, num_residual_blocks=9).to(device)
generator_A2B.load_state_dict(torch.load("/content/generator_A2B_epoch_50.pth"))
generator_B2A.load_state_dict(torch.load("/content/generator_B2A_epoch_50.pth"))
generator_A2B.eval()
generator_B2A.eval()

# test
# TestA -> FakeB
for img, file_path in testA_loader:
    img = img.to(device)
    with torch.no_grad():
        fake_B = generator_A2B(img)
    output_path = os.path.join(output_dir, f"A2B_{os.path.basename(file_path[0])}")
    save_image((fake_B + 1) / 2, output_path)

# TestB -> FakeA
for img, file_path in testB_loader:
    img = img.to(device)
    with torch.no_grad():
        fake_A = generator_B2A(img)
    output_path = os.path.join(output_dir, f"B2A_{os.path.basename(file_path[0])}")
    save_image((fake_A + 1) / 2, output_path)


  generator_A2B.load_state_dict(torch.load("/content/generator_A2B_epoch_50.pth"))
  generator_B2A.load_state_dict(torch.load("/content/generator_B2A_epoch_50.pth"))


In [None]:
!zip -r /content/result_improved.zip /content/result_improved

  adding: content/result_improved/ (stored 0%)
  adding: content/result_improved/A2B_31.jpg (deflated 6%)
  adding: content/result_improved/A2B_huangshan2.jpg (deflated 3%)
  adding: content/result_improved/A2B_plantB_0_120.jpg (deflated 2%)
  adding: content/result_improved/A2B_guilin18.jpg (deflated 1%)
  adding: content/result_improved/A2B_湖亭.jpg (deflated 3%)
  adding: content/result_improved/A2B_changjiang3.jpg (deflated 1%)
  adding: content/result_improved/A2B_guilin26.jpg (deflated 3%)
  adding: content/result_improved/A2B_lake.png (deflated 0%)
  adding: content/result_improved/A2B_5.jpg (deflated 2%)
  adding: content/result_improved/A2B_huashan7.jpg (deflated 6%)
  adding: content/result_improved/B2A_mountain.png (deflated 0%)
  adding: content/result_improved/A2B_huangshan3.jpg (deflated 1%)
  adding: content/result_improved/A2B_changjiang11.jpg (deflated 4%)
  adding: content/result_improved/A2B_57.jpg (deflated 3%)
  adding: content/result_improved/A2B_guilin15.jpg (defla

In [None]:
!rm -rf /content/result_improved

## Previous attempt 2: Cycle GAN+VGG+self attention

This is the first model we trained. The quality of the output images is too low, so we decided not to use it.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from torchvision.utils import save_image
from torchvision.models import vgg16, VGG16_Weights
from PIL import Image

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset class with data augmentation
class ImageDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A) if os.path.isfile(os.path.join(root_A, f))])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B) if os.path.isfile(os.path.join(root_B, f))])
        self.transform = transform

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def __getitem__(self, idx):
        file_A = self.files_A[idx % len(self.files_A)]
        file_B = self.files_B[idx % len(self.files_B)]
        img_A = Image.open(file_A).convert("RGB")
        img_B = Image.open(file_B).convert("RGB")

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

# Transforms with data augmentation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load datasets
dataset = ImageDataset(
    root_A="/content/CIP_dataset1_process/trainA",
    root_B="/content/CIP_dataset1_process/trainB",
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Residual Block with self-attention
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 16, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        residual = self.block(x)
        attention = self.se(residual)
        return x + residual * attention

# Generator
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_residual_blocks=15):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.down = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(num_residual_blocks)]
        )
        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.output = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down(x)
        x = self.res_blocks(x)
        x = self.up(x)
        x = self.output(x)
        return x

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Initialize models
generator_A2B = Generator(3, 3).to(device)
generator_B2A = Generator(3, 3).to(device)
discriminator_A = Discriminator(3).to(device)
discriminator_B = Discriminator(3).to(device)

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# VGG Perceptual Loss
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = vgg16(weights=VGG16_Weights.DEFAULT).features[:16].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.L1Loss()

    def forward(self, input, target):
        input_features = self.vgg(input)
        target_features = self.vgg(target)
        loss = self.criterion(input_features, target_features)
        return loss

vgg_loss_fn = VGGPerceptualLoss().to(device)

# Optimizers
optimizer_G = optim.Adam(
    list(generator_A2B.parameters()) + list(generator_B2A.parameters()),
    lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(discriminator_A.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(discriminator_B.parameters(), lr=0.0001, betas=(0.5, 0.999))

# Learning rate scheduler
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.5)
scheduler_D_A = optim.lr_scheduler.StepLR(optimizer_D_A, step_size=10, gamma=0.5)
scheduler_D_B = optim.lr_scheduler.StepLR(optimizer_D_B, step_size=10, gamma=0.5)

# Hyperparameters
epochs = 50
lambda_cycle = 10.0

# Training loop
for epoch in range(epochs):
    for i, batch in enumerate(dataloader):
        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)

        # Generate fake images
        fake_B = generator_A2B(real_A)
        fake_A = generator_B2A(real_B)

        # GAN Loss
        loss_GAN_A2B = criterion_GAN(discriminator_B(fake_B), torch.ones_like(discriminator_B(fake_B)))
        loss_GAN_B2A = criterion_GAN(discriminator_A(fake_A), torch.ones_like(discriminator_A(fake_A)))

        # Cycle Loss
        recov_A = generator_B2A(fake_B)
        recov_B = generator_A2B(fake_A)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        # VGG Perceptual Loss
        loss_vgg_A2B = vgg_loss_fn(fake_B, real_B)
        loss_vgg_B2A = vgg_loss_fn(fake_A, real_A)

        # Total Generator Loss
        loss_G = loss_GAN_A2B + loss_GAN_B2A + lambda_cycle * (loss_cycle_A + loss_cycle_B) + loss_vgg_A2B + loss_vgg_B2A

        # Update Generators
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # Discriminator Loss
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        loss_real_A = criterion_GAN(discriminator_A(real_A), torch.ones_like(discriminator_A(real_A)))
        loss_fake_A = criterion_GAN(discriminator_A(fake_A.detach()), torch.zeros_like(discriminator_A(fake_A)))
        loss_D_A = (loss_real_A + loss_fake_A) / 2

        loss_real_B = criterion_GAN(discriminator_B(real_B), torch.ones_like(discriminator_B(real_B)))
        loss_fake_B = criterion_GAN(discriminator_B(fake_B.detach()), torch.zeros_like(discriminator_B(fake_B)))
        loss_D_B = (loss_real_B + loss_fake_B) / 2

        # Update Discriminators
        loss_D_A.backward()
        optimizer_D_A.step()
        loss_D_B.backward()
        optimizer_D_B.step()

        print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {(loss_D_A + loss_D_B).item()}] [G loss: {loss_G.item()}]")

    # Step the learning rate scheduler
    scheduler_G.step()
    scheduler_D_A.step()
    scheduler_D_B.step()


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
[Epoch 10/50] [Batch 0/125] [D loss: 0.33433282375335693] [G loss: 5.92738676071167]
[Epoch 10/50] [Batch 1/125] [D loss: 0.3372704088687897] [G loss: 5.408411979675293]
[Epoch 10/50] [Batch 2/125] [D loss: 0.3442777395248413] [G loss: 5.128917217254639]
[Epoch 10/50] [Batch 3/125] [D loss: 0.32579344511032104] [G loss: 5.665841579437256]
[Epoch 10/50] [Batch 4/125] [D loss: 0.2729547619819641] [G loss: 5.278474807739258]
[Epoch 10/50] [Batch 5/125] [D loss: 0.3112141788005829] [G loss: 5.081557750701904]
[Epoch 10/50] [Batch 6/125] [D loss: 0.4907546937465668] [G loss: 4.922496795654297]
[Epoch 10/50] [Batch 7/125] [D loss: 0.4199795126914978] [G loss: 4.948441028594971]
[Epoch 10/50] [Batch 8/125] [D loss: 0.3856120705604553] [G loss: 4.825518608093262]
[Epoch 10/50] [Batch 9/125] [D loss: 0.415230393409729] [G loss: 4.615609169006348]
[Epoch 10/50] [Batch 10/125] [D loss: 0.4011486768722534] [G loss: 4.915618896484375]
[Epoch 10/50] [Batch 11

test the model

In [None]:
# Save models after training
torch.save(generator_A2B.state_dict(), "generator_A2B.pth")
torch.save(generator_B2A.state_dict(), "generator_B2A.pth")

# Reload the trained models for testing
generator_A2B = Generator(3, 3).to(device)
generator_B2A = Generator(3, 3).to(device)

generator_A2B.load_state_dict(torch.load("generator_A2B.pth"))
generator_B2A.load_state_dict(torch.load("generator_B2A.pth"))

generator_A2B.eval()
generator_B2A.eval()

# Update dataset paths for testing
test_dataset = ImageDataset(
    root_A="/content/CIP_dataset1_process/testA",  # Content images
    root_B="/content/CIP_dataset1_process/testB",  # Style images
    transform=transform
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Create output directories
os.makedirs("results/testA_to_B", exist_ok=True)
os.makedirs("results/testB_to_A", exist_ok=True)

# Test and save results
for i, batch in enumerate(test_loader):
    real_A = batch["A"].to(device)  # Content image
    real_B = batch["B"].to(device)  # Style image

    with torch.no_grad():
        fake_B = generator_A2B(real_A)  # Generate style image from content
        fake_A = generator_B2A(real_B)  # Generate content image from style

    # Save generated images
    save_image((real_A + 1) / 2, f"results/testA_to_B/real_A_{i}.png")  # Original content
    save_image((real_B + 1) / 2, f"results/testB_to_A/real_B_{i}.png")  # Original style
    save_image((fake_B + 1) / 2, f"results/testA_to_B/fake_B_{i}.png")  # Styled content
    save_image((fake_A + 1) / 2, f"results/testB_to_A/fake_A_{i}.png")  # Contentized style

    print(f"Processed image pair {i}")


  generator_A2B.load_state_dict(torch.load("generator_A2B.pth"))
  generator_B2A.load_state_dict(torch.load("generator_B2A.pth"))


Processed image pair 0
Processed image pair 1
Processed image pair 2
Processed image pair 3
Processed image pair 4
Processed image pair 5
Processed image pair 6


In [None]:
!zip -r results.zip results

  adding: results/ (stored 0%)
  adding: results/testB_to_A/ (stored 0%)
  adding: results/testB_to_A/fake_A_4.png (deflated 0%)
  adding: results/testB_to_A/real_B_1.png (deflated 0%)
  adding: results/testB_to_A/real_B_4.png (deflated 0%)
  adding: results/testB_to_A/fake_A_6.png (deflated 0%)
  adding: results/testB_to_A/fake_A_1.png (deflated 0%)
  adding: results/testB_to_A/fake_A_3.png (deflated 0%)
  adding: results/testB_to_A/fake_A_5.png (deflated 0%)
  adding: results/testB_to_A/real_B_5.png (deflated 0%)
  adding: results/testB_to_A/fake_A_2.png (deflated 0%)
  adding: results/testB_to_A/real_B_0.png (deflated 0%)
  adding: results/testB_to_A/real_B_2.png (deflated 0%)
  adding: results/testB_to_A/real_B_3.png (deflated 0%)
  adding: results/testB_to_A/fake_A_0.png (deflated 0%)
  adding: results/testB_to_A/real_B_6.png (deflated 0%)
  adding: results/testA_to_B/ (stored 0%)
  adding: results/testA_to_B/fake_B_3.png (deflated 0%)
  adding: results/testA_to_B/fake_B_1.png (de

In [None]:
import torch
from torchvision.utils import save_image
import os

# Define your Generator model (assuming the model class is already defined)
generator_A2B = Generator(3, 3).to(device)  # Assuming you have defined this class before

# Load the pre-trained generator model (G-A)
generator_A2B.load_state_dict(torch.load("/content/latest_net_G_A.pth"))
generator_A2B.eval()  # Set the model to evaluation mode

# Create output directory
os.makedirs("results/A2B", exist_ok=True)

# Test the model
with torch.no_grad():  # No gradients are needed for inference
    for i, batch in enumerate(dataloader):
        real_A = batch["A"].to(device)  # Get real images from domain A

        # Generate fake images in domain B (style transfer)
        fake_B = generator_A2B(real_A)

        # Save generated images
        save_image((fake_B + 1) / 2, f"results/A2B/{i}.png")  # Normalize and save images
        print(f"Processed image {i} and saved to 'results/A2B/{i}.png'")

