In [3]:
import xml.etree.cElementTree as ET
import numpy as np
from pathlib import Path
from PIL import Image
import cv2
from DCGAN_dog import NetG, NetD
import config
import gzip
import pickle

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, TensorDataset

In [5]:
device = torch.device('cuda:2' if torch.cuda.is_available else 'cpu')
BATCH_SIZE = config.BATCH_SIZE
MAX_EPOCH = 300
latent_dim = config.latent_dim

In [6]:
data_root = Path('stanford_dog/')
annotation_path = data_root / 'annotation'
image_path = data_root / 'images'
croped_path = data_root / 'croped_images'
croped_path.mkdir(parents = True, exist_ok = True)
processed_path = data_root / 'processed'
processed_path.mkdir(parents = True, exist_ok = True)

In [5]:
# 根据annotation里的bounding box得到crop后的图片
# for category_dir in annotation_path.iterdir():
#     (croped_path / category_dir.name).mkdir(parents=True, exist_ok=True)
#     for i in category_dir.iterdir():
#         root = ET.parse(str(i)).getroot()
#         x_min = int(root[5][4][0].text)
#         y_min = int(root[5][4][1].text)
#         x_max = int(root[5][4][2].text)
#         y_max = int(root[5][4][3].text)
#         img = cv2.imread(str(image_path / category_dir.name / i.stem) + '.jpg', -1)
#         cv2.imwrite(str(croped_path / category_dir.name / i.stem) + '.jpg', img[int(y_min):int(y_max), int(x_min):int(x_max),:])

In [6]:
# 将crop后的图片缩放的109*109*3以匹配输入
# for category_dir in croped_path.iterdir():
#     (processed_path / category_dir.name).mkdir(parents=True, exist_ok=True)
#     for i in category_dir.iterdir():
#         img = cv2.imread(str(i), -1)
#         img = cv2.resize(img,(109,109), interpolation = cv2.INTER_AREA)
#         cv2.imwrite(str(processed_path / category_dir.name / i.name), img)

In [7]:
ds_train = torchvision.datasets.ImageFolder('data/', transform=torchvision.transforms.Compose([
    torchvision.transforms.Resize((109,109)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean = (0.4211971, 0.40016708, 0.40234432), std=(0.27824208, 0.27396, 0.27874818))
]))

dl_train = DataLoader(
    dataset=ds_train,
    batch_size=BATCH_SIZE,
    shuffle=True,    
#     num_workers=2,
    drop_last=True
)

In [16]:
def model_init():
    G, D = NetG().to(device), NetD().to(device)
    # Adam 修改一下Momentum的参数，0.5等于focus on最近的2次迭代，默认的beta1 = 0.9，focus on最近10次有点太长了。
    optG = torch.optim.SGD(G.parameters(), lr=0.00002)# 一开始用lr=0.0002会好一点
    optD = torch.optim.SGD(D.parameters(), lr=0.00002)
#     optD = torch.optim.SGD(D.parameters(), lr=0.0002)
    return G, optG, D, optD

In [11]:
real_labels = torch.ones(BATCH_SIZE, 1).to(device)
fake_labels = -1.0 * torch.ones(BATCH_SIZE, 1).to(device)
g_labels = torch.ones(BATCH_SIZE, 1).to(device)
def fit(G, D, optG, optD, criterion, start_epoch = 0, max_epoch = MAX_EPOCH):
    for epoch in range(start_epoch, max_epoch):
        for i, (x_gt, y_gt) in enumerate(dl_train):
            # 将x,y转好格式
            x_gt = x_gt.to(device) # (1024,784)float32
#             y_gt = F.one_hot(y_gt.to(torch.long)).to(device) # (batch_size,10) float32,因为输入要cat，与z保持同类型
            
            # 固定G，训练D
#             G.eval()
#             D.train()              
            # 生成Gaussian噪声
            z = torch.randn((BATCH_SIZE, latent_dim)).view(BATCH_SIZE, latent_dim, 1, 1).to(device) # (batch_size,100) float32
   
            out_x = D(x_gt)
            loss_x = criterion(out_x, real_labels)  
        
            out_z = D(G(z).detach()) # (batch_size,1) float32 detach防止误差传播到G
            loss_z = criterion(out_z, fake_labels)
            loss_D = loss_z + loss_x
            
            optD.zero_grad()
            loss_D.backward()
            optD.step()
            
            # 固定G，训练D
            if i % 2 == 0:        
#                 G.train()
#                 D.eval()
                # 生成Gaussian噪声
                z = torch.randn(BATCH_SIZE, latent_dim).view(BATCH_SIZE, latent_dim, 1, 1).to(device) # (1024,100) float32
                fake = G(z)
                out_z = D(fake) # (1024,1) float32
                loss_G = criterion(out_z, g_labels)

                optG.zero_grad()
                loss_G.backward()
                optG.step()
                
                print('Epoch:[%d/%d], Iteration:[%d/%d], Loss_G = %f, Loss_D = %f.\n' %(epoch+1, max_epoch, i, (8144//BATCH_SIZE), loss_G.item(), loss_D.item()))

        
        
        if (epoch+1) % 1 == 0:
#             G.eval()
            print('Epoch:[%d/%d], Loss_G = %f, Loss_D = %f.\n' %(epoch+1, max_epoch, loss_G.item(), loss_D.item()))
#             img = get_sample_image(G)
            torchvision.utils.save_image(fake, 'DCGAN_sampler/%d_my.jpg' % (epoch+1))
#             cv2.imwrite('DCGAN_sampler/%d.jpg' % (epoch+1), img*255)

In [17]:
# criterion = F.mse_loss #predict使用onehot过的标签作为输入，但是计算损失还是只考虑是否为数字
criterion = nn.MSELoss().to(device)
G, optG, D, optD = model_init()

In [18]:
G.load_state_dict(torch.load('car_G_481.pth')['state'])
D.load_state_dict(torch.load('car_D_481.pth')['state'])

In [19]:
fit(G,D,optG,optD,criterion,1000,1500)

Epoch:[1001/1500], Iteration:[0/15], Loss_G = 2.227064, Loss_D = 0.311804.

Epoch:[1001/1500], Iteration:[2/15], Loss_G = 2.410249, Loss_D = 0.331204.

Epoch:[1001/1500], Iteration:[4/15], Loss_G = 2.504159, Loss_D = 0.285710.

Epoch:[1001/1500], Iteration:[6/15], Loss_G = 2.472982, Loss_D = 0.239517.

Epoch:[1001/1500], Iteration:[8/15], Loss_G = 2.571881, Loss_D = 0.176914.

Epoch:[1001/1500], Iteration:[10/15], Loss_G = 2.708025, Loss_D = 0.180518.

Epoch:[1001/1500], Iteration:[12/15], Loss_G = 2.705430, Loss_D = 0.125228.

Epoch:[1001/1500], Iteration:[14/15], Loss_G = 2.928019, Loss_D = 0.123685.

Epoch:[1001/1500], Loss_G = 2.928019, Loss_D = 0.123685.

Epoch:[1002/1500], Iteration:[0/15], Loss_G = 2.939899, Loss_D = 0.133797.

Epoch:[1002/1500], Iteration:[2/15], Loss_G = 3.026256, Loss_D = 0.102908.

Epoch:[1002/1500], Iteration:[4/15], Loss_G = 3.078828, Loss_D = 0.083937.

Epoch:[1002/1500], Iteration:[6/15], Loss_G = 3.163186, Loss_D = 0.107242.

Epoch:[1002/1500], Iteratio

In [20]:
state_G = {
    'state': G.state_dict(),
    'epoch': 1500,
}
state_D = {
    'state': D.state_dict(),
    'epoch': 1500,
}
torch.save(state_G, 'car_G_1500.pth')
torch.save(state_D, 'car_D_1500.pth')

1

In [2]:
G

NameError: name 'G' is not defined