In [11]:
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import torchvision
import matplotlib.pyplot as plt
import numpy as np

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [13]:
transform_hr = transforms.Compose([
    transforms.Resize((512, 512)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform_lr = transforms.Compose([
    transforms.Resize((256, 256)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class FaceDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform_hr, transform_lr):
        self.root_dir = os.path.join(root_dir, 'raw')  # Update this line to include 'raw'
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        # Ensure that we only list files (this filters out subdirectories if there are any)
        self.images = [f for f in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, f))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.images[idx])
        image_hr = Image.open(img_name).convert('RGB')  # Convert to RGB to avoid issues with RGBA or grayscale images
        image_lr = image_hr.copy()
        image_hr = self.transform_hr(image_hr)
        image_lr = self.transform_lr(image_lr)
        return image_lr, image_hr

dataset = FaceDataset('../dataset/face_imgs', transform_hr, transform_lr)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [14]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.PReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.block(x)

class SRResNet(nn.Module):
    def __init__(self, in_channels=3, num_blocks=8):
        super(SRResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4)
        self.prelu = nn.PReLU()
        
        # Residual blocks
        self.residuals = nn.Sequential(*[ResidualBlock(64) for _ in range(num_blocks)])
        
        # Another conv layer after residual blocks
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        # Upsampling layer
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )

        # Output layer
        self.conv3 = nn.Conv2d(64, in_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        out1 = self.prelu(self.conv1(x))
        res = self.residuals(out1)
        out2 = self.bn2(self.conv2(res)) + out1
        out3 = self.upsample(out2)
        out = self.conv3(out3)
        return out
    
    
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        return self.model(x)

In [15]:
srresnet = SRResNet()
discriminator = Discriminator()

In [16]:
criterion = nn.BCELoss()  # 二元交叉熵损失
optimizer_srresnet = optim.Adam(srresnet.parameters(), lr=0.001)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.001)

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
srresnet.to(device)
discriminator.to(device)

sample_dir = './samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# 训练模型
num_epochs = 5  # 根据需要调整

for epoch in range(num_epochs):
    for i, (images_lr, images_hr) in enumerate(dataloader):
        images_lr = images_lr.to(device)
        images_hr = images_hr.to(device)

        # 真实和假的标签
        real_labels = torch.ones(images_hr.size(0), 1).to(device)
        fake_labels = torch.zeros(images_hr.size(0), 1).to(device)

        # 训练判别器
        optimizer_discriminator.zero_grad()
        real_outputs = discriminator(images_hr).squeeze().unsqueeze(1)
        d_loss_real = criterion(real_outputs, real_labels)
        fake_images = srresnet(images_lr)
        fake_outputs = discriminator(fake_images.detach()).squeeze().unsqueeze(1)
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_discriminator.step()

        # 训练生成器 (SRResNet)
        optimizer_srresnet.zero_grad()
        fake_images = srresnet(images_lr)
        outputs = discriminator(fake_images).squeeze().unsqueeze(1)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_srresnet.step()

        # 打印损失信息
        if (i + 1) % 53 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

        # 每个epoch结束时保存一些生成的样本
        if (i + 1) % 53 == 0:
            fake_images = fake_images.reshape(fake_images.size(0), 3, 512, 512)
            save_image(fake_images, os.path.join(sample_dir, f'sample_epoch_{epoch}_{i}.png'))

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
