## Same DCGAN premise as prior - however, changing the "game" played.

In [None]:
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import pandas as pd
import statistics
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

In [None]:
# Root directory for dataset
dataroot = "/Users/senadkokic/Desktop/F2023/STAT940/Final Project/data"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 20

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 2

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    x = torch.ones(1, device=device)
    print (x)
else:
    print ("MPS device not found.")

In [None]:
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# plot training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-GPU
if (device.type == 'mps') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize weights
netG.apply(weights_init)

# Print the model
print(netG)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu

        # Define individual layers
        self.layer1 = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.final_layer = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        # Apply layers and capture intermediate outputs
        output1 = self.layer1(input)
        output2 = self.layer2(output1)
        output3 = self.layer3(output2)
        output4 = self.layer4(output3)
        final_output = self.final_layer(output4)

        # Return final output and one or more intermediate outputs
        return final_output,output1, output2, output3, output4  # we can choose which layers to return

# adding the feedback loss layer    
def feedback_loss_layer(disc_intermediate_output_fake, disc_intermediate_output_real):
    # Calculate the L1 loss between the intermediate outputs
    return F.l1_loss(disc_intermediate_output_fake, disc_intermediate_output_real)

In [None]:
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-GPU
if (device.type == 'mps') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize weights
netD.apply(weights_init)

# Print the model
print(netD)

In [None]:
# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
torch.cuda.empty_cache()

In [None]:
# Set up for anomaly detection (for debugging purposes)
torch.autograd.set_detect_anomaly(True)

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

# lambdas for strength of mechanism
lambda_fb = 0.6
lambda_reward = 0.25
lambda_punish = 0.1

# thresholds for feedback mechanism
epoch_threshold_1 = 3
epoch_threshold_2 = 10
epoch_threshold_3 = 17

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    for i, (real_data, _) in enumerate(dataloader):
        ############################
        # Update D network
        ###########################
        netD.zero_grad()
        real_cpu = real_data.to(device)
        b_size = real_cpu.size(0)

        
        # Phase 1: Punish
        if (epoch+1) % 4 != 0:
            # Process real data
            label_real = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            output_real,_,_,_,_ = netD(real_cpu)
            output_real = output_real.view(-1)  # Flatten the output for loss calculation
            errD_real = criterion(output_real, label_real)
            D_x = output_real.mean().item()

            # Process fake data
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label_fake = torch.full((b_size,), fake_label, dtype=torch.float, device=device)
            output_fake,_,_,_,_ = netD(fake.detach())
            output_fake = output_fake.view(-1)  # Flatten the output for loss calculation
            errD_fake = criterion(output_fake, label_fake)
            D_G_z1 = output_fake.mean().item()

            # Apply punishment
            punish = lambda_punish * torch.mean(((output_fake > 0.5).float() == label_fake).float())
            errD = errD_real + errD_fake + punish

        # Phase 2: Reward 
        else:
            # Process a subset of real data
            split_index = int(b_size * 0.75)  # Adjust the split index as needed
            real_data_subset = real_cpu[:split_index]
            unused_real_data = real_cpu[split_index:]
            label_subset = torch.full((split_index,), real_label, dtype=torch.float, device=device)

            # Forward pass and compute loss for the subset of real data
            output_subset,_,_,_,_ = netD(real_data_subset)
            output_subset = output_subset.view(-1)
            errD_real_subset = criterion(output_subset, label_subset)
            D_x = output_subset.mean().item()

            # Generate fake data
            noise = torch.randn(b_size - split_index, nz, 1, 1, device=device)
            fake_data = netG(noise)
            fake_label_subset = torch.full((b_size - split_index,), fake_label, dtype=torch.float, device=device)

            # Combine and shuffle unused real data with fake data
            combined_data = torch.cat((unused_real_data, fake_data), dim=0)
            combined_labels = torch.cat((torch.ones(unused_real_data.size(0), device=device), torch.zeros(fake_data.size(0), device=device)), dim=0)
            indices = torch.randperm(combined_data.size(0))
            combined_data = combined_data[indices]
            combined_labels = combined_labels[indices]

            # Forward pass and compute loss for combined data
            output_combined,_,_,_,_ = netD(combined_data)
            output_combined = output_combined.view(-1) 
            errD_combined = criterion(output_combined, combined_labels)
            D_G_z1 = output_combined.mean().item()

            # Apply reward
            reward = lambda_reward * torch.mean(((output_combined > 0.5).float() == combined_labels).float())
            errD = errD_real_subset + errD_combined - reward

        # Update D
        errD.backward()
        optimizerD.step()        


        ############################
        # Update G network
        ###########################
        netG.zero_grad()
        
        # Generate batch of latent vectors and fake images
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)

        # Create label tensor filled with real_label
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)


        # Since we just updated D, perform another forward pass of all-fake batch through D
        output, interm_output1, interm_output2, interm_output3, interm_output4 = netD(fake)
        output = output.view(-1)  # Flatten the output for loss calculation
        
        # Forward pass real batch through D to get real intermediate outputs
        _, interm_output_real1, interm_output_real2, interm_output_real3, interm_output_real4 = netD(real_cpu)

        # Select feedback loss based on current epoch
        if (epoch+1) < epoch_threshold_1:
            fb_loss = feedback_loss_layer(interm_output1, interm_output_real1)
        elif (epoch+1) < epoch_threshold_2:
            fb_loss = feedback_loss_layer(interm_output2, interm_output_real2)
        elif (epoch+1) < epoch_threshold_3:
            fb_loss = feedback_loss_layer(interm_output3, interm_output_real3)
        else:
            fb_loss = feedback_loss_layer(interm_output4, interm_output_real4)

        # Calculate traditional GAN loss
        errG = criterion(output, label)
        
        # Combine losses
        total_errG = errG + lambda_fb * fb_loss 

        # Calculate gradients for G and update G
        D_G_z2 = output.mean().item()
        total_errG.backward()
        optimizerG.step()        
        
        # Output training stats
        if i % 100 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f\tfb_loss: %.4f'
                  % (epoch+1, num_epochs, i, len(dataloader),
                     errD.item(), total_errG.item(), D_x, D_G_z1, D_G_z2,fb_loss))

        # Save Losses for plotting later
        G_losses.append(total_errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G",linewidth=0.2)
plt.plot(D_losses,label="D",linewidth=0.2)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
print(statistics.mean(G_losses))
print(statistics.mean(D_losses))

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()