### Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

### Generator Class

In [17]:
class Generator(nn.Module):
    def __init__(self, laten_dim, channels):
        super(Generator, self).__init__()
        
        self.latent_dim = latent_dim
        self.channels = channels
        
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.model(input)

### Discriminator Class

In [18]:
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.channels = channels

        self.model = nn.Sequential(
                    nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(64, 128, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(128),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(128, 256, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(256),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(256, 512, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(512),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(512, 1, 4, 1, 0, bias=False),
                    nn.Sigmoid()
                )
        
    def forward(self,input):
        return self.model(input).view(-1, 1).squeeze(1)

### TRAIN DCGAN (Deep Convolutional Adversarial Network)

In [19]:
def train_dcgan(dataloader, num_epochs, latent_dim, lr, beta1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    netG = Generator(latent_dim, channels=3).to(device)
    netD = Discriminator(channels=3).to(device)
    
    criterion = nn.BCELoss()
    
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    
    fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
    
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            netD.zero_grad()
            real = data[0].to(device)
            batch_size = real.size(0)
            label = torch.full((batch_size,), 1, dtype=torch.float, device=device)
            
            output = netD(real)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()
            
            noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(0.)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(1.)
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            
            if i % 50 == 0:
                print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}')

        # Save generated images
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            torchvision.utils.save_image(fake, f'fake_samples_epoch_{epoch}.png', normalize=True)

    return netG, netD

# Main Execution

In [21]:
if __name__ == "__main__":
    # Hyperparameters
    batch_size = 128
    image_size = 64
    num_epochs = 5
    lr = 0.0002
    beta1 = 0.5
    latent_dim = 100

    # Data loading and preprocessing
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataset = torchvision.datasets.CIFAR10(root='../data', download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Train the model
    generator, discriminator = train_dcgan(dataloader, num_epochs, latent_dim, lr, beta1)

    # Save the trained models
    torch.save(generator.state_dict(), 'generator.pth')
    torch.save(discriminator.state_dict(), 'discriminator.pth')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100.0%


Extracting ../data/cifar-10-python.tar.gz to ../data
[0/5][0/391] Loss_D: 1.4553 Loss_G: 2.0123 D(x): 0.4029 D(G(z)): 0.3968/0.1388


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc09a2cf160>
Traceback (most recent call last):
  File "/Users/ali/Documents/Project/Creative-Art-GAN/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/Users/ali/Documents/Project/Creative-Art-GAN/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/anaconda3/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/opt/anaconda3/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/opt/anaconda3/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/opt/anaconda3/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
Keyboard

KeyboardInterrupt: 