In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import glob
import multiprocessing
from good_tile_gan.generator import Generator
from good_tile_gan.discriminator import Discriminator

In [3]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    n_gpu = float(torch.cuda.device_count())
    device_name = torch.cuda.get_device_name(DEVICE)
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    device_name = "Apple Silicon"
    n_gpu = 0.0
else:
    DEVICE = torch.device("cpu")
    device_name = "CPU"
    n_gpu = 0.0
    
torch.manual_seed(0)

n_cores = multiprocessing.cpu_count()
print(f"Number of GPUs: {n_gpu} / Number of CPU Cores: {n_cores}")
print(f"Training on {device_name} ({DEVICE})")

Number of GPUs: 1.0 / Number of CPU Cores: 24
Training on NVIDIA GeForce RTX 4090 (cuda)


In [None]:
train_good_dir = "./tile/train/good"
test_good_dir = "./tile/test/good"
image_size = 256
batch_size = 32


In [None]:
class TileDataset(Dataset):
    def __init__(self, train_good_dir, test_good_dir, image_size):
        super().__init__()
        self.image_size = image_size
        self.image_files = []

        train_files = glob.glob(os.path.join(train_good_dir, '*.png'))
        self.image_files.extend(train_files)
        test_files = glob.glob(os.path.join(test_good_dir, '*.png')) 
        self.image_files.extend(test_files)
        print(f"- Found {len(train_files)} images in train_good_dir: {train_good_dir}")
        print(f"- Found {len(test_files)} images in test_good_dir: {test_good_dir}")
        print(f"-> Total {len(self.image_files)} 'good' samples collected.")

        self.transform = transforms.Compose([
                transforms.Resize((self.image_size, self.image_size),
                                  interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        print("Using default transforms (Resize, ToTensor, Normalize to [-1, 1]).")

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path)
        image = image.convert('RGB')
        processed_image = self.transform(image)
        return processed_image

In [10]:
tile_dataset = TileDataset(train_good_dir, test_good_dir, image_size)

- Found 230 images in train_good_dir: ./tile/train/good
- Found 33 images in test_good_dir: ./tile/test/good
-> Total 263 'good' samples collected.
Using default transforms (Resize, ToTensor, Normalize to [-1, 1]).


In [None]:
dataloader = DataLoader(
        tile_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

In [None]:
latent_dim = 100
num_channels = 3
ngf = 64
ndf = 64
epochs = 800
lr_g = 0.0002
lr_d = 0.0002
beta1 = 0.5
beta2 = 0.999

In [None]:
generator = Generator(latent_dim, num_channels, image_size, ngf).to(DEVICE)
discriminator = Discriminator(num_channels, image_size, ndf).to(DEVICE)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
discriminator.apply(weights_init)

optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))

adversarial_loss = nn.BCEWithLogitsLoss()

In [None]:
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=DEVICE)

for epoch in range(epochs):
    for i, real_images in enumerate(dataloader):
        real_images = real_images.to(DEVICE)
        batch_size = real_images.size(0)

        real_label_val = 1.0
        fake_label_val = 0.0
        real_labels = torch.empty((batch_size,), device=DEVICE).fill_(real_label_val)
        fake_labels = torch.empty((batch_size,), device=DEVICE).fill_(fake_label_val)

        discriminator.zero_grad()

        output_real = discriminator(real_images).view(-1)
        errD_real = adversarial_loss(output_real, real_labels)
        errD_real.backward()
        D_x = output_real.mean().item()

        noise = torch.randn(batch_size, latent_dim, 1, 1, device=DEVICE)
        fake_images = generator(noise).detach()
        output_fake = discriminator(fake_images).view(-1)
        errD_fake = adversarial_loss(output_fake, fake_labels)
        errD_fake.backward()
        D_G_z1 = output_fake.mean().item()

        errD = errD_real + errD_fake
        optimizer_D.step()

        generator.zero_grad()

        noise_for_G = torch.randn(batch_size, latent_dim, 1, 1, device=DEVICE)
        fake_images_for_G = generator(noise_for_G)
        output_G = discriminator(fake_images_for_G).view(-1)

        errG = adversarial_loss(output_G, real_labels)

        errG.backward()
        D_G_z2 = output_G.mean().item()
        optimizer_G.step()

        if i % 50 == 0:
            print(
                f'[{epoch+1}/{epochs}][{i}/{len(dataloader)}] '
                f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}'
            )

print("Training Finished.")

In [None]:
generator.eval()

num_images_to_generate = 16
noise = torch.randn(num_images_to_generate, latent_dim, 1, 1, device=DEVICE)

with torch.no_grad():
    fake_images = generator(noise).detach().cpu()

    

In [None]:

import torchvision.utils as vutils

output_filename = "generated_tiles.png"
vutils.save_image(fake_images, output_filename, normalize=True, nrow=4)
print(f"Saved generated images to {output_filename}")

import matplotlib.pyplot as plt
import torchvision.transforms.functional as F

grid = vutils.make_grid(fake_images, nrow=4, normalize=True)
img_to_show = F.to_pil_image(grid)

plt.figure(figsize=(8, 8))
plt.imshow(img_to_show)
plt.axis("off")
plt.title("Generated Tiles")
plt.show()