In [1]:
import os
import time
from AC_GAN import *
import numpy as np
from torchvision.utils import save_image
import torch.optim as optim


In [2]:
EPOCHS = 50
BATCH_SIZE = 100
N_WORKERS = 4
HIDDEN_DIM = 100
NOISE_STD = 0.1
learning_rates = [0.0001, 0.0002, 0.0003]
betas = (0.5, 0.999)
N_CLASS = 10

In [3]:
import torchvision.transforms as transforms
transform_train = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  
])

import torchvision.datasets as dset
from torch.utils.data import DataLoader

DATA_ROOT = "./data"
train_set = dset.CIFAR10(
    root=DATA_ROOT,
    train=True,
    download=True,
    transform=transform_train
)
train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

Files already downloaded and verified


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    print("GPU USED")
else:
    print("CPU USED")

# define the model
generator = Generator(NOISE_STD).to(device)
discriminator = Discriminator().to(device)

# define the optimizers
optimizer_gen = optim.Adam(generator.parameters(), lr=0.0002, betas=betas)
optimizer_dis = optim.Adam(discriminator.parameters(), lr=0.0002, betas=betas)

# define the loss function
dis_loss = nn.CrossEntropyLoss().to(device)
cla_loss = nn.CrossEntropyLoss().to(device)

CPU USED


In [7]:
start = time.time()
print("==> Training starts!")
print("="*50)
lossG_history = []
lossD_history = []
for i in range(EPOCHS):
    epoch_start = time.time()
    print("Epoch %d:" %i)
    
    lossG_record = []
    lossD_record = []
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.float().to(device)
        targets = targets.type(torch.LongTensor).to(device)
        batch_size = inputs.shape[0]
        real_labels = torch.ones((batch_size, 1), device=device)
        fake_labels = torch.zeros((batch_size, 1), device=device)
        # Train Generator
        nos = torch.randn(batch_size, HIDDEN_DIM).to(device)
        class_to_gen = torch.randint(0, N_CLASS, (batch_size,)).to(device)
        one_hot_labels = torch.eye(N_CLASS)[class_to_gen].to(device)
        nw = torch.cat((nos, one_hot_labels), dim=1)
        class_to_gen_torch = class_to_gen.type(torch.LongTensor).to(device)
        
        img = generator(nw)
        fake_out = discriminator(img)
        lossG1 = dis_loss(fake_out[:, 0].view(-1, 1), real_labels)
        lossG2 = cla_loss(fake_out[:, 1:], class_to_gen_torch)
        avg_lossG = (lossG1 + lossG2) / 2
        
        # zero the gradients and update the weights
        optimizer_gen.zero_grad()
        avg_lossG.backward()
        optimizer_gen.step()
        
        #record the losses
        lossG_record.append(avg_lossG.cpu().detach().numpy())

        # Train Discriminator
        real_out = discriminator(inputs)
        lossD1 = dis_loss(real_out[:, 0].view(-1, 1), fake_labels)
        lossD2 = cla_loss(real_out[:, 1:], class_to_gen_torch)

        fake_out2 = discriminator(img.detach())
        lossD3 = dis_loss(fake_out2[:, 0].view(-1, 1), fake_labels)
        lossD4 = cla_loss(fake_out2[:, 1:], class_to_gen_torch)

        # zero the gradients and update the weights
        avg_lossD = (lossD1 + lossD2 + lossD3 + lossD4) / 4
        optimizer_dis.zero_grad()
        avg_lossD.backward()
        optimizer_dis.step()

        #record the losses
        lossD_record.append(avg_lossD.cpu().detach().numpy())

        if batch_idx % 500 == 0:
            batches_done = i * len(train_loader) + batch_idx
            generated_images = generator(torch.randn(BATCH_SIZE, HIDDEN_DIM+N_CLASS).to(device)).data[:50]
            save_image(generated_images, os.path.join('./images', "%d.png" % batches_done), nrow=5, normalize=True)
            
    avg_loss_g = np.sum(np.asarray(lossG_record)) / len(lossG_record)
    avg_loss_d = np.sum(np.asarray(lossD_record)) / len(lossD_record)
    print("Generator loss: %.4f, Discriminator loss: %.4f"%(avg_loss_g, avg_loss_d))

    lossG_history.append(avg_loss_g)
    lossD_history.append(avg_loss_d)

    # save the model checkpoint
    torch.save(generator.state_dict(), os.path.join('./model', 'generator.pth'))
    torch.save(discriminator.state_dict(), os.path.join('./model', 'discriminator.pth'))
    print(f"Epoch finished in {time.time() - epoch_start:.2f}s")
    print("")

print("="*50)
print(f"==> Optimization finished in {time.time() - start:.2f}s!")

==> Training starts!
Epoch 0:
Generator loss: 0.7137, Discriminator loss: 0.9336
Epoch finished in 268.45s

Epoch 1:
Generator loss: 0.0830, Discriminator loss: 0.6173
Epoch finished in 263.42s

Epoch 2:
Generator loss: 0.0002, Discriminator loss: 0.5758
Epoch finished in 273.50s

Epoch 3:


KeyboardInterrupt: 