In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.autograd as autograd
import torchvision
import os
import cv2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

Sat Jan 11 17:13:08 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 522.06       Driver Version: 522.06       CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000   WDDM  | 00000000:01:00.0  On |                  Off |
| 62%   85C    P2   254W / 300W |   3233MiB / 49140MiB |     78%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
# Basic
RESOLUTION = 256 # If you want to change the resolution, you also need to change architecture of Generator and Discriminator as well.
CHANNELS = 1 # RGB images -> 3, gray scale images -> 1. Now only gray scale images are available.
BATCH_SIZE = 32
EPOCHS = 1000
SAVE_INTERVAL = 10

# Optional
LAMBDA_GP = 10
Z_DIM = 100
LEARNING_RATE = 0.0002

In [4]:
DATASET = "./dataset"
AUGMENTED_DATASET = "./augmented_dataset"

def data_augmentation(input_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    transformations = [
        lambda x: x,
        lambda x: cv2.rotate(x, cv2.ROTATE_90_CLOCKWISE),
        lambda x: cv2.rotate(x, cv2.ROTATE_90_COUNTERCLOCKWISE),
        lambda x: cv2.rotate(x, cv2.ROTATE_180),
        lambda x: cv2.flip(x, 1),
        lambda x: cv2.flip(cv2.rotate(x, cv2.ROTATE_90_CLOCKWISE), 1),
        lambda x: cv2.flip(cv2.rotate(x, cv2.ROTATE_90_COUNTERCLOCKWISE), 1),
        lambda x: cv2.flip(cv2.rotate(x, cv2.ROTATE_180), 1),
    ]

    for folder_name in os.listdir(input_folder):
        input_folder_path = os.path.join(input_folder, folder_name)
        output_folder_path = os.path.join(output_folder, folder_name)

        if not os.path.exists(output_folder_path):
            os.makedirs(output_folder_path)

        for filename in os.listdir(input_folder_path):
            filepath = os.path.join(input_folder_path, filename)
            img = cv2.imread(filepath)
            if img is None:
                continue
            
            for i, transform in enumerate(transformations):
                augmented_img = transform(img)
                output_filepath = os.path.join(output_folder_path, f"{os.path.splitext(filename)[0]}_aug{i}.tif")
                cv2.imwrite(output_filepath, augmented_img)

def get_loader():
    transform = transforms.Compose([
        transforms.Resize((RESOLUTION, RESOLUTION)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    
    dataset = datasets.ImageFolder(AUGMENTED_DATASET, transform=transform)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    return loader

data_augmentation(DATASET, AUGMENTED_DATASET)

In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_channels):
        super(Generator, self).__init__()
        self.generator_networks = nn.Sequential(
            # 1st layer (1 pix to 4 pix)
            nn.ConvTranspose2d(z_dim, 1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),

            # 2nd layer (4 pix to 8 pix)
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 3rd layer (8 pix to 16 pix)
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 4th layer (16 pix to 32 pix)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # 5th layer (32 pix to 64 pix)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # 6th layer (64 pix to 128 pix)
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            # 7th layer (128 pix to 256 pix)
            nn.ConvTranspose2d(32, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.generator_networks(x)

In [6]:
class Discriminator(nn.Module):
    def __init__(self, img_channels):
        super(Discriminator, self).__init__()
        self.discriminator_networks = nn.Sequential(
            # 1st layer (256 pix to 128 pix)
            nn.Conv2d(img_channels, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 2nd layer (128 pix to 64 pix)
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LayerNorm([32, 64, 64]),
            nn.LeakyReLU(0.2, inplace=True),

            # 3rd layer (64 pix to 32 pix)
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LayerNorm([64, 32, 32]),
            nn.LeakyReLU(0.2, inplace=True),

            # 4th layer (32 pix to 16 pix)
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LayerNorm([128, 16, 16]),
            nn.LeakyReLU(0.2, inplace=True),

            # 5th layer (16 pix to 8 pix)
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LayerNorm([256, 8, 8]),
            nn.LeakyReLU(0.2, inplace=True),

            # 6th layer (8 pix to 4 pix)
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LayerNorm([512, 4, 4]),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 7th layer (4 pix to 1 pix)
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
        )

    def forward(self, x):
        return self.discriminator_networks(x).view(-1, 1).squeeze(1)

In [7]:
def gradient_penalty(discriminator, real_imgs, fake_imgs, device=DEVICE):
    batch_size = real_imgs.size(0)

    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    alpha = alpha.expand_as(real_imgs)

    interpolated = alpha * real_imgs + (1 - alpha) * fake_imgs
    interpolated = interpolated.to(device)
    interpolated.requires_grad_(True)

    interpolated_scores = discriminator(interpolated)

    gradients = autograd.grad(
        outputs=interpolated_scores,
        inputs=interpolated,
        grad_outputs=torch.ones_like(interpolated_scores, device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)

    penalty = ((gradient_norm - 1) ** 2).mean()
    return penalty

def discriminator_loss(discriminator, real_imgs, fake_imgs, lambda_gp=LAMBDA_GP, device=DEVICE):
    real_scores = discriminator(real_imgs)
    fake_scores = discriminator(fake_imgs)

    wasserstein_distance = real_scores.mean() - fake_scores.mean()

    gp = gradient_penalty(discriminator, real_imgs, fake_imgs, device=device)

    d_loss = -wasserstein_distance + lambda_gp * gp
    return d_loss

def generator_loss(discriminator, fake_imgs):
    fake_scores = discriminator(fake_imgs)
    g_loss = -fake_scores.mean()
    return g_loss

In [8]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, mean=1.0, std=0.02)
        nn.init.constant_(m.bias, 0)

In [9]:
generator = Generator(z_dim=Z_DIM, img_channels=CHANNELS).cuda()
discriminator = Discriminator(img_channels=CHANNELS).cuda()
generator.apply(weights_init)
discriminator.apply(weights_init)

optimizer_g = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

dataloader = get_loader()
real_imgs, labels = next(iter(dataloader))

print(real_imgs.shape)  # If the value matches [BATCH_SIZE, 1, 256, 256], it's correct.
print(real_imgs.min().item(), real_imgs.max().item())  # See the max normalizedd value and min normalized value.
print(f"Number of Training Images: {len(dataloader.dataset)}")

torch.Size([32, 1, 256, 256])
-1.0 1.0
Number of Training Images: 17456


In [10]:
print(generator)
print(discriminator)

Generator(
  (generator_networks): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

In [11]:
def generate_examples(generator, epoch, n=1):
    generator.eval()
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).cuda()
            generated_img = generator(noise).cuda()
            if not os.path.exists(f'saved_examples/epoch{epoch+1}'):
                os.makedirs(f'saved_examples/epoch{epoch+1}')
            for j, img in enumerate(generated_img):
                img_path = os.path.join(f"saved_examples/epoch{epoch+1}", f"generated_image_{i * BATCH_SIZE + j + 1}.tif")
                torchvision.utils.save_image(img, img_path, normalize=True)
    print(f"Saved examples of generated images.")
    generator.train()

In [12]:
def save_model(generator, discriminator, epoch):
    generator.eval()
    discriminator.eval()
    if not os.path.exists(f'saved_models/epoch{epoch+1}'):
        os.makedirs(f'saved_models/epoch{epoch+1}')
    torch.save(generator, f'saved_models/epoch{epoch+1}/generator_epoch{epoch+1}.pth')
    torch.save(discriminator, f'saved_models/epoch{epoch+1}/discriminator_epoch{epoch+1}.pth')
    print("Models were stored.")
    generator.train()
    discriminator.train()

In [13]:
def save_loss_value(loss_log):
    if not os.path.exists(f'saved_loss_value'):
        os.makedirs(f'saved_loss_value')
    log_path = os.path.join("./saved_loss_value", "loss_log.txt")
    with open(log_path, "w") as f:
        for epoch, (g_loss, d_loss) in enumerate(zip(loss_log["Generator Loss"], loss_log["Discriminator Loss"]), 1):
            f.write(f"Epoch {epoch}: Generator Loss: {g_loss:.10f}, Discriminator Loss: {d_loss:.10f}\n")
    print("Saved loss log.")

In [14]:
loss_log = {
    "Generator Loss": [],
    "Discriminator Loss": []
}

for epoch in range(EPOCHS):
    
    for i, (real_imgs, _) in enumerate (dataloader):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.cuda()

        optimizer_d.zero_grad()

        noise = torch.randn(batch_size, Z_DIM, 1, 1).cuda()

        fake_imgs = generator(noise)
        
        d_loss = discriminator_loss(discriminator, real_imgs, fake_imgs)

        d_loss.backward()
        optimizer_d.step()

        optimizer_g.zero_grad()

        fake_imgs = generator(noise)

        g_loss = generator_loss(discriminator, fake_imgs)
        
        g_loss.backward()
        optimizer_g.step()

    loss_log["Generator Loss"].append(g_loss)
    loss_log["Discriminator Loss"].append(d_loss)

    print(f"Epoch: {epoch+1}/{EPOCHS}")
    print(f"Generator Loss: {g_loss:.10f}")
    print(f"Discriminator Loss: {d_loss:.10f}")

    if (epoch+1) % SAVE_INTERVAL == 0:
        generate_examples(generator, epoch)
        save_model(generator, discriminator, epoch)
        save_loss_value(loss_log)

Epoch: 1/1000
Generator Loss: 44.2636947632
Discriminator Loss: 6.4218907356
Epoch: 2/1000
Generator Loss: -155.1075439453
Discriminator Loss: 3.2641997337
Epoch: 3/1000
Generator Loss: -177.9539489746
Discriminator Loss: 7.5641145706
Epoch: 4/1000
Generator Loss: -122.3797988892
Discriminator Loss: -5.3227539062
Epoch: 5/1000
Generator Loss: 170.5669555664
Discriminator Loss: -5.7698464394
Epoch: 6/1000
Generator Loss: 21.1111907959
Discriminator Loss: 7.1528482437
Epoch: 7/1000
Generator Loss: -26.3372955322
Discriminator Loss: 11.3047246933
Epoch: 8/1000
Generator Loss: -136.0519104004
Discriminator Loss: -9.9814958572
Epoch: 9/1000
Generator Loss: -434.1154174805
Discriminator Loss: 7.6762843132
Epoch: 10/1000
Generator Loss: 101.6653213501
Discriminator Loss: -120.6568298340
Saved examples of generated images.
Models were stored.
Saved loss log.
Epoch: 11/1000
Generator Loss: -94.6670532227
Discriminator Loss: -62.7009582520
Epoch: 12/1000
Generator Loss: 411.6562500000
Discrimina