In [1]:
import torch
from torch import nn
import torchvision.utils as vutils
from torch.optim import Adam
from torch.utils.data import DataLoader,Dataset
import glob
from torchvision import transforms, datasets
from torchvision.utils import save_image
import os
import random
from PIL import Image
import itertools
import matplotlib.pyplot as plt

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=3, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [nn.Conv2d(64, output_nc, kernel_size=7, padding=3), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [3]:
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(256, 512, kernel_size=4, padding=1, bias=False),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [nn.Conv2d(512, 1, kernel_size=4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [4]:
transform = transforms.Compose([
    transforms.Resize(int(256*1.12), Image.BICUBIC),
    transforms.RandomCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [5]:
class ImageDataset(Dataset):
    def __init__(self, root='data', transforms_=None, unaligned=False, mode="train"):          ## (root = "./datasets/facades", unaligned=True:非对其数据)
        self.transform = transforms_                             ## transform变为tensor数据
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))     ## "./datasets/facades/trainA/*.*"
        self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))     ## "./datasets/facades/trainB/*.*"

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])                   ## 在A中取一张照片

        if self.unaligned:                                                              ## 如果采用非配对数据，在B中随机取一张
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # # 如果是灰度图，把灰度图转换为RGB图
        # if image_A.mode != "RGB":
        #     image_A = to_rgb(image_A)
        # if image_B.mode != "RGB":
        #     image_B = to_rgb(image_B)
        
        # 把RGB图像转换为tensor图, 方便计算，返回字典数据
        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return item_A, item_B

    ## 获取A,B数据的长度
    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


In [6]:
train_dataloader = DataLoader(        ## 改成自己存放文件的目录
    ImageDataset("human_dog", transforms_=transform, unaligned=True, mode="train"),  ## "./datasets/facades" , unaligned:设置非对其数据
    batch_size=4,                                                                  ## batch_size = 1
    shuffle=True,
)
test_dataloader = DataLoader(        ## 改成自己存放文件的目录
    ImageDataset("human_dog", transforms_=transform, unaligned=True, mode="test"),  ## "./datasets/facades" , unaligned:设置非对其数据
    batch_size=4,                                                                  ## batch_size = 1
    shuffle=False,
)


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G_AB = Generator(input_nc=3, output_nc=3).to(device)
G_BA = Generator(input_nc=3, output_nc=3).to(device)
D_A = Discriminator(input_nc=3).to(device)
D_B = Discriminator(input_nc=3).to(device)

# Optimizers
optimizer_G = Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0005, betas=(0.5, 0.999))
optimizer_D_A = Adam(D_A.parameters(), lr=0.0005, betas=(0.5, 0.999))
optimizer_D_B = Adam(D_B.parameters(), lr=0.0005, betas=(0.5, 0.999))

# Losses
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()


output_dir = './cyclegan_images'
os.makedirs(output_dir, exist_ok=True)

In [23]:
for epoch in range(4):
    for i, (real_A, real_B) in enumerate(train_dataloader):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # 训练生成器 G_A 和 G_B
        optimizer_G.zero_grad()
        
        # 对抗性损失
        fake_B = G_AB(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        
        fake_A = G_BA(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        
        # 循环一致性损失
        recovered_A = G_BA(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0
        
        recovered_B = G_AB(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

        # 总损失
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()

        # 训练判别器 D_A
        optimizer_D_A.zero_grad()

        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # 训练判别器 D_B
        optimizer_D_B.zero_grad()

        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()
        
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Loss_G: {loss_G.item()}, Loss_D_A: {loss_D_A.item()}, Loss_D_B: {loss_D_B.item()}')

    with torch.no_grad():
        # 使用测试集中的数据生成图像
        for i, (real_A, real_B) in enumerate(test_dataloader):
            real_A = real_A.to(device)
            fake_B = G_AB(real_A)
            vutils.save_image(fake_B, f'{output_dir}/fake_B_epoch_{epoch}_batch_{i}.png', normalize=True)
        


Epoch: 0, Batch: 0, Loss_G: 15.61037826538086, Loss_D_A: 1.1615573167800903, Loss_D_B: 0.9101475477218628
Epoch: 0, Batch: 10, Loss_G: 6.970175743103027, Loss_D_A: 0.3666701912879944, Loss_D_B: 0.24359501898288727
Epoch: 0, Batch: 20, Loss_G: 5.965689659118652, Loss_D_A: 0.24649454653263092, Loss_D_B: 0.24288181960582733
Epoch: 0, Batch: 30, Loss_G: 5.182547569274902, Loss_D_A: 0.24978424608707428, Loss_D_B: 0.2856913208961487
Epoch: 0, Batch: 40, Loss_G: 7.472266674041748, Loss_D_A: 0.3758556544780731, Loss_D_B: 0.32502102851867676
Epoch: 0, Batch: 50, Loss_G: 7.623085975646973, Loss_D_A: 0.23080357909202576, Loss_D_B: 0.24449694156646729
Epoch: 0, Batch: 60, Loss_G: 6.9660868644714355, Loss_D_A: 0.26649582386016846, Loss_D_B: 0.23975417017936707
Epoch: 0, Batch: 70, Loss_G: 5.563512325286865, Loss_D_A: 0.24589692056179047, Loss_D_B: 0.2512208819389343
Epoch: 0, Batch: 80, Loss_G: 6.209434986114502, Loss_D_A: 0.30268868803977966, Loss_D_B: 0.2863529324531555


### GAN Inversion

Train the model to convert A to A. Then interpolation can be used to generate intermediate image. 

Using contrastive loss to make the model to learn?

In [24]:
class ImageDataset(Dataset):
    def __init__(self, root='data', transforms_=None, unaligned=False, mode="train"):          ## (root = "./datasets/facades", unaligned=True:非对其数据)
        self.transform = transforms_                             ## transform变为tensor数据
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, 'combined') + "/*.*"))     ## "./datasets/facades/trainA/*.*"
        self.files_B = sorted(glob.glob(os.path.join(root, 'combined') + "/*.*"))     ## "./datasets/facades/trainB/*.*"

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])                   ## 在A中取一张照片

        if self.unaligned:                                                              ## 如果采用非配对数据，在B中随机取一张
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # # 如果是灰度图，把灰度图转换为RGB图
        # if image_A.mode != "RGB":
        #     image_A = to_rgb(image_A)
        # if image_B.mode != "RGB":
        #     image_B = to_rgb(image_B)
        
        # 把RGB图像转换为tensor图, 方便计算，返回字典数据
        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return item_A, item_B

    ## 获取A,B数据的长度
    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


In [25]:
train_dataloader = DataLoader(        ## 改成自己存放文件的目录
    ImageDataset("human_dog", transforms_=transform, unaligned=True, mode="train"),  ## "./datasets/facades" , unaligned:设置非对其数据
    batch_size=12,                                                                  ## batch_size = 1
    shuffle=True,
)
test_dataloader = DataLoader(        ## 改成自己存放文件的目录
    ImageDataset("human_dog", transforms_=transform, unaligned=True, mode="test"),  ## "./datasets/facades" , unaligned:设置非对其数据
    batch_size=12,                                                                  ## batch_size = 1
    shuffle=False,
)


In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G_AA = Generator(input_nc=3, output_nc=3).to(device)
# G_BA = Generator(input_nc=3, output_nc=3).to(device)
D_A = Discriminator(input_nc=3).to(device)
D_B = Discriminator(input_nc=3).to(device)

# Optimizers
optimizer_G = Adam(G_AA.parameters(), lr=0.0005, betas=(0.5, 0.999))
optimizer_D_A = Adam(D_A.parameters(), lr=0.0005, betas=(0.5, 0.999))
optimizer_D_B = Adam(D_B.parameters(), lr=0.0005, betas=(0.5, 0.999))

# Losses
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()


output_dir = './cyclegan_images'
os.makedirs(output_dir, exist_ok=True)

In [27]:
for epoch in range(5):
    for i, (real_A, real_B) in enumerate(train_dataloader):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # 训练生成器 G_A 和 G_B
        optimizer_G.zero_grad()
        
        # 对抗性损失
        fake_A = G_AA(real_A)
        pred_fake = D_A(fake_A)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        

        
        # 循环一致性损失
        loss_cycle_BAB = criterion_cycle(fake_A, real_A) * 10.0

        # 总损失
        loss_G = loss_GAN_A2B + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()

        # 训练判别器 D_A
        optimizer_D_A.zero_grad()

        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()
        
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Loss_G: {loss_G.item()}, Loss_D_A: {loss_D_A.item()}')

    with torch.no_grad():
        # 使用测试集中的数据生成图像
        for i, (real_A, real_B) in enumerate(test_dataloader):
            real_A = real_A.to(device)
            fake_A = G_AA(real_A)
            vutils.save_image(fake_A, f'{output_dir}/fake_B_epoch_{epoch}_batch_{i}.png', normalize=True)
        


Epoch: 0, Batch: 0, Loss_G: 6.2875518798828125, Loss_D_A: 0.6166337728500366
Epoch: 0, Batch: 10, Loss_G: 3.730794906616211, Loss_D_A: 0.24838754534721375
Epoch: 0, Batch: 20, Loss_G: 2.7318222522735596, Loss_D_A: 0.25245800614356995
Epoch: 0, Batch: 30, Loss_G: 2.937140464782715, Loss_D_A: 0.344274640083313
Epoch: 0, Batch: 40, Loss_G: 2.3008639812469482, Loss_D_A: 0.24993270635604858
Epoch: 0, Batch: 50, Loss_G: 2.1108765602111816, Loss_D_A: 0.24796178936958313
Epoch: 0, Batch: 60, Loss_G: 3.3187010288238525, Loss_D_A: 0.2761794924736023
Epoch: 0, Batch: 70, Loss_G: 2.0487430095672607, Loss_D_A: 0.2466328740119934
Epoch: 0, Batch: 80, Loss_G: 2.0849242210388184, Loss_D_A: 0.2562412917613983
Epoch: 0, Batch: 90, Loss_G: 2.538743257522583, Loss_D_A: 0.2563542127609253
Epoch: 0, Batch: 100, Loss_G: 2.0033926963806152, Loss_D_A: 0.24686303734779358
Epoch: 0, Batch: 110, Loss_G: 2.5875730514526367, Loss_D_A: 0.2730735242366791
Epoch: 0, Batch: 120, Loss_G: 2.3883023262023926, Loss_D_A: 0.

In [28]:
# save_model_path='models'
# checkpoint_path = os.path.join(save_model_path, "Generator.ckpt")
# torch.save(G_AA.state_dict(), checkpoint_path)
# print("Model saved at %s" % checkpoint_path)

Model saved at models\Generator.ckpt


In [29]:
G_AA = Generator(input_nc=3, output_nc=3).to(device)

In [30]:
G_AA.load_state_dict(torch.load("models\Generator.ckpt"))
G_AA.eval()

Generator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (8): ReLU(inplace=True)
    (9): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1,

In [31]:
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # 根据你的模型调整尺寸
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image).unsqueeze(0)  # 添加批次维度
    return image

In [40]:
def predict(model, image_path1,image_path2=None):
    image1 = load_image(image_path1).to(device)
    image2 = load_image(image_path2).to(device)

    model= model.to(device)
    for i in range(10):
        with torch.no_grad():  # 不计算梯度
            image=image1*(i/10)+image2*(1-i/10)
            output = model(image)
            vutils.save_image(output, f'test {i}.png', normalize=True)
    return output

In [41]:
output_image = predict(G_AA, 'human_dog\\testA\\200600.jpg','human_dog\\testB\\flickr_dog_000043.jpg')
# output_image = output_image - output_image.min()
# output_image = output_image / output_image.max()

# output_image = output_image.squeeze()  # 假设输出是图像格式，调整通道
# output_image = output_image.permute(1,2,0)
# output_image=output_image.to('cpu')
# # 步骤 5: 可视化输出图像
# plt.imshow(output_image.numpy())
# plt.title('Output Image')
# plt.show()

In [38]:
print(output_image.shape)

torch.Size([1, 3, 256, 256])
