## (1) import 

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST #Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) #set for testing purposes, please do not change! 

lr = 0.1
device = 'cpu'


## (2) generator

### (2-1) generator block
* input_dim 과 output_dim을 파라미터로 받음
* linear layer 와 batch norm, ReLU 함수로 구성 

In [2]:
def gen_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True),
    )

### generator 의 구조
* 4개의 generator block와 FC layer, sigmoid함수로 구성
* 28x28 해상도의 MNIST dataset 생성 
- 입력: z_dim =10
- 출력: im_dim =784 

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator,self).__init__()
        #Build the neural network
        self.gen = nn.Sequential(
            gen_block(z_dim,hidden_dim),
            gen_block(hidden_dim,hidden_dim * 2),
            gen_block(hidden_dim * 2,hidden_dim * 4),
            gen_block(hidden_dim * 4,hidden_dim * 8),
            
            nn.Linear(hidden_dim *8, im_dim),
            nn.Sigmoid()
        )

## (3) discriminator

### (3-1) discriminator block
* input_dim 과 outpu_dim 을 parameter 로 받음
* linear layer 와 ReLU 함수로 구성

In [4]:
def discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2,inplace=True))

### (3-2)discriminator 의 구조
* 3개의 discriminator block와 FC layer로 구성
* 28x28 해상도의 MNIST dataset를 처리하는 discriminator 

- 입력: im_dim = 784
- 출력: 1

In [5]:
class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            discriminator_block(im_dim, hidden_dim *4),
            discriminator_block(hidden_dim * 4, hidden_dim *2),
            discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim,1)
        )
        
        def forward(self, image):
            return self.disc(image)

        def get_disc(self):
            return self.disc




### (4) 기타

### (4-1) 노이즈 함수

In [6]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device = device)

### (4-2) 파라미터 셋업

In [7]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cpu'

### (4-3) 데이터 로딩

In [8]:
dataloader = DataLoader(
    MNIST(".", download=True, transform=transforms.ToTensor()),
    batch_size = batch_size,
    shuffle = True)

## (5) loss함수

### (5-1)optimizer

In [9]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(),lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(),lr=lr)

### (5-2) disc loss

In [11]:
def get_disc_loss(gen,disc,criterion,real,num_images, z_dim, devices):
    fake_noise = get_noise(num_images,z_dim,device=device) # z
    fake = gen(fake_noise) #G(z)
    disc_fake_pred = disc(fake.detach()) #D(G(z))
    #compare fake_pred &zero
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    
    return disc_loss


    

### (5-3) gen loss

In [12]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    # compare fake_pred & ones
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))

    return gen_loss


## (6) image display

In [14]:
def show_tensor_images(image_tensor, num_images = 25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unfloat[:num_images], nrow = 5)
    plt.show(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

## (7) training

In [15]:
cur_step =0
mean_generator_loss =0
mean_discriminator_loss=0
test_generator= True
gen_loss=False
error=False

#image_tensor = the imagesto show

for epoch in range(n_epochs):
    for real, _in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)
        
        # update discriminator
        disc_opt.zero_grad()
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, )
        disc_loss.backward(return_graph = True)
        disc_opt.step()
        
        if test_generator:
        
    

SyntaxError: invalid syntax (<ipython-input-15-1fc056f81a09>, line 11)