In [1]:
import torch,torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import random,numpy.random
import os
from torchvision.utils import save_image

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.n1 = nn.Sequential(nn.ConvTranspose2d(110,512,4,1,0,bias=False), 
                                 nn.BatchNorm2d(512), 
                                 nn.LeakyReLU(0.2,inplace=True))
        self.n2=nn.Sequential(nn.ConvTranspose2d(512,256,4,2,1,bias=False), 
                                 nn.BatchNorm2d(256), 
                                 nn.LeakyReLU(0.2,inplace=True))
        self.n3=nn.Sequential(nn.ConvTranspose2d(256,256,4,2,1,bias=False), 
                                 nn.BatchNorm2d(256), 
                                 nn.LeakyReLU(0.2,inplace=True))
        self.n4=nn.Sequential(nn.ConvTranspose2d(256,128,4,2,1,bias=False), 
                                 nn.BatchNorm2d(128), 
                                 nn.LeakyReLU(0.2,inplace=True))
        self.n5=nn.Sequential(nn.ConvTranspose2d(128,128,4,2,1,bias=False), 
                                 nn.BatchNorm2d(128), 
                                 nn.LeakyReLU(0.2,inplace=True))
        self.n6=nn.ConvTranspose2d(128,3,4,2,1,bias=False)
        self.n7=nn.Tanh()
    def forward(self, noise,label):
        x = torch.cat((noise, label),dim=1)  #将标签与数据拼接 (N,channels,128,128),(N,n_classes, 128,128)->(N,channels+nc_classes,128,128)
        x=self.n1(x)
        x=self.n2(x)
        x=self.n3(x)
        x=self.n4(x)
        x=self.n5(x)
        x=self.n6(x)
        x=self.n7(x)
        return x

In [3]:
# 鉴别器结构
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.n1=nn.Sequential(nn.Conv2d(13,128, 4,2,1, bias=False), 
                              nn.BatchNorm2d(128),
                              nn.LeakyReLU(0.2,inplace=True))
        self.n2=nn.Sequential(nn.Conv2d(128,128, 4,2,1, bias=False), 
                              nn.BatchNorm2d(128),
                              nn.LeakyReLU(0.2,inplace=True))
        self.n3=nn.Sequential(nn.Conv2d(128,256, 4,2,1, bias=False), 
                              nn.BatchNorm2d(256),
                              nn.LeakyReLU(0.2,inplace=True))
        self.n4=nn.Sequential(nn.Conv2d(256,256, 4,2,1, bias=False), 
                              nn.BatchNorm2d(256),
                              nn.LeakyReLU(0.2,inplace=True))
        self.n5=nn.Sequential(nn.Conv2d(256,512, 4,2,1, bias=False), 
                              nn.BatchNorm2d(512),
                              nn.LeakyReLU(0.2,inplace=True))
        self.n6=nn.Sequential(nn.Conv2d(512,1, 4,1,0, bias=False))
        self.n7=nn.Flatten()     #(N,1)
        self.n8=nn.Sigmoid()
  
    def forward(self, img, label):
        x = torch.cat((img, label),dim=1)
        x=self.n1(x)
        x=self.n2(x)
        x=self.n3(x)
        x=self.n4(x)
        x=self.n5(x)
        x=self.n6(x)
        x=self.n7(x)
        x=self.n8(x)
        return x




In [4]:
#加载数据集
my_transform = transforms.Compose([
        transforms.Resize((128,128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

In [5]:
train_dataset = torchvision.datasets.CIFAR10(root='./Data', train=False, download=True,transform=my_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./Data/cifar-10-python.tar.gz


100%|████████████████████████████████████████████| 170498071/170498071 [00:18<00:00, 9124687.65it/s]


Extracting ./Data/cifar-10-python.tar.gz to ./Data


In [6]:
batch_size=64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [7]:

#将标签进行one-hot编码
def to_categrical(y: torch.FloatTensor):
    lb = LabelBinarizer()
    lb.fit(list(range(0,10)))
    y_one_hot = lb.transform(y.cpu())
    floatTensor = torch.FloatTensor(y_one_hot)
    return floatTensor
 

In [8]:
def trainer(batch, netD, netG, optimizerD, optimizerG, loss_func, device):
    # 将模型参数设为训练模式
    netD.train()  # 训练中需要求模型参数的梯度，所有参数都处于可训练模式
    netG.train()
    # 从batch中获取输入数据和标签(不一定有标签)
    x, y = batch
    # 将数据存入对应设备中
    x = x.to(device)
    y = y.to(device)
    
    target = to_categrical(y).unsqueeze(2).unsqueeze(3).float()  #加到噪声上 torch.Size([N, n_classes, 1, 1])
    target = target.to(device)
    label = target.repeat(1, 1, x.size(2), x.size(3))   #加到数据上(N,n_classes,128,128)
    label = label.to(device)
    label_r = torch.full((x.size(0),1), 1.0) # 按照shape，创建一模一样的向量
    label_r=label_r.to(device)
 
    #（1）训练判别器 
    #training real data
    netD.zero_grad()
    output1 = netD(x,label) #将标签与数据拼接 (N,channels,128,128),(N,n_classes, 128,128)->(N,channels+nc_classes,128,128)
    
    loss_D1 = loss_func(output1, label_r)
    
        
    #training fake data,拼接噪声和标签
    noise_z = torch.randn(x.size(0), 100, 1, 1) # (N,噪声向量维度100,1,1)
    noise_z = noise_z.to(device)
    
    fake_data = netG(noise_z,target) # 假数据来自噪声

    label_f = torch.full((x.size(0),1), 0.0) # (N,1)
    label_f = label_f.to(device)
    
    output2 = netD(fake_data.detach(),label) # (N,1)
    loss_D2 = loss_func(output2, label_f)
    
    loss_D=loss_D1+loss_D2
    loss_D.backward()
    
    #更新判别器
    optimizerD.step()
        
    #（2）训练生成器，首先清空梯度
    netG.zero_grad()
    output2 = netD(fake_data,label)   
    loss_G = loss_func(output2, label_r)  # 像真实图像靠近
    loss_G.backward()
        
    #更新生成器
    optimizerG.step()
        
    # 计算准确率
    correct_predictions1 = (output1>0.5).sum().item()  # 统计在正确数据中预测正确的数量
    correct_predictions2 = (output2<0.5).sum().item()  # 统计在假数据中预测正确的数量
    correct_predictions=correct_predictions1+correct_predictions2
    return loss_D.item() / y.shape[0], loss_G.item()/y.shape[0], correct_predictions

In [9]:
def generate_fakeimg(netG,target_label):
    sample_dir = "fakeimg2"
    # 创建生成图像的目录
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    noise_z1 = torch.randn(64, 100, 1, 1).cuda()
    label = to_categrical(torch.full((64,1), target_label)).unsqueeze(2).unsqueeze(3).float() #将标签编码
    label=label.cuda()
    fake_data = netG(noise_z1,label)            
    #保存图片
    data = fake_data.detach().permute(0,2,3,1) # 通道数放最后
    data = data.cpu()
    data = np.array(data)
    #保存单张图片，将数据还原
    data = (data*0.5+0.5) # 缩放，否则某些像素点可能会小于0
    plt.imsave('./fakeimg2.png', data[0])
    torchvision.utils.save_image(fake_data[:64]*0.5+0.5,'./fakeimg2/class_%d.png'%target_label,nrow=8,normalize=True)
            

In [10]:
def save_model(netG,netD):
    sample_dir = "model_cGAN"
    # 创建生成图像的目录
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    state = {'net_G': netG.state_dict(),'net_D': netD.state_dict()}
    torch.save(state, './model_cGAN/net.pth')
    

In [11]:
if __name__ == '__main__':
    netG = Generator().cuda()
    netD = Discriminator().cuda()
    loss_func = torch.nn.BCELoss()
    device = "cuda"
    # setup optimizer
    # optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0002,betas=(0.5, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002,betas=(0.5,0.999))
    optimizerD = torch.optim.RMSprop(netD.parameters(),
                    lr=0.0002,
                    alpha=0.99,
                    eps=1e-08,
                    weight_decay=0,
                    momentum=0,)
    total_ep = 30
    for ep in range(total_ep):
        total_lossG = 0.0
        total_lossD = 0.0
        total_correct = 0
        total_samples = 0
        for batch in train_loader:
            loss_D, loss_G, correct_predictions = trainer(batch, netD, netG, optimizerD, optimizerG, loss_func, device)
            total_lossD += loss_D
            total_lossG += loss_G
            total_correct += correct_predictions
            total_samples += batch[1].shape[0]
            # print(total_lossD,total_lossG)
        average_lossD = total_lossD / len(train_loader)
        average_lossG = total_lossG / len(train_loader)
        accuracy = total_correct / (2*total_samples)
        print(f"Epoch: {ep} Training Loss: {average_lossD:.4f} | {average_lossG:.4f} ｜Accuracy: {accuracy * 100:.2f}%")
    generate_fakeimg(netG,5)   # 生成class5的图片
    save_model(netG,netD)      # 保存模型参数
        

Epoch: 0 Training Loss: 0.0221 | 0.0194 ｜Accuracy: 68.92%
Epoch: 1 Training Loss: 0.0206 | 0.0188 ｜Accuracy: 71.11%
Epoch: 2 Training Loss: 0.0209 | 0.0189 ｜Accuracy: 72.33%
Epoch: 3 Training Loss: 0.0206 | 0.0196 ｜Accuracy: 71.60%
Epoch: 4 Training Loss: 0.0211 | 0.0180 ｜Accuracy: 70.20%
Epoch: 5 Training Loss: 0.0208 | 0.0203 ｜Accuracy: 72.99%
Epoch: 6 Training Loss: 0.0201 | 0.0225 ｜Accuracy: 74.90%
Epoch: 7 Training Loss: 0.0200 | 0.0249 ｜Accuracy: 76.42%
Epoch: 8 Training Loss: 0.0191 | 0.0263 ｜Accuracy: 76.50%
Epoch: 9 Training Loss: 0.0191 | 0.0250 ｜Accuracy: 74.99%
Epoch: 10 Training Loss: 0.0194 | 0.0247 ｜Accuracy: 74.91%
Epoch: 11 Training Loss: 0.0191 | 0.0244 ｜Accuracy: 74.44%
Epoch: 12 Training Loss: 0.0184 | 0.0275 ｜Accuracy: 76.50%
Epoch: 13 Training Loss: 0.0181 | 0.0275 ｜Accuracy: 76.49%
Epoch: 14 Training Loss: 0.0158 | 0.0354 ｜Accuracy: 82.06%
Epoch: 15 Training Loss: 0.0146 | 0.0455 ｜Accuracy: 86.94%
Epoch: 16 Training Loss: 0.0128 | 0.0481 ｜Accuracy: 87.92%
Epoch: 