In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader


In [None]:
image_directory = "VOCtrainval_06-Nov-2007\VOCdevkit\VOC2007\JPEGImages"
mask_directory = "VOCtrainval_06-Nov-2007\VOCdevkit\VOC2007\SegmentationClass"
batch_size = 32
num_epochs = 10


In [None]:
import torch
import torch.nn as nn

# 定义U-Net模型中的编码器部分
class UNetEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetEncoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        skip_connection = x
        x = self.maxpool(x)
        return x, skip_connection

# 定义U-Net模型中的解码器部分
class UNetDecoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetDecoder, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x, skip_connection):
        x = self.upconv(x)
        x = torch.cat([x, skip_connection], dim=1)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x

# 定义完整的 U-Net 模型
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder1 = UNetEncoder(in_channels, 64)
        self.encoder2 = UNetEncoder(64, 128)
        self.encoder3 = UNetEncoder(128, 256)
        self.encoder4 = UNetEncoder(256, 512)
        self.center = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        )
        self.decoder4 = UNetDecoder(1024, 256)
        self.decoder3 = UNetDecoder(512, 128)
        self.decoder2 = UNetDecoder(256, 64)
        self.decoder1 = UNetDecoder(128, out_channels)

    def forward(self, x):
        x, skip1 = self.encoder1(x)
        x, skip2 = self.encoder2(x)
        x, skip3 = self.encoder3(x)
        x, skip4 = self.encoder4(x)
        x = self.center(x)
        x = self.decoder4(x, skip4)
        x = self.decoder3(x, skip3)
        x = self.decoder2(x, skip2)
        x = self.decoder1(x, skip1)
        return x


In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_files = os.listdir(image_dir)
        self.mask_files = os.listdir(mask_dir)
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.image_files[index])
        mask_path = os.path.join(self.mask_dir, self.mask_files[index])
        image = self.transform(Image.open(image_path))
        mask = self.transform(Image.open(mask_path))
        return image, mask


In [None]:
generator = generator.to(device)
discriminator = discriminator.to(device)
generator_optimizer = optim.Adam(generator.parameters(), lr=0.001)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
criterion = nn.BCELoss()


In [None]:
transform = ToTensor()  # 图像转为Tensor
dataset = CustomDataset(image_directory, mask_directory, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
for epoch in range(num_epochs):
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        # 将图像和遮罩图像输入生成器，生成合成图像
        synthetic_images = generator(images)

        # 将真实图像和合成图像与相应的遮罩图像一起输入给判别器进行判断
        real_outputs = discriminator(images, masks)
        synthetic_outputs = discriminator(synthetic_images, masks)

        # 计算生成器和判别器的损失
        generator_loss = criterion(synthetic_outputs, torch.ones_like(synthetic_outputs))
        discriminator_loss = criterion(real_outputs, torch.ones_like(real_outputs)) + criterion(synthetic_outputs, torch.zeros_like(synthetic_outputs))

        # 反向传播和更新生成器的参数
        generator_optimizer.zero_grad()
        generator_loss.backward()
        generator_optimizer.step()

        # 反向传播和更新判别器的参数
        discriminator_optimizer.zero_grad()
        discriminator_loss.backward()
        discriminator_optimizer.step()

    # 打印当前训练轮数和损失
    print(f"Epoch [{epoch+1}/{num_epochs}], Generator Loss: {generator_loss.item()}, Discriminator Loss: {discriminator_loss.item()}")
