In [1]:
import torch
import torch.nn as nn
import torchvision

In [2]:
dataset = torchvision.datasets.MNIST(root='./data/', train=True, download=True,
                                     transform=torchvision.transforms.Compose([
                                         #torchvision.transforms.Resize(28),
                                         torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize(mean=[0.5],std=[0.5])
                                     ]))
for i in range(len(dataset)):
    if i<=5:
        print(dataset[i][0].shape)
    else:
        break

torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])


In [3]:
image_size = torch.tensor([1, 28, 28],dtype=torch.int32)
image_size

tensor([ 1, 28, 28], dtype=torch.int32)

In [4]:
image_size = torch.tensor([1, 28, 28])

class Generator(nn.Module):
    def __init__(self, in_dim):
        super(Generator, self).__init__()
         
        self.model = nn.Sequential(
            nn.Linear(in_dim, 64),
            torch.nn.ReLU(inplace=True),
            nn.Linear(64, 128),
            torch.nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            torch.nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            torch.nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            torch.nn.ReLU(inplace=True),
            nn.Linear(1024, torch.prod(image_size, dtype=torch.int32)),
            nn.Tanh(),
        )
        
    def forward(self, z):
        # shape of z:[bsz, in_dim]
        # shape of output:[bsz, 1, 28, 28]
        out = self.model(z)
        image = out.reshape(z.shape[0], *image_size)
        return image
    

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
         
        self.model = nn.Sequential(
            nn.Linear(torch.prod(image_size, dtype=torch.int32), 1024),
            torch.nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            torch.nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            torch.nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            torch.nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, image):
        # shape of image:[bsz, 1, 28, 28]
        prob = self.model(image.reshape(image.shape[0], -1))
        return prob

In [6]:
latent_dim = 64
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True,drop_last=True)
generator = Generator(latent_dim)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001)
discriminator = Discriminator()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
loss_fn = nn.BCELoss()

In [7]:
# dataloader在进入epoch循环之前shuffle=True，每个epoch加载的数据也会打乱
for epoch in range(3):
    k = 0
    for i, batch in enumerate(dataloader):
        if k < 1:
            print(batch[1])
        else:
            break
        k += 1

tensor([7, 2, 6, 0, 7, 8, 1, 1, 4, 7, 5, 3, 2, 5, 1, 3, 3, 1, 6, 2, 0, 0, 8, 8,
        7, 1, 9, 6, 3, 9, 8, 4, 8, 7, 5, 1, 8, 1, 9, 1, 6, 8, 2, 6, 3, 3, 8, 7,
        1, 9, 4, 2, 3, 0, 4, 0, 7, 7, 3, 1, 5, 1, 7, 2])
tensor([9, 4, 4, 9, 4, 2, 5, 1, 5, 4, 8, 1, 6, 8, 4, 8, 9, 3, 2, 6, 7, 6, 4, 5,
        1, 2, 7, 8, 6, 4, 5, 5, 2, 6, 4, 5, 7, 0, 8, 1, 3, 6, 1, 6, 2, 2, 4, 5,
        2, 7, 1, 2, 8, 3, 5, 2, 0, 4, 9, 8, 7, 8, 9, 1])
tensor([2, 1, 9, 1, 6, 0, 2, 8, 0, 5, 6, 4, 6, 2, 7, 8, 8, 8, 8, 3, 8, 7, 5, 5,
        8, 5, 1, 1, 5, 1, 2, 0, 4, 7, 2, 7, 0, 3, 1, 8, 1, 1, 2, 7, 8, 2, 1, 3,
        8, 3, 6, 4, 3, 3, 0, 6, 2, 9, 4, 1, 5, 3, 7, 6])


In [None]:

for epoch in range(100):
    for i, batch in enumerate(dataloader):
        gt_images, _ = batch
        z = torch.randn(64, latent_dim)
        pred_images = generator(z)
        
        discriminator.eval()
        g_optimizer.zero_grad()
        g_loss = loss_fn(discriminator(pred_images), torch.ones(64, 1))
        g_loss.backward()
        g_optimizer.step()
        
        discriminator.train()
        d_optimizer.zero_grad()
        d_loss = 0.5*(loss_fn(discriminator(gt_images), torch.ones(64, 1)) + loss_fn(discriminator(pred_images.detach()), torch.zeros(64, 1)))
        d_loss.backward()
        d_optimizer.step()
        if(epoch%10==0 and i%500==0):
            torchvision.utils.save_image(pred_images[0], f"image_epoch{epoch}_step_{i}.png")
        