In [None]:
import numpy as np
import pandas as pd
import os
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torchvision import transforms
from PIL import Image
import gzip
%matplotlib inline

In [None]:
# 加载数据
def load_data(data_file):
    files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
    paths = []
    for fileName in files:
        paths.append(os.path.join(data_file, fileName))
        
    # 读取每个文件夹的数据    
    with gzip.open(paths[0], 'rb') as train_labels_path:   # rb 以2进制格式打开
        train_labels = np.frombuffer(train_labels_path.read(), np.uint8, offset=8)
      
    with gzip.open(paths[1], 'rb') as train_images_path:
        train_images = np.frombuffer(train_images_path.read(), np.uint8, offset=16).reshape(len(train_labels), 784)
       
    with gzip.open(paths[2], 'rb') as test_labels_path:
        test_labels = np.frombuffer(test_labels_path.read(), np.uint8, offset=8)
        
    with gzip.open(paths[3], 'rb') as test_images_path:
        test_images = np.frombuffer(test_images_path.read(), np.uint8, offset=16).reshape(len(test_labels), 784)
        
    return train_labels,train_images,test_labels,test_images
 
train_labels,train_images,test_labels,test_images = load_data('C:\\Users\\Fan\\JupyterFile\\data\\MNIST\\raw\\')
print(train_labels.shape)
print(train_images.shape)
print(test_labels.shape)
print(test_images.shape)

# 创建加载器
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image = self.images[idx].reshape(28,28)
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        #label = 1 if self.labels[idx] >=5 else 0
        return image,label
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

batch_size = 100

train_dataset = ImageDataset(images = train_images, labels = train_labels, transform = transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = ImageDataset(images = test_images, labels = test_labels, transform = transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
def show_image(tensor):
    tensor=torch.clamp(tensor,0,1)
    tensor = tensor.detach().squeeze().numpy()
    plt.imshow(tensor)
    plt.axis('off')  # 关闭坐标轴
    plt.show()
def show_images_all(imgs):
    n = imgs.size(0)  # 获取图片数量
    rows = int(np.sqrt(n))  # 确定方形图中行数和列数
    cols = n // rows
    
    # 创建一个新的图像
    fig, axs = plt.subplots(rows, cols, figsize=(8, 8))
    axs = axs.flatten()
    
    for i in range(n):
        img = imgs[i].detach().numpy()
        axs[i].imshow(img, cmap='gray')
        axs[i].axis('off')  # 关闭坐标轴
    
    plt.tight_layout()
    plt.show()

In [None]:
# 模型定义
# 初始化权重函数
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:   # 没有则返回-1，有则返回0
        nn.init.normal_(m.weight.data, 0.0, 0.02)    # 使用正态分布初始化权重数据，均值0，标准差0.02
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)    # 使用正态分布初始化权重数据，均值1，标准差0.02
        nn.init.constant_(m.bias.data, 0)     # 使用常数初始化偏置项数据
        
# 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100,256)
        self.fc2 = nn.Linear(256,512)
        self.fc3 = nn.Linear(512,1024)
        self.fc4 = nn.Linear(1024,784)
        self.drop = nn.Dropout2d(p=0.3)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.fc1(x))
        x = nn.functional.leaky_relu(self.fc2(x))
        x = nn.functional.leaky_relu(self.fc3(x))
        x = self.fc4(x)
        x = torch.tanh(x)
        
        return x.view(-1,28,28)

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784,1024)
        self.fc2 = nn.Linear(1024,512)
        self.fc3 = nn.Linear(512,256)
        self.fc4 = nn.Linear(256,1)
        self.drop = nn.Dropout2d(p=0.3)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = nn.functional.leaky_relu(self.fc1(x))
        x = self.drop(x)
        x = nn.functional.leaky_relu(self.fc2(x))
        x = self.drop(x)
        x = nn.functional.leaky_relu(self.fc3(x))
        x = self.drop(x)
        x = torch.sigmoid(self.fc4(x))
        
        return x


In [None]:
# 创建生成器和鉴别器
netG = Generator()
netD = Discriminator()

# 应用权重初始化
netG.apply(weights_init)
netD.apply(weights_init)

manual_seed = 886
torch.manual_seed(manual_seed)

beta1 = 0.5
lr = 0.0002
# 定义损失函数和优化器
criterion = nn.BCELoss()     # 无法使用？
#criterion = nn.CrossEntropyLoss()
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr)
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr)

DLoss=[]
GLoss=[]
# 模型训练
epochs = 10
for epoch in range(epochs):
    for i, (real_images,_) in enumerate(train_loader):
        train_size = real_images.shape[0]
        real_labels = torch.ones(train_size, 1)
        fake_labels = torch.zeros(train_size, 1)
        # 生成假图
        z = torch.randn(train_size, 100)
        fake_images = netG(z)
        # 更新鉴别器D
        netD.zero_grad()
        d_real = netD(real_images)
        d_real_loss = criterion(d_real, real_labels)

        d_fake = netD(fake_images)
        d_fake_loss = criterion(d_fake, fake_labels)
        d_real_loss.backward()
        d_fake_loss.backward()
        d_loss = (d_real_loss + d_fake_loss) / 2
        optimizerD.step()
        DLoss.append(d_loss)
        
        # 更新生成器G
        netG.zero_grad()
        fake_images = netG(z)
        d_fake = netD(fake_images)
        g_loss = criterion(d_fake, real_labels)
        g_loss.backward()
        optimizerG.step()
        GLoss.append(g_loss)
        
        if i % 100 == 0:
            print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(train_loader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
            show_images_all(fake_images)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(DLoss, label='D Loss')
plt.plot(GLoss, label='G Loss')
plt.show()