In [1]:
pip install torch

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install torchvision

Note: you may need to restart the kernel to use updated packages.


In [3]:
from tqdm import tqdm
import torch
import torchvision as tv
from torch.utils.data import DataLoader
import torch.nn as nn

In [9]:
# config
class Config(object):
    """
    定义一个配置类
    """
    # adjust parameter
    data_path = './data'
    virs = 'result'
    num_workers = 4  
    img_size = 96  
    batch_size = 256  
    max_epoch = 400   
    lr1 = 2e-4  # learn rate
    lr2 = 2e-4  
    beta1 = 0.5  # Adam
    gpu = False 
    nz = 100  
    ngf = 64  
    ndf = 64  

    
    save_path = 'imgs3/'  
    
    d_every = 1  # 
    g_every = 5  # 
    save_every = 5  # 
    netd_path = None
    netg_path = None

    # result image
    gen_img = "result.png"
    
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0   
    gen_std = 1     

In [10]:
opt = Config()

#generator
class NetG(nn.Module):
    def __init__(self, opt):
        super(NetG, self).__init__()
        # self.ngf
        self.ngf = opt.ngf
        self.Gene = nn.Sequential(
            # output = (input - 1)*stride + output_padding - 2*padding + kernel_size
            nn.ConvTranspose2d(in_channels=opt.nz, out_channels=self.ngf * 8, kernel_size=4, stride=1, padding=0, bias =False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(inplace=True),

            # input4*4*ngf*8
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4, kernel_size=4, stride=2, padding=1, bias =False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(inplace=True),

            # input8*8*ngf*4
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2, kernel_size=4, stride=2, padding=1,bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(inplace=True),

            # input16*16*ngf*2
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf, kernel_size=4, stride=2, padding=1, bias =False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(inplace=True),

            # input32*32*ngf
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=3, kernel_size=5, stride=3, padding=1, bias =False),

            nn.Tanh(),

        )# output96*96*3

    def forward(self, x):
        return self.Gene(x)

In [11]:
# Discriminator
class NetD(nn.Module):
    def __init__(self, opt):
        super(NetD, self).__init__()

        self.ndf = opt.ndf
        self.Discrim = nn.Sequential(
            
            nn.Conv2d(in_channels=3, out_channels= self.ndf, kernel_size= 5, stride= 3, padding= 1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace= True),

            # input:(ndf, 32, 32)
            nn.Conv2d(in_channels= self.ndf, out_channels= self.ndf * 2, kernel_size= 4, stride= 2, padding= 1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, True),

            # input:(ndf *2, 16, 16)
            nn.Conv2d(in_channels= self.ndf * 2, out_channels= self.ndf *4, kernel_size= 4, stride= 2, padding= 1,bias=False),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, True),

            # input:(ndf *4, 8, 8)
            nn.Conv2d(in_channels= self.ndf *4, out_channels= self.ndf *8, kernel_size= 4, stride= 2, padding= 1, bias=False),
            nn.BatchNorm2d(self.ndf *8),
            nn.LeakyReLU(0.2, True),

            # input:(ndf *8, 4, 4)
            # output:(1, 1, 1)
            nn.Conv2d(in_channels= self.ndf *8, out_channels= 1, kernel_size= 4, stride= 1, padding= 0, bias=True),

            nn.Sigmoid()
        )

    def forward(self, x):
        return self.Discrim(x).view(-1)

In [12]:
def train(**kwargs):


    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    if opt.gpu:
        device = torch.device("cuda")
    else:
        device = torch.device('cpu')

    
    transforms = tv.transforms.Compose([
        # 3*96*96
        tv.transforms.Resize(opt.img_size),   
        
        tv.transforms.CenterCrop(opt.img_size),

        
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    
    dataset = tv.datasets.ImageFolder(root=opt.data_path, transform=transforms)

    dataloader = DataLoader(
        dataset,      
        batch_size=opt.batch_size,    
        shuffle=True,     
        
        drop_last=True           
    )

    
    netg, netd = NetG(opt), NetD(opt)
   
    map_location = lambda storage, loc: storage


    
    if opt.netg_path:
        netg.load_state_dict(torch.load(f=opt.netg_path, map_location=map_location))
    if opt.netd_path:
        netd.load_state_dict(torch.load(f=opt.netd_path, map_location=map_location))

    
    netd.to(device)
    netg.to(device)

    
    optimize_g = torch.optim.Adam(netg.parameters(), lr=opt.lr1, betas=(opt.beta1, 0.999))
    optimize_d = torch.optim.Adam(netd.parameters(), lr=opt.lr2, betas=(opt.beta1, 0.999))

   
    criterions = nn.BCELoss().to(device)

    
    true_labels = torch.ones(opt.batch_size).to(device)
    fake_labels = torch.zeros(opt.batch_size).to(device)

    
    noises = torch.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    
    fix_noises = torch.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    
    for epoch in range(opt.max_epoch):
        
        for ii_, (img, _) in tqdm((enumerate(dataloader))):
            
            real_img = img.to(device)

           
            if ii_ % opt.d_every == 0:
               
                optimize_d.zero_grad()

               
                output = netd(real_img)
                error_d_real = criterions(output, true_labels)
                error_d_real.backward()
                noises = noises.detach()
                fake_image = netg(noises).detach()
                output = netd(fake_image)
                error_d_fake = criterions(output, fake_labels)
                error_d_fake.backward()
                optimize_d.step()

            # train Discriminator
            if ii_ % opt.g_every == 0:
                optimize_g.zero_grad()
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_image = netg(noises)
                output = netd(fake_image)
                error_g = criterions(output, true_labels)
                error_g.backward()
                optimize_g.step()

        # save model
        if (epoch + 1) % opt.save_every == 0:
            fix_fake_image = netg(fix_noises)
            tv.utils.save_image(fix_fake_image.data[:64], "%s/%s.png" % (opt.save_path, epoch), normalize=True)

            torch.save(netd.state_dict(),  'imgs2/' + 'netd_{0}.pth'.format(epoch))
            torch.save(netg.state_dict(),  'imgs2/' + 'netg_{0}.pth'.format(epoch))


In [None]:
# @torch.no_grad()
@torch.no_grad()
def generate(**kwargs):
    # to generate pics

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = torch.device("cuda") if opt.gpu else torch.device("cpu")

    # upload weight
    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    map_location = lambda storage, loc: storage

    # opt.netd_path
    netd.load_state_dict(torch.load('imgs2/netd_399.pth', map_location=map_location), False)
    netg.load_state_dict(torch.load('imgs2/netg_399.pth', map_location=map_location), False)
    netd.to(device)
    netg.to(device)

    # generate trained pics
    noise = torch.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std).to(device)

    fake_image = netg(noise)
    score = netd(fake_image).detach()

    indexs = score.topk(opt.gen_num)[1]

    result = []

    for ii in indexs:
        result.append(fake_image.data[ii])

    # opt.gen_img
    tv.utils.save_image(torch.stack(result), opt.gen_img, normalize=True, range=(-1, 1))

def main():
    train()
    generate()

if __name__ == '__main__':
    main()

7it [00:56,  7.99s/it]