# DCGAN

## code reference

The overall structure of the DCGAN model refers to the code in the following link, with many changes made to the number of deconvolution layers and parameter settings when used.

https://github.com/HuiiJi/GAN?tab=readme-ov-file 

## Import some libraries

In [1]:
import torch
import torch.nn as nn
import torchvision
import os
from torch.utils.data import DataLoader
from tqdm import tqdm

## Parameter configuration

The code defines a class called Config that contains Settings for some configuration parameters.

These parameters include the path to save the results, the model path of the discriminator and generator, the image path and the size of the image, the batch size, the maximum number of training rounds, the noise vector dimension, and the number of feature channels.

In addition, the code checks if the results folder and snapshot folder exist, and creates them if they do not.

In [2]:
class Config():

    # Path setting
    result_save_path = 'results/' 
    d_net_path = 'DCGAN_snapshots/dnet.pth' 
    g_net_path = 'DCGAN_snapshots/gnet.pth' 
    img_path = 'painting/' 

    img_size = 96 
    batch_size = 256 
    max_epoch = 300 
    noise_dim = 100 
    feats_channel = 64 

opt = Config() 

# Generate the result folder and snapshots folder
if not os.path.exists('results'):
    os.mkdir('results')  
if not os.path.exists('DCGAN_snapshots'):
    os.mkdir('DCGAN_snapshots') 

## Generater design

A generator is defined whose constructor __init__() contains a series of layers for the generator network structure, including a transposed convolution layer, a batch normalization layer, and an activation function layer.

The generator receives the input noise vector, goes through a series of transposition convolution operations, and finally outputs a composite image similar to the target image.

The forward() function is used to perform the forward propagation operation of the generator, returning the generated image.

In [3]:
class Gnet(nn.Module):
    def __init__(self, opt):
        super(Gnet, self).__init__()
        self.feats = opt.feats_channel
        self.generate = nn.Sequential(

           #input = (n, c, h, w = 256, 100, 1, 1)
            nn.ConvTranspose2d(in_channels=opt.noise_dim, out_channels=self.feats * 8, kernel_size=4, stride=1, padding=0,
                               bias=False),
            nn.BatchNorm2d(self.feats * 8),
            nn.ReLU(inplace=True),
            # deconv = (input - 1 ) * stride + k -2* padding = (1- 1)*1  +4-0 = 4
            #output = (256, 800 ,4, 4)

            nn.ConvTranspose2d(in_channels=self.feats * 8, out_channels=self.feats * 4, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.BatchNorm2d(self.feats * 4),
            nn.ReLU(inplace=True),

            # decon = (input - 1)*stride + k - 2*padding = (4-1)*2 + 4-2 = 8

            nn.ConvTranspose2d(in_channels=self.feats * 4, out_channels=self.feats * 2, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.BatchNorm2d(self.feats * 2),
            nn.ReLU(inplace=True),

            # decon = (input - 1)*stride + k - 2*padding = (8-1)*2 + 4-2 = 16


            nn.ConvTranspose2d(in_channels=self.feats * 2, out_channels=self.feats, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.BatchNorm2d(self.feats),
            nn.ReLU(inplace=True),

            # decon = (input - 1)*stride + k - 2*padding = (16-1)*2 + 4-2 = 32

            nn.ConvTranspose2d(in_channels=self.feats, out_channels=3, kernel_size=5, stride=3, padding=1, bias=False),

            nn.Tanh(),
            # decon = (input - 1)*stride + k - 2*padding = (32-1)*3 + 5-2 = 96
            #output = (n, c, h, w = 256, 3, 96, 96)

    )

    def forward(self, x):
        return self.generate(x)

## Discriminantor design

A simple convolution discriminator is defined.

The discriminator takes an image as input, passes through a series of convolution and activation function layers, and finally outputs a scalar value between 0 and 1 that represents the probability that the input image is a real image.

In [4]:
class Dnet(nn.Module):
    def __init__(self, opt):
        super(Dnet, self).__init__()
        self.feats = opt.feats_channel
        self.discrim = nn.Sequential(

            #input = （n, c, h, w = 256, 3, 96, 96)

            nn.Conv2d(in_channels=3, out_channels= self.feats, kernel_size= 5, stride= 3, padding= 1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace= True),
            #con = (input - k + 2 * padding ) / stride +  1 = (256 - 5 + 2) / 3 + 1 = 128



            nn.Conv2d(in_channels= self.feats, out_channels= self.feats * 2, kernel_size= 4, stride= 2, padding= 1, bias=False),
            nn.BatchNorm2d(self.feats* 2),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(in_channels= self.feats * 2, out_channels= self.feats * 4, kernel_size= 4, stride= 2, padding= 1,bias=False),
            nn.BatchNorm2d(self.feats * 4),
            nn.LeakyReLU(0.2, True),


            nn.Conv2d(in_channels= self.feats * 4, out_channels= self.feats * 8, kernel_size= 4, stride= 2, padding= 1, bias=False),
            nn.BatchNorm2d(self.feats *8),
            nn.LeakyReLU(0.2, True),


            nn.Conv2d(in_channels= self.feats * 8, out_channels= 1, kernel_size= 4, stride= 1, padding= 0, bias=True),

            nn.Sigmoid()

            #output = ( n, c, h, w = 256, 1, 1, 1)
        )

    def forward(self, x):
        return self.discrim(x).view(-1) 

g_net, d_net = Gnet(opt), Dnet(opt)

## Data preprocessing

1. A series of image preprocessing operations are defined and a data loader is created.

2. Set up the device (CPU or CUDA), move the generator and discriminator to the device,

3. Optimizers, loss functions, and labels are defined.

4. The random noise for training is generated.

In [5]:
transforms = torchvision.transforms.Compose([
    # resize Image size
    torchvision.transforms.Resize(opt.img_size), 
    # Center crop image size
    torchvision.transforms.CenterCrop(opt.img_size),
    # array tensor.float(), adapted to the data format of the torch framework
    torchvision.transforms.ToTensor()
])

dataset = torchvision.datasets.ImageFolder(root=opt.img_path, transform=transforms)

dataloader = DataLoader(
    dataset,
    batch_size=opt.batch_size, 
    num_workers = 0,
    drop_last = True
)

# set device(cpu or cuda)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

g_net.to(device)
d_net.to(device)

optimize_g = torch.optim.Adam(g_net.parameters(), lr= 2e-4, betas=(0.5, 0.999))
optimize_d = torch.optim.Adam(d_net.parameters(), lr= 2e-4, betas=(0.5, 0.999))
# optimize_d = torch.optim.SGD(d_net.parameters(), lr= 2e-4)

#BCEloss, find the binary classification probability
criterions = nn.BCELoss().to(device) 

# Define the tag and start injecting the generator's input noise
true_labels = torch.ones(opt.batch_size).to(device) 
fake_labels = torch.zeros(opt.batch_size).to(device) 

# Generate N(1,1) standard normal distribution, 100 dimensional, 256 numbers of random noise
noises = torch.randn(opt.batch_size, opt.noise_dim, 1, 1).to(device)

test_noises = torch.randn(opt.batch_size, opt.noise_dim, 1, 1).to(device)

## Start training

Load the pre-trained generator and discriminator model weight files, and then execute the training process.

Within each training cycle, the training is done by looping through the images in the data loader. The discriminator is trained 5 times per session, and the generator is trained 1 time per session.

During the training process, the losses of discriminator and generator are calculated and optimized, and the model parameters are updated.

At the end of each training cycle, the generative network generates a batch of images and saves a portion of them as a result. At the same time, the weights of the model are saved and the loss values of the discriminator and generator for the current training cycle are printed.

In [None]:
#Load weight file
try:
    g_net.load_state_dict(torch.load(opt.g_net_path)) 
    d_net.load_state_dict(torch.load(opt.d_net_path))
    print('Load successfully. Continue training')
except:
    print('Load failed, retrain')

for epoch in range(opt.max_epoch):  

    for itertion, (img, _) in tqdm((enumerate(dataloader))):
        real_img = img.to(device)

        # The discriminator is trained 5 times and the generator is trained 1 time
        if itertion % 1 == 0: 

            optimize_d.zero_grad() 

            # Real data input discriminant network
            output = d_net(real_img)
            # The discriminator is expected to identify the real image as a positive sample with a label of 1
            d_real_loss = criterions(output, true_labels) 
            fake_image = g_net(noises.detach()).detach() 

            # Generate data input to the discriminant network
            output = d_net(fake_image) 
            # The discriminator is expected to identify the generated image as a negative sample with a label of 0
            d_fake_loss = criterions(output, fake_labels)

            #loss Fusion calculation
            d_loss = (d_fake_loss + d_real_loss) / 2 
            

            d_loss.backward() 
            optimize_d.step() 

        # Generate network optimizer gradient clear
        if itertion % 1 == 0:
            optimize_g.zero_grad() 
            noises.data.copy_(torch.randn(opt.batch_size, opt.noise_dim, 1, 1)) 

            # Calculate the probability that the generated image is real
            fake_image = g_net(noises) 
            output = d_net(fake_image) 
            g_loss = criterions(output, true_labels)

            g_loss.backward() 
            optimize_g.step() 

    # randomly generate 256 noises
    vid_fake_image = g_net(test_noises) 
    # Save the first 16 images
    torchvision.utils.save_image(vid_fake_image.data[:16], "%s/%s.png" % (opt.result_save_path, epoch), normalize=True) 
    torch.save(d_net.state_dict(),  opt.d_net_path)
    torch.save(g_net.state_dict(),  opt.g_net_path)
    #loss visualization
    print('epoch:', epoch, '---D-loss:---', d_loss.item(), '---G-loss:---', g_loss.item()) 