# 对一个数字图片进行GAN训练

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import random
from torchvision import transforms,datasets

## 定义超参数

In [None]:
batch_size = 64
n_epoch = 1000
lr = 2e-4
tar_domain = './data'
nlen = 100 #噪声的长度
kstep = 5 # 小迭代的次数

## 加载数据

In [None]:
def load_data(tar_domain, batch_size):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5)
    ])
    data = datasets.MNIST(root = tar_domain, train = True, transform = transform, download = True)
    # data = datasets.ImageFolder(root = tar_domain, transform = transform)
    data_loader = torch.utils.data.DataLoader(
        data, batch_size = batch_size, shuffle = True, drop_last = False)
    return data_loader

In [None]:
# 设置随机数种子，保证每次运行时结果不变
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
    
targetDataset = load_data(tar_domain, batch_size)    

In [None]:
# 测试读取
imgs, labels = next(iter(targetDataset))
print(labels)
print(labels.shape)

In [None]:
# 读取前6个图片并显示
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(imgs[i][0], cmap='gray', interpolation='none')#子显示
    plt.title("Truth value:{}".format(labels[i]))  #显示title

## 生成器

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(nn.Linear(100,256),
                                  nn.ReLU(0.1),
                                  nn.Linear(256,512),
                                  nn.ReLU(),
                                  nn.Linear(512,28*28),
                                  nn.Tanh()
                                 )
    def forward(self, x):    # x表示长度为100的noise输入
        img = self.main(x)
        img = img.view(-1,28,28,1)
        return img

## 判别器

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(nn.Linear(28*28,512),
                                  nn.LeakyReLU(0.1),
                                  nn.Linear(512,256),
                                  nn.ReLU(),
                                  nn.Linear(256,1),
                                  nn.Sigmoid()
                                 )
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x

## 初始化模型。优化器及损失函数

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
G = Generator().to(device)
D = Discriminator().to(device)

Doptimizer = torch.optim.Adam(D.parameters(), lr)
Goptimizer = torch.optim.Adam(G.parameters(), lr)

# 损失函数
loss = torch.nn.BCELoss()

# 可以用来保存每次迭代的loss
D_loss = []
G_loss = []
D_prob = []
G_prob = []

## 训练模型

In [None]:
import time
since = time.time()
learnTarget = imgs[0][0]
plt.imshow(learnTarget)
img = learnTarget.to(device)

In [None]:
### 训练过程
for epoch in range(n_epoch):
    Dloss = 0
    Gloss = 0
    real = 0
    fake = 0
    # 这里的k最大取值是一个超参数
    for k in range(kstep):
        real = D(img)  # 是一个概率
        realDloss = loss(real, torch.ones_like(real))
        # -log(p(x))
        Doptimizer.zero_grad()
        realDloss.backward(retain_graph=True)
        randomnoise = torch.randn(6, nlen, device = device)
        Gimg = G(randomnoise)
        fake = D(Gimg.detach())
        fakeDloss = loss(fake, torch.zeros_like(fake))
        # -log(1-p(g(z)))
        fakeDloss.backward()
        dloss = realDloss + fakeDloss
        # -(log(p(x)) + log(1-p(g(z))))
        Doptimizer.step()

        Goptimizer.zero_grad()
        fake = D(Gimg)
        gloss = loss(fake, torch.ones_like(fake))
        # -log(p(g(z)))
        gloss.backward()
        Goptimizer.step()

        with torch.no_grad():
            Dloss += dloss
            Gloss += gloss

    with torch.no_grad():
        D_loss.append(dloss)
        G_loss.append(gloss)
        D_prob.append(real)
        G_prob.append(fake[0])
        # print(Dloss,Gloss)
        time_eplased = time.time() - since
        print('Time elapsed {:.0f}m {:.0f}s)'.format(time_eplased // 60, time_eplased % 60))
        print('epoch',epoch,'dloss',dloss,'gloss',gloss)
        Geneimg = np.squeeze(G(randomnoise).detach().cpu().numpy())  
        # 将数据传至cpu并显示
        # 显示每个batchsize的前6张
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2,3,i+1)
            plt.tight_layout()
            nimg = (Geneimg[i]+1)/2
            plt.imshow(nimg, cmap='gray', interpolation='none')#子显示
        plt.show()
        showdata = Geneimg[0]
        showda = [showdata.reshape(-1),learnTarget.reshape(-1)]
        label = ['G(z)','x']
        plt.hist(showda, bins = 20, label = label)
        plt.legend(loc='upper right')
        plt.show()

## 绘制G,D的损失曲线

In [None]:
D_loss = torch.stack(D_loss)
G_loss = torch.stack(G_loss)

plt.plot(D_loss.detach().cpu().numpy(), c = 'blue')
plt.plot(G_loss.detach().cpu().numpy(), c = 'orange')
plt.axhline(y = 0, ls = ":", c = 'black')
plt.legend(['D_loss','G_loss'], loc='upper left')
plt.show()

## 绘制D(x),D(G(z))的概率曲线

In [None]:
D_prob = torch.stack(D_prob)
G_prob = torch.stack(G_prob)

plt.plot(D_prob.view(-1).detach().cpu().numpy(), c = 'blue')
plt.plot(G_prob.view(-1).detach().cpu().numpy(), c = 'orange')
plt.axhline(y = 0.5, ls = ":", c = 'black')
plt.legend(['D_prob','G_prob'], loc='center right')
plt.show()

## GPT示例

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

In [None]:
# Define the generator network
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 784),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self.img_shape = img_shape

    def forward(self, img):
        img = img.view(img.size(0), -1)
        validity = self.model(img)
        return validity

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
lr = 0.0002
batch_size = 64
epochs = 200

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, num_workers = 4)

# Initialize the generator and discriminator
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)

# Define loss function and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1)
        fake = torch.zeros(imgs.size(0), 1)

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Print progress
        if i % 100 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

    # Save generated images
    if epoch % 10 == 0:
        os.makedirs("image", exist_ok=True)
        save_image(gen_imgs.data[:25], f"image/{epoch}.png", nrow=5, normalize=True)



## 改良版num6

In [None]:
## 优化model架构后的
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
import random
import numpy as np

In [None]:
# 设置随机数种子，保证每次运行时结果不变
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  

In [None]:
# Define the generator network
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 784),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self.img_shape = img_shape

    def forward(self, img):
        img = img.view(img.size(0), -1)
        validity = self.model(img)
        return validity

In [None]:
# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
lr = 0.0002
batch_size = 64
epochs = 300 # you can set 2000

In [None]:
# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, num_workers = 4)

In [None]:
# Initialize the generator and discriminator
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

# Define loss function and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

In [None]:
import time
since = time.time()
G_loss = []
D_loss = []
G_prob = []
D_prob = []

In [None]:
## 先G后D
# Training loop
for epoch in range(epochs):
    g_loss = 0
    d_loss = 0
    g_prob = 0
    d_prob = 0
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        # valid = torch.ones(imgs.size(0), 1) 
        # fake = torch.zeros(imgs.size(0), 1)
        imgs = imgs.to(device)
        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim, device = device)
        
        gen_imgs = generator(z)
        g_prob = discriminator(gen_imgs)
        g_loss = adversarial_loss(g_prob, torch.ones_like(g_prob))
        g_loss.backward()
        optimizer_G.step()
        
        # Train Discriminator
        optimizer_D.zero_grad()
        d_prob = discriminator(imgs)
        real_loss = adversarial_loss(d_prob, torch.ones_like(d_prob))
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), torch.zeros_like(d_prob))
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        # Print progress
        if i % 300 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
    
    print(f'[Epoch{epoch}/{epochs}]')
    time_eplased = time.time() - since
    print('Time elapsed {:.0f}m {:.0f}s)'.format(time_eplased // 60, time_eplased % 60))
    G_loss.append(g_loss.item())
    D_loss.append(d_loss.item())
    G_prob.append(g_prob[0])
    D_prob.append(d_prob[0])
    if epoch%10 == 0:
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2,3,i+1)
            plt.tight_layout()
            nimg = (gen_imgs[i]+1)/2
            plt.imshow(nimg.detach().cpu().numpy()[0], cmap='gray', interpolation='none')#子显示
        plt.show()

In [None]:
# type(D_loss)
# D_loss = torch.stack(D_loss)
# G_loss = torch.stack(G_loss)
import numpy as np
plt.plot(np.array(D_loss), c = 'blue')
plt.plot(np.array(G_loss), c = 'orange')
plt.axhline(y = 0, ls = ":", c = 'black')
plt.legend(['D_loss','G_loss'], loc='upper left')
plt.show()

In [None]:
D_prob = torch.stack(D_prob)
G_prob = torch.stack(G_prob)
plt.plot(D_prob.view(-1).detach().cpu().numpy(), c = 'blue')
plt.plot(G_prob.view(-1).detach().cpu().numpy(), c = 'orange')
plt.axhline(y = 0.5, ls = ":", c = 'black')
plt.legend(['D_prob','G_prob'], loc='center right')
plt.show()

In [None]:
## 先D后G
# Training loop
for epoch in range(epochs):
    g_loss = 0
    d_loss = 0
    g_prob = 0
    d_prob = 0
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        # valid = torch.ones(imgs.size(0), 1)
        # fake = torch.zeros(imgs.size(0), 1)
        imgs = imgs.to(device)
        z = torch.randn(imgs.size(0), latent_dim, device = device)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        d_prob = discriminator(imgs)
        real_loss = adversarial_loss(d_prob, torch.ones_like(d_prob))
        gen_imgs = generator(z)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), torch.zeros_like(d_prob))
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        
        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim, device = device)
        
        gen_imgs = generator(z)
        g_prob = discriminator(gen_imgs)
        g_loss = adversarial_loss(g_prob, torch.ones_like(g_prob))
        g_loss.backward()
        optimizer_G.step()
        
        # Print progress
        if i % 300 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
    
    print(f'[Epoch{epoch}/{epochs}]')
    time_eplased = time.time() - since
    print('Time elapsed {:.0f}m {:.0f}s)'.format(time_eplased // 60, time_eplased % 60))
    G_loss.append(g_loss.item())
    D_loss.append(d_loss.item())
    G_prob.append(g_prob[0])
    D_prob.append(d_prob[0])
    if epoch%10 == 0:
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2,3,i+1)
            plt.tight_layout()
            nimg = (gen_imgs[i]+1)/2
            plt.imshow(nimg.detach().cpu().numpy()[0], cmap='gray', interpolation='none')#子显示
        plt.show()

In [None]:
# type(D_loss)
# D_loss = torch.stack(D_loss)
# G_loss = torch.stack(G_loss)
import numpy as np
plt.plot(np.array(D_loss), c = 'blue')
plt.plot(np.array(G_loss), c = 'orange')
plt.axhline(y = 0, ls = ":", c = 'black')
plt.legend(['D_loss','G_loss'], loc='upper left')
plt.show()

In [None]:
D_prob = torch.stack(D_prob)
G_prob = torch.stack(G_prob)
plt.plot(D_prob.view(-1).detach().cpu().numpy(), c = 'blue')
plt.plot(G_prob.view(-1).detach().cpu().numpy(), c = 'orange')
plt.axhline(y = 0.5, ls = ":", c = 'black')
plt.legend(['D_prob','G_prob'], loc='center right')
plt.show()

In [None]:
## 保存模型
# torch.save(generator.state_dict(), './data/generator.pth')
# torch.save(discriminator.state_dict(), './data/discriminator.pth')
torch.save(generator, './data/ModelG.pth')
torch.save(discriminator, './data/ModelD.pth')

## 使用模型参数或者模型所有数据来进行验证

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt

In [None]:
# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
lr = 0.0002
batch_size = 64
epochs = 300 # you can set 2000

In [None]:
# Define the generator network
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 784),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self.img_shape = img_shape

    def forward(self, img):
        img = img.view(img.size(0), -1)
        validity = self.model(img)
        return validity

In [None]:
# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Initialize the generator and discriminator
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator(latent_dim, img_shape).to(device)
generator = torch.load('./data/ModelG.pth') # use model parameters and model structure
generator.eval()
print(generator)
discriminator = Discriminator(img_shape).to(device)
discriminator = torch.load('./data/ModelD.pth')
discriminator.eval()
print(discriminator)

In [None]:
import time
since = time.time()
imgs, labels = next(iter(dataloader))
print(imgs.shape, labels[0])
img = imgs[0][0]
plt.imshow(img, cmap='gray')
plt.show()

In [None]:
imgs = imgs.to(device)
real = discriminator(imgs)
print(real.shape)
plt.plot(real.detach().cpu().numpy())
plt.axhline(y = 0.5, ls = ':', color = 'gray')
plt.show()
import numpy as np
print('average prob:',np.sum(real.detach().cpu().numpy())/64)

In [None]:
nlen = 100
randomnoise = torch.randn(64, nlen, device = device)
Gimg = generator(randomnoise)
prob = discriminator(Gimg)
prob = prob.detach().cpu().numpy()
print(prob.shape)
Gimg = Gimg.detach().cpu().numpy()

fig = plt.figure()
for i in range(12):
    plt.subplot(3,4,i+1)
    plt.tight_layout()
    plt.title(prob[i])
    nimg = (Gimg[i]+1)/2
    plt.imshow(nimg[0], cmap='gray', interpolation='none')#子显示
plt.show()

In [None]:
nlen = 100
randomnoise = torch.randn(1000, nlen, device = device)
Gimg = generator(randomnoise)
prob = discriminator(Gimg)
prob = prob.detach().cpu().numpy()
print(prob.shape)
Gimg = Gimg.detach().cpu().numpy()

num = 1
fig = plt.figure()
for i in range(1000):
    plt.subplot(4,4,num)
    plt.tight_layout()
    if prob[i] > 0.7:
        num = num +1
        plt.title(prob[i])
        nimg = (Gimg[i]+1)/2
        plt.imshow(nimg[0], cmap='gray', interpolation='none')#子显示
    if num > 16:
        break
plt.show()