In [1]:
import cv2
import os
import torch
import torch.nn as nn
import numpy as np
import torch.nn.init as init
from torchvision import transforms
import pandas as pd

In [2]:
trans = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((.5, .5, .5), (.5, .5, .5))
    ]
)

M1=96
K1=5
S1=3

G_LR = 0.0003
D_LR = 0.0003
BATCHSIZE = 90
EPOCHES = 8000
MAIN_PATH="./training2D/"

In [3]:
def get_imgs():
    files = os.listdir(MAIN_PATH)
    imgs = []
    for file in files:
        imgs.append(cv2.imread(MAIN_PATH + file,0))
    print("get_imgs")
    return imgs

In [4]:
def init_ws_bs(m):
    if isinstance(m, nn.ConvTranspose2d):
        init.normal_(m.weight.data, std=0.2)
        init.normal_(m.bias.data, std=0.2)

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
                in_channels=100,
                out_channels=64 * 8,
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(inplace=True),
        )
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
                in_channels=64 * 8,
                out_channels=64 * 4,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(inplace=True),
        )
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
                in_channels=64 * 4,
                out_channels=64 * 2,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(inplace=True),
        )
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
                in_channels=64 * 2,
                out_channels=64 * 1,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.deconv5 = nn.Sequential(
            nn.ConvTranspose2d(64, 1, K1, S1, 1, bias=False),
            nn.Tanh(),
        )
        
 
    def forward(self, x):
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(x)
        x = self.deconv4(x)
        x = self.deconv5(x)
        return x
 
 
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
                
        self.conv1 = nn.Sequential(
            nn.Conv2d(  # batchsize,1,96,96
                in_channels=1,
                out_channels=64,
                kernel_size=K1,
                padding=1,
                stride=S1,
                bias=False,
            ),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(.2, inplace=True),
 
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False, ),  # batchsize,16,32,32
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(.2, inplace=True),
 
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(.2, inplace=True),
 
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(.2, inplace=True),
 
        )
        self.output = nn.Sequential(
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  #
        )
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.output(x)
        return x

In [6]:
g = Generator().cuda()
d = Discriminator().cuda()
 
init_ws_bs(g), init_ws_bs(d)
 
g_optimizer = torch.optim.Adam(g.parameters(), betas=(.5, 0.999), lr=G_LR)
d_optimizer = torch.optim.Adam(d.parameters(), betas=(.5, 0.999), lr=D_LR)
 
g_loss_func = nn.BCELoss()
d_loss_func = nn.BCELoss()
 
label_real = torch.ones(BATCHSIZE).cuda()
label_fake = torch.zeros(BATCHSIZE).cuda()

real_img = get_imgs()

get_imgs


In [14]:
for epoch in range(EPOCHES):
    np.random.shuffle(real_img)
    count = 0
    generate=0
    batch_imgs = []
    
    if epoch > 200:
        g_optimizer = torch.optim.Adam(g.parameters(), betas=(.5, 0.999), lr=0.8*G_LR)
        d_optimizer = torch.optim.Adam(d.parameters(), betas=(.5, 0.999), lr=0.8*D_LR)
        
    if epoch > 500:
        g_optimizer = torch.optim.Adam(g.parameters(), betas=(.5, 0.999), lr=0.5*G_LR)
        d_optimizer = torch.optim.Adam(d.parameters(), betas=(.5, 0.999), lr=0.5*D_LR)
       
    for i in range(len(real_img)):
        count = count + 1
        batch_imgs.append(real_img[i])  # tensor类型#这里经过trans操作通道维度从第四个到第二个了
        if count % BATCHSIZE==0:
            #count = 0
            start=count-BATCHSIZE
            end=count
                         
            batch_real = torch.Tensor(batch_imgs).cuda()
            batch_imgs.clear()
            d_optimizer.zero_grad()
            pre_real = d(batch_real.unsqueeze(1)).squeeze()
            d_real_loss = d_loss_func(pre_real, label_real)
            d_real_loss.backward()
 
            batch_fake = torch.randn(BATCHSIZE, 100, 1, 1).cuda()
            #batch_fake = dataset[start:end].cuda()
            img_fake = g(batch_fake).detach()
            pre_fake = d(img_fake).squeeze()

            d_fake_loss = d_loss_func(pre_fake, label_fake)
            d_fake_loss.backward()
 
            d_optimizer.step()
 
            g_optimizer.zero_grad()
            batch_fake = torch.randn(BATCHSIZE, 100, 1, 1).cuda()
            #batch_fake = dataset[start:end].cuda()
            img_fake = g(batch_fake)
            pre_fake = d(img_fake).squeeze()
                            
            g_loss = g_loss_func(pre_fake, label_real)
            g_loss.backward()
            g_optimizer.step()
            
            if epoch>=200:
                generate=generate+1
                imgs=g(torch.randn(1,100,1,1).cuda())                
                image=imgs[0].permute(1,2,0).cpu().detach().numpy()*255
                cv2.imwrite("test_samples/"+str(epoch)+str(generate)+".jpg",image)
            
            if epoch%200==0:
                print(epoch,i,(d_real_loss + d_fake_loss).detach().cpu().numpy(), g_loss.detach().cpu().numpy())
                torch.save(g, "pkl/" + str(epoch) + "g.pkl")

0 59 0.0026599357 16.558338


  "type " + obj.__name__ + ". It won't be checked "


0 119 0.00087071507 16.093792
0 179 0.0033105134 17.931475
0 239 0.026379757 20.857218
0 299 0.016292412 20.035498
0 359 0.0025174045 19.365269
0 419 0.0007985751 18.418428
0 479 0.0044573806 17.745872
0 539 0.000894839 18.692402
200 59 0.00041018362 9.368069
200 119 0.0038076877 8.7699585
200 179 0.00046522258 9.186563
200 239 0.0013894469 9.141986
200 299 0.00030700106 9.447865
200 359 0.0007426002 8.688296
200 419 0.0016714488 7.842141
200 479 0.0017305925 6.780424
200 539 0.0018386698 7.428142
400 59 0.0013442208 26.084433
400 119 0.01496202 16.088007
400 179 1.8170633e-05 10.479563
400 239 0.0066532195 19.000536
400 299 7.751847e-06 19.022911
400 359 0.0024441434 16.865248
400 419 2.9286161e-06 15.1261635
400 479 4.0404408e-05 10.420789
400 539 0.0010903376 12.481567
600 59 8.066552e-07 18.211203
600 119 3.8018209e-06 24.200615
600 179 2.4934639e-07 27.087065
600 239 9.551416e-06 20.192785
600 299 0.00015002918 28.308939
600 359 0.0004644251 22.204851
600 419 1.2547292e-05 14.7211

KeyboardInterrupt: 