<a href="https://colab.research.google.com/github/HuanAII/GAN/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn, Tensor
import numpy as np
from torchvision.utils import save_image

torch.manual_seed(0)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cpu'

# 1. Dataset

In [2]:
import torchvision

img_size = 28

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5],
                                     std=[0.5])
])

images = torchvision.datasets.MNIST(root='./mnist_data', train=True, download=True, transform=transform)

100%|██████████| 9.91M/9.91M [00:00<00:00, 15.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 468kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.33MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.71MB/s]


In [3]:
BATCH_SIZE = 128
dataloader = torch.utils.data.DataLoader(images, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [4]:
channels = 1
img_shape = (channels, img_size, img_size)
hidden_dim = 100

In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Linear(z_dim, 256), # chuyen vector random dim 100 thanh 256
            nn.BatchNorm1d(256), # normalize
            nn.ReLU(), #activation function
        )
        self.block2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
        )
        self.block3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
        )
        self.block4 = nn.Sequential(
            nn.Linear(1024, img_size * img_size),
            nn.Tanh() # chuyen du lieu ve [-1,1]
        )
    def forward(self, noise):
        x = self.block1(noise)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return x.view(-1, 1, img_size, img_size) # x sau block 4 co size img_size*img_size
        #sau reshape ta duoc anh grey


In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Linear(img_size * img_size, 1024), # ta dua vao real img , fake img co size la (1 , img_size , img_size)
            nn.LeakyReLU(0.2),
        )
        self.block2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
        )
        self.block3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
        )
        self.block4 = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid() # dua ve quyet dinh fake hay real img
        )

    def forward(self, image):
        x = image.view(image.shape[0], -1) # ta nhan chieu height va weight cua anh thu duoc size moi
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x) # sau bloc4 thong qua sigmoid dua ve gia tri
        return x

In [7]:
generator = Generator()
discriminator = Discriminator()

In [8]:
generator.to(device)

Generator(
  (block1): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (block2): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (block3): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (block4): Sequential(
    (0): Linear(in_features=1024, out_features=784, bias=True)
    (1): Tanh()
  )
)

In [9]:
discriminator.to(device)

Discriminator(
  (block1): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (block2): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (block3): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (block4): Sequential(
    (0): Linear(in_features=256, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

# 3. Training

In [13]:
EPOCHS = 50

criterion = nn.BCELoss() # khai bao ham loss
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # cap nhat G thong qua ham Adam
# do dieu chinh momemtum bang 0.5 co y nghia can bang toc do hoi tu va on dinh  , nhay voi thay doi gradient
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


hist = {
        "train_G_loss": [],
        "train_D_loss": [],
    }

for epoch in range(EPOCHS): # duyet sang moi epoch
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, _) in enumerate(dataloader): # _ la labels

        real_imgs = imgs.to(device) # chuyen real_img sang cuda
        valid = torch.ones(imgs.shape[0], 1).to(device) # shape 0 la batch size , tao vector chua label la 1
        fake = torch.zeros(imgs.shape[0], 1).to(device) #  tao vector chua label la 1 co size la batch size , tuong ung 1 sampel 1 label

        # --- Train Generator ---
        optimizer_G.zero_grad()
        # Noise input for Generator
        z = Tensor(np.random.normal(0, 1, (imgs.shape[0], hidden_dim))).to(device)
        #tao z co gia tri tu 0 den 1 va size la batch_size , hidden_dim de thong qua generator

        gen_imgs = generator(z)
        # thu duoc batch size anh duoc sinh ra , fake_img

        G_loss = criterion(discriminator(gen_imgs), valid) # tinh loss gianh cho generator thong qua anh sinh ra va valid , ep du lieu ve 1
        # sau khi tinh goss tu disriminator se backpropacation de cap nhat lai network

        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()

        # --- Train Discriminator ---
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), valid)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake) # tach gen_imgs ra khoi update weight
        D_loss = (real_loss + fake_loss) / 2
        running_D_loss += D_loss.item()

        D_loss.backward()
        optimizer_D.step()

    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)

    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

    save_image(gen_imgs.data[:25], f"{OUTPUT}/epoch_{epoch}.png", nrow=5, normalize=True)


Epoch [1/50], Train G Loss: 1.3187, Train D Loss: 0.4461


KeyboardInterrupt: 

In [None]:
# save model ckpt
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")

In [None]:
import os
import glob
from PIL import Image

# Folder containing images
image_folder = "gan_output"

# Get all epoch images sorted by number
image_files = sorted(glob.glob(os.path.join(image_folder, "epoch_*.png")),
                     key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))

# Load images
images = [Image.open(img) for img in image_files]

# Determine grid size
num_images = len(images)
cols = 10
rows = 5

# Image size (assumes all images are the same size)
img_width, img_height = images[0].size
padding = 10  # Space between images

# Calculate total canvas size
grid_width = cols * img_width + (cols - 1) * padding
grid_height = rows * img_height + (rows - 1) * padding

# Create a blank canvas
grid_img = Image.new("RGB", (grid_width, grid_height), "white")

# Paste images into the grid with padding
for i, img in enumerate(images):
    x = (i % cols) * (img_width + padding)
    y = (i // cols) * (img_height + padding)
    grid_img.paste(img, (x, y))

# Save and show the final grid image
grid_img.save("gan_grid.png")
grid_img.show()
