In [3]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import matplotlib.pyplot as plt
import tqdm
import math
from sklearn.model_selection import train_test_split

In [4]:
class ImagePairDataset(Dataset):
    def __init__(self, folder1, folder2, pairs, transform=None, enhance_transform_1 = None, 
                 enhance_transform_2 = None, enhance_transform_3 = None):
        self.folder1 = folder1
        self.folder2 = folder2
        self.pairs = pairs
        self.transform = transform
        self.enhance_transform_1 = enhance_transform_1
        self.enhance_transform_2 = enhance_transform_2
        self.enhance_transform_3 = enhance_transform_3
        self.image_pairs = self.read_image_pairs()

    def read_image_pairs(self):
      image_pairs = []
      for image_pair in tqdm.tqdm(self.pairs):
        img1_path = os.path.join(self.folder1, image_pair[0])
        img2_path = os.path.join(self.folder2, image_pair[1])
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")
        img3 = img1
        img4 = img2
        img5 = img1
        img6 = img2
        img7 = img1
        img8 = img2
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        if self.enhance_transform_1:
            img3 = self.enhance_transform_1(img3)
            img4 = self.enhance_transform_1(img4)
        if self.enhance_transform_2:
            img5 = self.enhance_transform_2(img5)
            img6 = self.enhance_transform_2(img6)
        if self.enhance_transform_3:
            img7 = self.enhance_transform_3(img7)
            img8 = self.enhance_transform_3(img8)
        image_pairs.append((img1, img2))
        image_pairs.append((img3, img4))
        image_pairs.append((img5, img6))
        image_pairs.append((img7, img8))
      return image_pairs
        
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
      return self.image_pairs[idx][0], self.image_pairs[idx][1]

In [5]:
# Define paths to the folders
digital_dir = "D:/data/数码"
film_dir = "D:/data/富士c200/胶片模拟/JPEG"

In [6]:
# Get the list of files in both folders
digital = sorted(os.listdir(digital_dir))
film = sorted(os.listdir(film_dir))

print(len(digital))
print(len(film))

# Ensure the number of files match
if len(digital) != len(film):
    raise ValueError("The two folders must have the same number of images.")

# Create pairs of images (file1, file2)
pairs = list(zip(digital, film))

# Split into training and testing sets
train_pairs, test_pairs = train_test_split(pairs, test_size=0.2, random_state=42)

1517
1517


In [7]:
# Define transformations (if needed)
transform = transforms.Compose([
  transforms.Resize((200, 320)),
  transforms.ToTensor()
])
enhance_transform_1 = transforms.Compose([
  transforms.RandomRotation(degrees=(-180, -180), expand=True),
  transforms.Resize((200, 320)),
  transforms.ToTensor(),
])
enhance_transform_2 = transforms.Compose([
  transforms.CenterCrop(330),
  transforms.Resize((200, 320)),
  transforms.ToTensor(),
])
enhance_transform_3 = transforms.Compose([
  transforms.RandomRotation(degrees=(-90, -90), expand=True),
  transforms.Resize((200, 320)),
  transforms.ToTensor(),
])

In [8]:
batch_size = 32

In [9]:
# Create datasets
train_dataset = ImagePairDataset(digital_dir, film_dir, train_pairs, transform=transform, 
                                 enhance_transform_1=enhance_transform_1,
                                 enhance_transform_2=enhance_transform_2,
                                 enhance_transform_3 = enhance_transform_3)
test_dataset = ImagePairDataset(digital_dir, film_dir, test_pairs, transform=transform, 
                                enhance_transform_1=enhance_transform_1,
                                enhance_transform_2=enhance_transform_2,
                                enhance_transform_3 = enhance_transform_3)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False, drop_last=True)

100%|██████████| 1213/1213 [04:15<00:00,  4.74it/s]
100%|██████████| 304/304 [01:00<00:00,  5.01it/s]


In [10]:
# 通道注意力机制 (Channel Attention Mechanism)
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class FilmStyleTransfer(torch.nn.Module):
    def __init__(self):
        super(FilmStyleTransfer, self).__init__()

        # 定义卷积层
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv5 = torch.nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.conv6 = torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.conv7 = torch.nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv8 = torch.nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1)
        
        # 定义激活函数
        self.leakyrelu = torch.nn.LeakyReLU()
        
        # 定义BatchNorm层
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.bn2 = torch.nn.BatchNorm2d(64)
        self.bn3 = torch.nn.BatchNorm2d(128)
        self.bn4 = torch.nn.BatchNorm2d(256)
        self.bn5 = torch.nn.BatchNorm2d(128)
        self.bn6 = torch.nn.BatchNorm2d(64)
        self.bn7 = torch.nn.BatchNorm2d(32)

        # 注意力机制
        self.ca2 = ChannelAttention(64)
        self.ca3 = ChannelAttention(128)
        self.ca4 = ChannelAttention(256)
        self.ca5 = ChannelAttention(128)
        self.ca6 = ChannelAttention(64)
        self.ca7 = ChannelAttention(32)
        
        # 跳跃连接的卷积层
        self.skip1 = torch.nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0)  # 将32通道映射到3通道
        self.skip2 = torch.nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)  # 将64通道映射到3通道
        
    def forward(self, x):
        x1 = (self.leakyrelu(self.bn1(self.conv1(x))))
        x2 = self.ca2(self.leakyrelu(self.bn2(self.conv2(x1))))
        x3 = self.ca3(self.leakyrelu(self.bn3(self.conv3(x2))))
        x4 = self.ca4(self.leakyrelu(self.bn4(self.conv4(x3))))
        x5 = self.ca5(self.leakyrelu(self.bn5(self.conv5(x4))))
        x6 = self.ca6(self.leakyrelu(self.bn6(self.conv6(x5))))
        x7 = self.ca7(self.leakyrelu(self.bn7(self.conv7(x6))))
        x8 = self.conv8(x7)
        
        # 跳跃连接
        skip1 = self.skip1(x1)  # 将x1的32通道映射到3通道
        skip2 = self.skip2(x2)  # 将x2的64通道映射到3通道
        
        # 合并跳跃连接和最终输出
        output = x8 + skip1 + skip2

        return output

In [11]:
class ChannelStatLoss(nn.Module):
    def forward(self, pred, target):
        loss = 0.0
        # 遍历RGB通道
        for c in range(3):
            pred_channel = pred[:, c, :, :]
            target_channel = target[:, c, :, :]
            
            # 均值差异（L1）
            loss += torch.abs(pred_channel.mean() - target_channel.mean())
            # 方差差异（L1）
            loss += torch.abs(pred_channel.var() - target_channel.var())
        
        lossrgb = loss + torch.nn.functional.l1_loss(pred, target)
        
        return lossrgb

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [13]:
def psnr(target, prediction, max_pixel=1.0):
    mse = torch.nn.functional.mse_loss(target, prediction)
    if mse == 0:
        return float('inf')  # 如果 MSE 为 0，意味着两图完全相同
    return 20 * math.log10(max_pixel / math.sqrt(mse))

def evaluate_psnr_generate(data_loader, model, max_pixel=1.0):
    total_psnr = 0.0
    num_images = 0

    # 遍历数据集
    for generated_images, original_images in data_loader:
        # 确保输入输出图像在相同的设备上（例如：GPU）
        original_images = original_images.to(device)
        generated_images = generated_images.to(device)

        # 计算每一对图像的 PSNR
        for orig, gen in zip(original_images, generated_images):
            orig = orig.unsqueeze(0).cuda()  # 增加批量维度
            output = model(orig)
            psnr_value = psnr(gen, output[0], max_pixel=max_pixel)
            total_psnr += psnr_value
            num_images += 1

    # 计算整个数据集的平均 PSNR
    avg_psnr = total_psnr / num_images if num_images > 0 else 0.0
    return avg_psnr

def evaluate_psnr_pre_train(data_loader, model, max_pixel=1.0):
    total_psnr = 0.0
    num_images = 0

    # 遍历数据集
    for generated_images, original_images in data_loader:
        # 确保输入输出图像在相同的设备上（例如：GPU）
        original_images = original_images.to(device)
        generated_images = generated_images.to(device)

        # 计算每一对图像的 PSNR
        for orig, gen in zip(original_images, generated_images):
            gen = gen.unsqueeze(0).cuda()  # 增加批量维度
            output = model(gen)
            psnr_value = psnr(orig, output[0], max_pixel=max_pixel)
            total_psnr += psnr_value
            num_images += 1

    # 计算整个数据集的平均 PSNR
    avg_psnr = total_psnr / num_images if num_images > 0 else 0.0
    return avg_psnr

In [14]:
def postprocess_image(tensor):
    """将模型输出的张量转换为 PIL 图像"""
    tensor = tensor.squeeze(0)  # 去掉 batch 维度
    tensor = tensor.clamp(0, 1)  # 确保值在 [0, 1] 范围内
    transform = transforms.ToPILImage()  # 转换为 PIL 图像
    image = transform(tensor)
    return image

In [15]:
pre_train_model = FilmStyleTransfer()
pre_train_model = pre_train_model.to(device)
optimizer = optim.Adam(pre_train_model.parameters(), lr=0.0001)
loss_function = torch.nn.SmoothL1Loss(reduction='mean')
pre_train_pre_epochs = 0

In [26]:
checkpoint = torch.load("./models/预训练/pre_train_gold_200_training_200.pt")

# 恢复模型和优化器状态
pre_train_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pre_train_model = pre_train_model.to(device)
pre_train_model.eval()
pre_train_pre_epochs = checkpoint['epoch']

# 恢复训练状态
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [16]:
# 训练模型
num_epochs = 1000
psnr_list = []

with tqdm.tqdm(total=num_epochs, desc="进度条") as pbar:
    for epoch in range(num_epochs):
        pre_train_model.train()
        for i, (img1, img2) in enumerate(train_loader):
            img1, img2 = img1.cuda(), img2.cuda()
            optimizer.zero_grad()
            reconstructed = pre_train_model(img1)
            loss = loss_function(reconstructed, img2)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item())
        pbar.update(1)
        if not ((epoch + 1 + pre_train_pre_epochs) % (100)):
            pre_train_model.eval()
            checkpoint = {
                'epoch': epoch + pre_train_pre_epochs + 1,
                'model_state_dict': pre_train_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }
            torch.save(checkpoint, 'pre_train_gold_200_training_' + str(epoch + pre_train_pre_epochs + 1) + '.pt')
            print("checkpoint: %s saved." % str(epoch + pre_train_pre_epochs + 1))
            avg_psnr = evaluate_psnr_pre_train(test_loader, pre_train_model)
            print(f'Average PSNR for the dataset: {avg_psnr:.2f} dB')

进度条:  10%|█         | 100/1000 [13:08<2:00:56,  8.06s/it, loss=0.001]  

checkpoint: 100 saved.
Average PSNR for the dataset: 27.53 dB


进度条:  20%|██        | 200/1000 [26:40<1:45:14,  7.89s/it, loss=0.000676]

checkpoint: 200 saved.
Average PSNR for the dataset: 29.94 dB


进度条:  30%|███       | 300/1000 [39:45<1:33:50,  8.04s/it, loss=0.000704]

checkpoint: 300 saved.
Average PSNR for the dataset: 29.25 dB


进度条:  40%|████      | 400/1000 [53:14<1:18:36,  7.86s/it, loss=0.000637]

checkpoint: 400 saved.
Average PSNR for the dataset: 30.80 dB


进度条:  50%|█████     | 500/1000 [1:06:40<1:07:13,  8.07s/it, loss=0.000564]

checkpoint: 500 saved.
Average PSNR for the dataset: 31.34 dB


进度条:  60%|██████    | 600/1000 [1:20:03<53:09,  7.97s/it, loss=0.00039]   

checkpoint: 600 saved.
Average PSNR for the dataset: 30.78 dB


进度条:  70%|███████   | 700/1000 [1:33:15<40:00,  8.00s/it, loss=0.00119] 

checkpoint: 700 saved.
Average PSNR for the dataset: 30.64 dB


进度条:  80%|████████  | 800/1000 [1:46:43<26:50,  8.05s/it, loss=0.000424]

checkpoint: 800 saved.
Average PSNR for the dataset: 31.15 dB


进度条:  90%|█████████ | 900/1000 [1:59:57<13:11,  7.92s/it, loss=0.0015]  

checkpoint: 900 saved.
Average PSNR for the dataset: 31.45 dB


进度条: 100%|██████████| 1000/1000 [2:13:13<00:00,  7.93s/it, loss=0.000685]

checkpoint: 1000 saved.


进度条: 100%|██████████| 1000/1000 [2:13:16<00:00,  8.00s/it, loss=0.000685]

Average PSNR for the dataset: 32.11 dB



