In [1]:
import import_ipynb
from model import generator, discriminator
from config import(
    EPOCHS_LR,
    BATCH_SIZE,
    NOISE_LENGTH,
    DEVICE,
)
from dataset import train_loader, val_loader
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch import optim, nn

importing Jupyter notebook from model.ipynb
importing Jupyter notebook from config.ipynb
importing Jupyter notebook from dataset.ipynb


In [2]:
G = generator(128).to(DEVICE)
D = discriminator(128).to(DEVICE)

In [3]:
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)

BCE_loss = nn.BCELoss()

In [None]:
for n, (epochs, lr) in enumerate(EPOCHS_LR):
    G_opt = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    D_opt = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    for epoch in range(epochs):
        D.train()
        G.train()
        for i, (img, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
            # Train D
            D_opt.zero_grad()
            mini_batch = img.size()[0]
            # 真假對應Label
            y_real = torch.ones(mini_batch)
            y_fake = torch.zeros(mini_batch)
            img, y_real, y_fake =(
                img.to(DEVICE),
                y_real.to(DEVICE),
                y_fake.to(DEVICE)
            )
            # 真圖片輸入 D
            D_result = D(img).squeeze()
            # 計算真圖片的loss
            D_reak_loss = BCE_loss(D_result, y_real)
            # 新建一個隨機變數
            noise = (
                torch.randn((mini_batch, NOISE_LENGTH)).
                view((-1, NOISE_LENGTH, 1, 1)).
                to(DEVICE)
            )
            # 生假圖片
            img_fake = G(noise)
            # 將假圖片輸入至判別器D
            D_result = D(img_fake).squeeze()
            # 計算假圖片的loss
            D_fake_loss = BCE_loss(D_result, y_fake)
            # 真假圖片的判別器loss相加之和反向傳播
            D_train_loss = D_real_loss + D_fake_loss
            D_train_loss.backward()
            D_opt.step()
            
            # Train G
            G_opt.zero_grad()
            noise = (
                torch.randn((mini_batch, NOISE_LENGTH)).
                view((-1, NOISE_LENGTH, 1, 1)).
                to(DEVICE)
            )
            # 生成假樸片
            img_fake = G(noise)
            # 輸入判別器D計算loss
            D_result = D(img_fake).squeeze()
            # 給假圖片打上真標籤，計算loss
            G_train_loss = BCE_loss(D_result, y_real)
            # 生成器反向傳播
            G_train_loss.backward()
            G_opt.step()
            
        print('D train loss : {} , G train loss : {}'.format( D_train_loss, G_train_loss))
        
        with torch.no_grad():
            D.eval()
            G.eval()
            for i, (img,_) in tqdm(enumerate(val_loader), total=len(val_loader)):
                mini_batch = img.size()[0]
                # 真假對應Label
                y_real = torch.ones(mini_batch)
                y_fake = torch.zeros(mini_batch)
                img, y_real, y_fake =(
                    img.to(DEVICE),
                    y_real.to(DEVICE),
                    y_fake.to(DEVICE)
                )
                # 真圖片輸入 D
                D_result = D(img).squeeze()
                # 計算真圖片的loss
                D_reak_loss = BCE_loss(D_result, y_real)
                # 新建一個隨機變數
                noise = (
                    torch.randn((mini_batch, NOISE_LENGTH)).
                    view((-1, NOISE_LENGTH, 1, 1)).
                    to(DEVICE)
                )
                # 生假圖片
                img_fake = G(noise)
                # 將假圖片輸入至判別器D
                D_result = D(img_fake).squeeze()
                # 計算假圖片的loss
                D_fake_loss = BCE_loss(D_result, y_fake)
                # 真假圖片的判別器loss相加之和反向傳播
                D_val_loss = D_real_loss + D_fake_loss

                noise = (
                    torch.randn((mini_batch, NOISE_LENGTH)).
                    view((-1, NOISE_LENGTH, 1, 1)).
                    to(DEVICE)
                )
                # 生成假樸片
                img_fake = G(noise)
                # 輸入判別器D計算loss
                D_result = D(img_fake).squeeze()
                # 給假圖片打上真標籤，計算loss
                G_val_loss = BCE_loss(D_result, y_real)
                
                print('D val loss : {} , G val loss : {}'.format( D_train_loss, G_train_loss))
                
    torch.save(G.state_dict)
    torc.save()