# ACGAN => Auxiliary Classifier gan 
### this network is similar to the CGAN but instead of passing one hot label vector to both generator and discriminator we just pass it to the generator and give the discriminator another classification task to preform we will make the model classify the output and this belived to make the network learn better so we used this architecture but we used kernal size 4 to be able to get the suitable shape 

![ACGAN](images/ACGAN.png)

In [None]:
import torch 
from torch import nn 
import torch.optim as optim
import numpy as np 
import sys
from PIL import Image
from tensorflow.keras.datasets import mnist
import os 
import matplotlib.pyplot as plt
import math 
import torchvision.datasets  as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [None]:
class Generator(nn.Module) : 
    def __init__(self) : 
        super().__init__() 
        self.linear_1 = nn.Linear(110 , 7*7*128 ) 
        self.seq = nn.Sequential(
            nn.BatchNorm2d(128) , 
            nn.ReLU() , 
            nn.ConvTranspose2d(128 , 128 , 4 , padding = 1 , stride = 2 ) , 
            nn.BatchNorm2d(128) , 
            nn.ReLU() , 
            nn.ConvTranspose2d(128 , 64 , 4 , padding = 1 , stride = 2 ) , 
            nn.BatchNorm2d(64) , 
            nn.ReLU() , 
            nn.ConvTranspose2d(64 , 32 , 4 , padding = 1 , stride = 1 ) , 
            nn.BatchNorm2d(32) , 
            nn.ReLU() , 
            nn.ConvTranspose2d(32 , 1 , 4 , padding = 2 , stride = 1 ) , 
            nn.Sigmoid()
        
        )
    def forward(self , inputs ) : 
        noise_vector , one_hot_vector = inputs 
        X = torch.cat((noise_vector , one_hot_vector )  , axis =  1)
        X = self.linear_1(X) 
        X = X.view(-1 , 128 , 7 , 7 )
        return self.seq(X) 

In [None]:
class Discriminator(nn.Module) : 
    def __init__(self) : 
        super().__init__() 
        
        self.seq = nn.Sequential(
            nn.LeakyReLU(.2) , 
            nn.Conv2d(1 , 32 , 4 , padding = 1 , stride = 2 ) ,
            nn.LeakyReLU(.2) , 
            nn.Conv2d(32 , 64 , 4 , padding = 1 , stride = 2 ) ,
            nn.LeakyReLU(.2) , 
            nn.Conv2d(64 , 128 , 4 , padding = 1 , stride = 2 ) ,
            nn.LeakyReLU(.2) , 
            nn.Conv2d(128 , 256 , 4 , padding = 1 , stride = 1 ) ,
            nn.Flatten() , 
            )
        self.linear_out_1 = nn.Linear(2 * 2 * 256 , 1)  
        self.out_1 = nn.Sigmoid() 
        
        self.linear_out_2 = nn.Linear(2*2*256 , 10 ) 
    def forward(self ,inputs   ) : 
        
        X =  self.seq(inputs) 
        output_1 = self.out_1(self.linear_out_1(X))
        output_2 = self.linear_out_2(X)
        return output_1 , output_2

In [None]:
def train(epochs ) :  
    loss_fn_1 = nn.BCELoss()
    loss_fn_2 = nn.CrossEntropyLoss()
    discriminator = Discriminator() 
    dis_optimizer = optim.Adam(discriminator.parameters(), lr= 2e-4 )
    generator = Generator() 
    gen_optimizer = optim.Adam(generator.parameters(), lr= 2e-4 ) 
    generator.train() 
    discriminator.train() 
    
    
    transforms_ = transforms.Compose(
        [
        transforms.Resize((28 ,28 )),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(1)], [0.5 for _ in range(1)]
        ),
        ]
    )
    dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms_,
                       download=True)
    dataloader , _ = torch.utils.data.random_split(dataset, [1,59999 ])
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    
    
    for epoch in range(epochs ) : 
        noise_class = np.eye(10)[np.arange(0, 16) % 10]
        noise_class = torch.from_numpy(noise_class).type(torch.float)
        gen_losses = []
        dis_losses = []
        
        for batch_images ,labels  in dataloader : 
            labels = nn.functional.one_hot(labels , num_classes = 10).type(torch.float)
            noise = torch.randn(64 , 100 )
            fake_labels = np.eye(10)[np.random.choice(10,64)]
            fake_labels = torch.from_numpy(fake_labels).type(torch.float)

            
            fake_images = generator((noise , fake_labels)) 
            fake_preds_1 , fake_preds_2 = discriminator(fake_images.detach())

            fake_preds_1 , fake_preds_2 = fake_preds_1.reshape(-1) , fake_preds_2
            real_preds_1 , real_preds_2  = discriminator(batch_images)
            real_preds_1 , real_preds_2 = real_preds_1.reshape(-1) , real_preds_2
            dis_fake_loss_1 = loss_fn_1(fake_preds_1 , torch.zeros_like(fake_preds_1))
            dis_real_loss_1 = loss_fn_1(real_preds_1 , torch.ones_like(real_preds_1))
            dis_fake_loss_2 = loss_fn_2(fake_preds_2 , fake_labels)
            dis_real_loss_2 = loss_fn_2(real_preds_2 , labels )
            
            dis_loss = (dis_fake_loss_1 + dis_real_loss_1 ) +(dis_fake_loss_2 + dis_real_loss_2 ) 
            dis_losses.append(dis_loss.detach().numpy()) 
            
            discriminator.zero_grad()
            dis_loss.backward()
            dis_optimizer.step() 
            
            output_1 , output_2 = discriminator(fake_images)
            output_1 , output_2  = output_1.reshape(-1) , output_2
            generator_loss_1 = loss_fn_1(output_1 , torch.ones_like(output_1))
            generator_loss_2 = loss_fn_2(output_2 , fake_labels)
            generator_loss = generator_loss_1 + generator_loss_2
            gen_losses.append(generator_loss.detach().numpy()) 
            
            generator.zero_grad()
            generator_loss.backward()
            gen_optimizer.step()
            
        print(f'gen_loss:{np.mean(gen_losses)} , dis_loss:{np.mean(dis_losses)}')
            
        plot_images(generator,
                torch.randn(16 , 100),
                noise_class , 
                show=False,
                step=epoch,
                model_name="gan")
        
                 
    

In [None]:
def plot_images(generator,
                noise_input,
                noise_class,
                show=False,
                step=0,
                model_name="gan"):
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    
    generator.eval()
    with torch.no_grad() : 
        images = generator((noise_input , noise_class)) 
    images = images.view(-1  , 28 , 28,1 )
    print(model_name , " labels for generated images: ", np.argmax(noise_class, axis=1))
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    rows = int(math.sqrt(noise_input.shape[0]))
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

In [None]:
train(10)