In [1]:
import torch
from torch import nn
import torchvision.utils as vutils
from torch.optim import Adam
from torch.utils.data import DataLoader,Dataset
import glob
from torchvision import transforms, datasets
from torchvision.utils import save_image
import os
import random
from PIL import Image
import itertools
import matplotlib.pyplot as plt
from torchvision.models import vgg19, resnet18
from torch.autograd import Variable

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=6):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=3, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [nn.Conv2d(64, output_nc, kernel_size=7, padding=3), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [3]:
class Discriminator_patch(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator_patch, self).__init__()

        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(256, 512, kernel_size=4, padding=1, bias=False),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [nn.Conv2d(512, 1, kernel_size=4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)
    
class Discriminator_classify(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator_classify, self).__init__()

        # 之前的卷积层保持不变
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # 添加一个全局平均池化层
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        # 最后一个卷积层，将特征图压缩为1个值
        self.final_conv = nn.Conv2d(512, 1, kernel_size=1)

        # 选择性添加，如果您需要输出概率
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        x = self.global_avg_pool(x)
        x = self.final_conv(x)
        x = self.sigmoid(x)  # 如果您需要输出概率
        return x.view(-1)  # 改变输出形状以匹配期望的输出形

In [4]:
transform = transforms.Compose([
    transforms.Resize(int(128*1.12), Image.BICUBIC),
    transforms.RandomCrop(128),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [5]:
class ImageDataset(Dataset):
    def __init__(self, root='data', transforms_=None, unaligned=False, mode="train"):          ## (root = "./datasets/facades", unaligned=True:非对其数据)
        self.transform = transforms_                             ## transform变为tensor数据
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))     ## "./datasets/facades/trainA/*.*"
        self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))     ## "./datasets/facades/trainB/*.*"

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])                   ## 在A中取一张照片

        if self.unaligned:                                                              ## 如果采用非配对数据，在B中随机取一张
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # # 如果是灰度图，把灰度图转换为RGB图
        # if image_A.mode != "RGB":
        #     image_A = to_rgb(image_A)
        # if image_B.mode != "RGB":
        #     image_B = to_rgb(image_B)
        
        # 把RGB图像转换为tensor图, 方便计算，返回字典数据
        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return item_A, item_B

    ## 获取A,B数据的长度
    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G_AB = Generator(input_nc=3, output_nc=3).to(device)
G_BA = Generator(input_nc=3, output_nc=3).to(device)
D_A = Discriminator_classify(input_nc=3).to(device)
D_B = Discriminator_classify(input_nc=3).to(device)

# Optimizers
optimizer_G = Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0005, betas=(0.5, 0.999))
optimizer_D_A = Adam(D_A.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_B = Adam(D_B.parameters(), lr=0.0005, betas=(0.5, 0.999))

# Losses
criterion_GAN = nn.BCELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()


output_dir = './cyclegan_images'
os.makedirs(output_dir, exist_ok=True)

### GAN Inversion

Train the model to convert A to A. Then interpolation can be used to generate intermediate image. 

Using contrastive loss to make the model to learn?

In [7]:
class ImageDataset(Dataset):
    def __init__(self, root='data', transforms_=None, unaligned=False, mode="train"):          ## (root = "./datasets/facades", unaligned=True:非对其数据)
        self.transform = transforms_                             ## transform变为tensor数据
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + "/*.*"))     ## "./datasets/facades/trainA/*.*"
        self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + "/*.*"))     ## "./datasets/facades/trainB/*.*"

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])                   ## 在A中取一张照片

        if self.unaligned:                                                              ## 如果采用非配对数据，在B中随机取一张
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # # 如果是灰度图，把灰度图转换为RGB图
        # if image_A.mode != "RGB":
        #     image_A = to_rgb(image_A)
        # if image_B.mode != "RGB":
        #     image_B = to_rgb(image_B)
        
        # 把RGB图像转换为tensor图, 方便计算，返回字典数据
        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return item_A, item_B

    ## 获取A,B数据的长度
    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


In [8]:
BATCH_SIZE = 16
train_dataloader = DataLoader(        ## 改成自己存放文件的目录
    ImageDataset("human_dog_colab", transforms_=transform, unaligned=True, mode="train"),  ## "./datasets/facades" , unaligned:设置非对其数据
    batch_size=BATCH_SIZE,                                                                  ## batch_size = 1
    shuffle=True,
)
test_dataloader = DataLoader(        ## 改成自己存放文件的目录
    ImageDataset("human_dog_colab", transforms_=transform, unaligned=True, mode="test"),  ## "./datasets/facades" , unaligned:设置非对其数据
    batch_size=BATCH_SIZE,                                                                  ## batch_size = 1
    shuffle=False,
)


In [9]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, bias=False)
        self.norm1 = nn.InstanceNorm2d(in_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, bias=False)
        self.norm2 = nn.InstanceNorm2d(in_features)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        return identity + out

class DownsampleBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(DownsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1, bias=False)
        self.norm = nn.InstanceNorm2d(out_features)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)
        return x

class UpsampleBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.norm = nn.InstanceNorm2d(out_features)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)
        return x

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()
        self.initial_conv = nn.Sequential(
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=3, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.downsample_blocks = nn.ModuleList([
            DownsampleBlock(64, 128),
            DownsampleBlock(128, 256)
        ])
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(n_residual_blocks)]
        )
        self.upsample_blocks = nn.ModuleList([
            UpsampleBlock(256, 128),
            UpsampleBlock(128, 64)
        ])
        self.output_conv = nn.Sequential(
            nn.Conv2d(64, output_nc, kernel_size=7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial_conv(x)
        # print("1",x.shape)
        for down_block in self.downsample_blocks:
            x = down_block(x)
        # print("2",x.shape)
        x = self.residual_blocks(x)
        # print("3",x.shape)
        for up_block in self.upsample_blocks:
            x = up_block(x)
        # print("4",x.shape)
        x = self.output_conv(x)
        # print("5",x.shape)
        return x


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G_AB = Generator(input_nc=3, output_nc=3).to(device)
G_BA = Generator(input_nc=3, output_nc=3).to(device)
D_A = Discriminator_classify(input_nc=3).to(device)
D_B = Discriminator_classify(input_nc=3).to(device)

D_A_P = Discriminator_patch(input_nc=3).to(device)
D_B_P = Discriminator_patch(input_nc=3).to(device)

# Optimizers
optimizer_G = Adam(G_AB.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_A = Adam(D_A.parameters(), lr=0.0004, betas=(0.5, 0.999))
optimizer_D_B = Adam(D_B.parameters(), lr=0.0004, betas=(0.5, 0.999))

optimizer_D_A_P = Adam(D_A_P.parameters(), lr=0.0004, betas=(0.5, 0.999))
optimizer_D_B_P = Adam(D_B_P.parameters(), lr=0.0004, betas=(0.5, 0.999))

G_AB.train()
G_BA.train()
D_A.train()
D_B.train()
D_A_P.train()
D_B_P.train()

# Losses
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()


output_dir = './cyclegan_images'
os.makedirs(output_dir, exist_ok=True)

In [11]:
# 加载ResNet模型
resnet = resnet18(pretrained=True)

# 获取全连接层之前的特征提取部分
features = nn.Sequential(*list(resnet.children())[:-1])

# 定义新的全连接层和ReLU激活函数
num_ftrs = resnet.fc.in_features
fc_layer = nn.Linear(num_ftrs, 256)
relu = nn.ReLU(inplace=True)

# 定义模型结构
class CustomResNet(nn.Module):
    def __init__(self, features, fc_layer, relu, num_classes=1000):
        super(CustomResNet, self).__init__()
        self.features = features
        self.fc_layer = fc_layer
        self.relu = relu
        self.fc_out = nn.Linear(256, num_classes)  # 输出层

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)
        x = self.relu(x)
        x = self.fc_out(x)
        return x



In [12]:
resnet = resnet18(pretrained=True)
# num_ftrs = resnet.fc.in_features
# resnet.fc = nn.Linear(num_ftrs, 256)
# resnet.fc = nn.Linear(32168, 65536)
# resnet = torch.load("models\\animal_rec.pth")
# resnet.fc_out = torch.nn.Identity()
# resnet.relu = torch.nn.Identity()
resnet.fc = torch.nn.Identity()
resnet = resnet.to(device)
resnet.eval()
print(resnet)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [13]:


class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        # self.vgg = vgg19(pretrained=True).features[:21]  # 只使用到第三个池化层之前的部分
        self.resnet = resnet
        self.loss = nn.MSELoss()

        # 冻结VGG参数
        for param in self.resnet.parameters():
            param.requires_grad = False

    def forward(self, generated, target):

        self.resnet = self.resnet.to(device)
        gen_features = self.resnet(generated)

        target_features = self.resnet(target)

        return self.loss(gen_features, target_features)

# 实例化感知损失
perceptual_loss = PerceptualLoss()


In [14]:
# resnet1 = torch.load("models\\my_animal_rec.pth")
# resnet1 = resnet1.to(device)
# resnet1.fc_out = torch.nn.Identity()
# resnet1.flatten = torch.nn.Identity()
# resnet1.fc_outpout = torch.nn.Identity()
# resnet1.average_pool1 = torch.nn.Identity()
# resnet1.average_pool2 = torch.nn.Identity()

# resnet1.eval()
# print(resnet1)

In [15]:
for epoch in range(1):
    for i, (real_A, real_B) in enumerate(train_dataloader):
        # print(real_A.shape)
        real_A = real_A.to(device)
        real_B = real_B.to(device)
        valid = torch.ones(real_A.size(0), 1, device=device)
        fake = torch.zeros(real_A.size(0), 1, device=device)

        # 训练生成器 G_A 和 G_B
        optimizer_G.zero_grad()
        
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        
        # 对抗性损失
        fake_B = G_AB(real_A)
        pred_fake = D_B(fake_B)
        # pred_fake_patch = D_B_P(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        # loss_GAN_A2B_Patch = criterion_GAN(pred_fake_patch, torch.ones_like(pred_fake_patch))
        
        fake_A = G_BA(real_B)
        pred_fake = D_A(fake_A)
        # pred_fake_patch = D_A_P(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        # loss_GAN_B2A_Patch = criterion_GAN(pred_fake_patch, torch.ones_like(pred_fake_patch))
        
        # contain domain B
        # generated_B = G_AB(real_B)
        # loss_cycle_BB = criterion_cycle(generated_B, real_B)
        
        
        # 循环一致性损失
        recovered_A = G_BA(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0
        loss_perceptual_ABA = perceptual_loss(fake_A,real_A) *2
        # 在训练循环中使用颜色一致性损失

                
        recovered_B = G_AB(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0
        loss_perceptual_BAB = perceptual_loss(fake_B,real_B) *2


        # 总损失
        loss_G = loss_id_A+loss_id_B+ loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB +loss_perceptual_ABA+loss_perceptual_BAB
        loss_G.backward()
        optimizer_G.step()

        # 训练判别器 D_A
        optimizer_D_A.zero_grad()

        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # 训练判别器 D_B
        optimizer_D_B.zero_grad()

        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()
        
        # # 训练判别器 D_B
        # optimizer_D_B_P.zero_grad()

        # pred_real = D_B_P(real_B)
        # loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        # pred_fake = D_B_P(fake_B.detach())
        # loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        # loss_D_B_P = (loss_D_real + loss_D_fake) * 0.5
        # loss_D_B_P.backward()
        # optimizer_D_B_P.step()
        
        # # 训练判别器 D_B
        # optimizer_D_A_P.zero_grad()

        # pred_real = D_A_P(real_A)
        # loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        # pred_fake = D_A_P(fake_A.detach())
        # loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # # 总损失
        # loss_D_A_P = (loss_D_real + loss_D_fake) * 0.5
        # loss_D_A_P.backward()
        # optimizer_D_A_P.step()
        
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Sample: {i*BATCH_SIZE}, Loss_G: {loss_G.item()}, Loss_D_A: {loss_D_A.item()}, Loss_D_B: {loss_D_B.item()}')

    # with torch.no_grad():
    #     # 使用测试集中的数据生成图像
    #     for i, (real_A, real_B) in enumerate(test_dataloader):
    #         real_A = real_A.to(device)
    #         fake_B = G_AB(real_A)
    #         vutils.save_image(fake_B, f'{output_dir}/fake_B_epoch_{epoch}_batch_{i}.png', normalize=True)
        


Epoch: 0, Batch: 0, Sample: 0, Loss_G: 25.777156829833984, Loss_D_A: 0.2497054934501648, Loss_D_B: 0.25125768780708313
Epoch: 0, Batch: 10, Sample: 160, Loss_G: 23.995515823364258, Loss_D_A: 0.19202102720737457, Loss_D_B: 0.17317256331443787
Epoch: 0, Batch: 20, Sample: 320, Loss_G: 22.434913635253906, Loss_D_A: 0.09354539215564728, Loss_D_B: 0.16209763288497925
Epoch: 0, Batch: 30, Sample: 480, Loss_G: 22.679304122924805, Loss_D_A: 0.060604073107242584, Loss_D_B: 0.1714388132095337
Epoch: 0, Batch: 40, Sample: 640, Loss_G: 22.715530395507812, Loss_D_A: 0.04219002276659012, Loss_D_B: 0.175270676612854
Epoch: 0, Batch: 50, Sample: 800, Loss_G: 21.981338500976562, Loss_D_A: 0.02990548312664032, Loss_D_B: 0.1674671769142151
Epoch: 0, Batch: 60, Sample: 960, Loss_G: 22.20755386352539, Loss_D_A: 0.02948971465229988, Loss_D_B: 0.14459393918514252
Epoch: 0, Batch: 70, Sample: 1120, Loss_G: 21.58616828918457, Loss_D_A: 0.017695844173431396, Loss_D_B: 0.19405043125152588
Epoch: 0, Batch: 80, Sa

KeyboardInterrupt: 

In [None]:
save_model_path='models'
checkpoint_path = os.path.join(save_model_path, "Cycle_GAN_Human2Dog_PerceptualLoss2.ckpt")
# checkpoint_path = os.path.join(save_model_path, "Cycle_GAN_Monet2Photo_PerceptualLoss2.ckpt")
torch.save(G_AB.state_dict(), checkpoint_path)
print("Model saved at %s" % checkpoint_path)

Model saved at models\Cycle_GAN_Human2Dog_PerceptualLoss2.ckpt


In [20]:
G_AB = Generator(input_nc=3, output_nc=3).to(device)

In [21]:
G_AB.load_state_dict(torch.load("models\Cycle_GAN_Human2Dog_PerceptualLoss2.ckpt"))
G_AB.eval()

Generator(
  (initial_conv): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
  )
  (downsample_blocks): ModuleList(
    (0): DownsampleBlock(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (relu): ReLU(inplace=True)
    )
    (1): DownsampleBlock(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (relu): ReLU(inplace=True)
    )
  )
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (norm1): InstanceNorm2d(256, eps=

In [None]:
 for epoch in range(20):
        for real_A, real_B in enumerate:
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # 真实数据标签
            valid = torch.ones(real_A.size(0), 1, device=device)
            fake = torch.zeros(real_A.size(0), 1, device=device)

            # ------------------
            #  训练生成器
            # ------------------
            optimizer_G.zero_grad()

            # 身份损失
            loss_id_A = criterion_identity(G_B(real_A), real_A)
            loss_id_B = criterion_identity(G_A(real_B), real_B)

            # 对抗损失
            fake_B = G_A(real_A)
            loss_G_A = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_B(real_B)
            loss_G_B = criterion_GAN(D_A(fake_A), valid)

            # 循环损失
            recovered_A = G_B(fake_B)
            loss_cycle_A = criterion_cycle(recovered_A, real_A)
            recovered_B = G_A(fake_A)
            loss_cycle_B = criterion_cycle(recovered_B, real_B)

            # 总损失
            loss_G = loss_id_A + loss_id_B + loss_G_A + loss_G_B + 10 * (loss_cycle_A + loss_cycle_B)
            loss_G.backward()
            optimizer_G.step()

            # ------------------
            #  训练判别器 D_A
            # ------------------
            optimizer_D_A.zero_grad()

            loss_real = criterion_GAN(D_A(real_A), valid)
            loss_fake = criterion_GAN(D_A(fake_A.detach()), fake)
            loss_D_A = (loss_real + loss_fake) / 2
            loss_D_A.backward()
            optimizer_D_A.step()

            # ------------------
            #  训练判别器 D_B
            # ------------------
            optimizer_D_B.zero_grad()

            loss_real = criterion_GAN(D_B(real_B), valid)
            loss_fake = criterion_GAN(D_B(fake_B.detach()), fake)
            loss_D_B = (loss_real + loss_fake) / 2
            loss_D_B.backward()
            optimizer_D_B.step()

In [22]:
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # 根据你的模型调整尺寸
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image).unsqueeze(0)  # 添加批次维度
    return image

In [23]:
def predict(model, image1,image2=None):
    # image1 = load_image(image_path1).to(device)
    # image2 = load_image(image_path2).to(device)

    model= model.to(device)
    for i in range(11):
        with torch.no_grad():  # 不计算梯度
            image=image1*(i/10)+image2*(1-i/10)
            output = model(image)
            vutils.save_image(output, f'test8 {i}.png', normalize=True)
    return output

In [25]:
image1 = load_image('human_dog_colab\\testA\\200601.jpg').to(device)
image2 = load_image('human_dog_colab\\testB\\flickr_dog_000043.jpg').to(device)

# image1_feature = resnet1(image1)
# image1_feature = image1_feature.unsqueeze(1).unsqueeze(1).expand(-1,-1,256,256)
# image1_with_feature = torch.cat([image1,image1_feature],dim=1)

# image2_feature = resnet1(image2)
# image2_feature = image2_feature.unsqueeze(1).unsqueeze(1).expand(-1,-1,256,256)
# image2_with_feature = torch.cat([image2,image2_feature],dim=1)

output_image = predict(G_AB,image1 ,image2)

# output_image = output_image - output_image.min()
# output_image = output_image / output_image.max()

# output_image = output_image.squeeze()  # 假设输出是图像格式，调整通道
# output_image = output_image.permute(1,2,0)
# output_image=output_image.to('cpu')
# # 步骤 5: 可视化输出图像
# plt.imshow(output_image.numpy())
# plt.title('Output Image')
# plt.show()

In [27]:
print(output_image.shape)

torch.Size([1, 3, 256, 256])


### Original Cycle GAN

In [23]:
for epoch in range(4):
    for i, (real_A, real_B) in enumerate(train_dataloader):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # 训练生成器 G_A 和 G_B
        optimizer_G.zero_grad()
        
        # 对抗性损失
        fake_B = G_AB(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        
        fake_A = G_BA(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        
        # 循环一致性损失
        recovered_A = G_BA(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0
        
        recovered_B = G_AB(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

        # 总损失
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()

        # 训练判别器 D_A
        optimizer_D_A.zero_grad()

        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # 训练判别器 D_B
        optimizer_D_B.zero_grad()

        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()
        
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Loss_G: {loss_G.item()}, Loss_D_A: {loss_D_A.item()}, Loss_D_B: {loss_D_B.item()}')

    with torch.no_grad():
        # 使用测试集中的数据生成图像
        for i, (real_A, real_B) in enumerate(test_dataloader):
            real_A = real_A.to(device)
            fake_B = G_AB(real_A)
            vutils.save_image(fake_B, f'{output_dir}/fake_B_epoch_{epoch}_batch_{i}.png', normalize=True)
        


Epoch: 0, Batch: 0, Loss_G: 15.61037826538086, Loss_D_A: 1.1615573167800903, Loss_D_B: 0.9101475477218628
Epoch: 0, Batch: 10, Loss_G: 6.970175743103027, Loss_D_A: 0.3666701912879944, Loss_D_B: 0.24359501898288727
Epoch: 0, Batch: 20, Loss_G: 5.965689659118652, Loss_D_A: 0.24649454653263092, Loss_D_B: 0.24288181960582733
Epoch: 0, Batch: 30, Loss_G: 5.182547569274902, Loss_D_A: 0.24978424608707428, Loss_D_B: 0.2856913208961487
Epoch: 0, Batch: 40, Loss_G: 7.472266674041748, Loss_D_A: 0.3758556544780731, Loss_D_B: 0.32502102851867676
Epoch: 0, Batch: 50, Loss_G: 7.623085975646973, Loss_D_A: 0.23080357909202576, Loss_D_B: 0.24449694156646729
Epoch: 0, Batch: 60, Loss_G: 6.9660868644714355, Loss_D_A: 0.26649582386016846, Loss_D_B: 0.23975417017936707
Epoch: 0, Batch: 70, Loss_G: 5.563512325286865, Loss_D_A: 0.24589692056179047, Loss_D_B: 0.2512208819389343
Epoch: 0, Batch: 80, Loss_G: 6.209434986114502, Loss_D_A: 0.30268868803977966, Loss_D_B: 0.2863529324531555


In [None]:
for epoch in range(20):
    for i, (real_A, real_B) in enumerate(train_dataloader):
        # print(real_A.shape)
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # 训练生成器 G_A 和 G_B
        optimizer_G.zero_grad()
        
        # 对抗性损失
        fake_B = G_AB(real_A)
        pred_fake = D_B(fake_B)
        pred_fake_patch = D_B_P(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        loss_GAN_A2B_Patch = criterion_GAN(pred_fake_patch, torch.ones_like(pred_fake_patch))
        
        fake_A = G_BA(real_B)
        pred_fake = D_A(fake_A)
        pred_fake_patch = D_A_P(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        loss_GAN_B2A_Patch = criterion_GAN(pred_fake_patch, torch.ones_like(pred_fake_patch))
        
        # contain domain B
        # generated_B = G_AB(real_B)
        # loss_cycle_BB = criterion_cycle(generated_B, real_B)
        
        
        # 循环一致性损失
        recovered_A = G_BA(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0
        loss_perceptual_ABA = perceptual_loss(fake_A,real_A) *2
        # 在训练循环中使用颜色一致性损失

                
        recovered_B = G_AB(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0
        loss_perceptual_BAB = perceptual_loss(fake_B,real_B) *2


        # 总损失
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB +loss_GAN_A2B_Patch+loss_GAN_B2A_Patch+loss_perceptual_ABA+loss_perceptual_BAB
        loss_G.backward()
        optimizer_G.step()

        # 训练判别器 D_A
        optimizer_D_A.zero_grad()

        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # 训练判别器 D_B
        optimizer_D_B.zero_grad()

        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()
        
        # 训练判别器 D_B
        optimizer_D_B_P.zero_grad()

        pred_real = D_B_P(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_B_P(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_B_P = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B_P.backward()
        optimizer_D_B_P.step()
        
        # 训练判别器 D_B
        optimizer_D_A_P.zero_grad()

        pred_real = D_A_P(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_A_P(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # 总损失
        loss_D_A_P = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A_P.backward()
        optimizer_D_A_P.step()
        
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Sample: {i*BATCH_SIZE}, Loss_G: {loss_G.item()}, Loss_D_A: {loss_D_A.item()}, Loss_D_B: {loss_D_B.item()}')

    # with torch.no_grad():
    #     # 使用测试集中的数据生成图像
    #     for i, (real_A, real_B) in enumerate(test_dataloader):
    #         real_A = real_A.to(device)
    #         fake_B = G_AB(real_A)
    #         vutils.save_image(fake_B, f'{output_dir}/fake_B_epoch_{epoch}_batch_{i}.png', normalize=True)
        
