In [56]:
import os
import numpy as np
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
#import torchvision.transforms as transforms
from torchvision.transforms import v2
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from PIL import Image
import wandb
from lightning.pytorch.loggers import WandbLogger
import datetime
from torchvision.io import decode_image

torch.set_float32_matmul_precision("medium")

In [57]:
# experiment tracking
#47080269e7b1b5a51a89830cb24c495498237e77

wandb.login()
wandb_logger = WandbLogger(project="JANGAN")

In [58]:
now = datetime.datetime.now()
now_str = now.strftime("%Y-%m-%d_%H-%M-%S")
print(now_str)

# make a new folder with the name now_str and save images there
os.makedirs(now_str, exist_ok=True)

2024-11-03_18-31-00


In [59]:
# load iamges from the folder, 70k total 
# from kaggle, 128x128 ffhq from kaggle
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image


# 256 batch size, 128x128 images, 8 cpu cores for batches
class CustomDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "archive", img_size: int = 128, batch_size: int = 256, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = v2.Compose([
            # v2.Resize(size=(128, 128)),
            # v2.CenterCrop(self.img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        self.dataset = CustomImageDataset(img_dir=self.data_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,)

In [60]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        # def block(in_feat, out_feat, normalize=True):
        #     layers = [nn.Linear(in_feat, out_feat)]
        #     if normalize:
        #         layers.append(nn.BatchNorm1d(out_feat, 0.8))
        #     layers.append(nn.LeakyReLU(0.01, inplace=True))
        #     return layers

        # self.model = nn.Sequential(
        #     *block(latent_dim, 128, normalize=False),
        #     *block(128, 256),
        #     *block(256, 512),
        #     *block(512, 1024),
        #     nn.Linear(1024, int(np.prod(img_shape))),
        #     nn.Tanh(),
        # )

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.GELU(),
            nn.Linear(128, 256),
            nn.GELU(),
            nn.Linear(256, 512),
            nn.GELU(),
            nn.Linear(512, 1024),
            nn.GELU(),
            nn.Linear(1024, 2048),
            nn.GELU(),
            nn.Linear(2048, 4096),
            nn.GELU(),
            nn.Linear(4096, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img



class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()


        # self.model = nn.Sequential(
        #     nn.Linear(int(np.prod(img_shape)), 512),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Linear(512, 256),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Linear(256, 1),
        #     nn.Sigmoid(),
        # )

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 2048),
            nn.GELU(),
            nn.Linear(2048, 1024),
            nn.GELU(),
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Linear(128, 1)
        )
        

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity


In [61]:
# opt_g = torch.optim.Adam(generator.parameters(), lr=0.0001)
# opt_d = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

# sch_g = ReduceLROnPlateau(opt_g,
#                 mode=scheduler_kwargs.get('mode', 'min'),
#                 factor=scheduler_kwargs.get('factor', 0.5),
#                 patience=scheduler_kwargs.get('patience', 10),
#                 min_lr=scheduler_kwargs.get('min_lr', 1e-6) )
# sch_d = ReduceLROnPlateau( opt_d,               
#                 mode=scheduler_kwargs.get('mode', 'min'),
#                 factor=scheduler_kwargs.get('factor', 0.5),
#                 patience=scheduler_kwargs.get('patience', 10),
#                 min_lr=scheduler_kwargs.get('min_lr', 1e-6))

In [62]:
class GAN(L.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.00001,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 256,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        #return F.binary_cross_entropy(y_hat, y)
        return torch.nn.BCEWithLogitsLoss(y_hat, y)

    def training_step(self, batch):
        imgs = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # log sampled images
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_validation_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image("validation/generated_images", grid, self.current_epoch)
        save_image(sample_imgs, f"{now_str}/{self.global_step}.png", nrow=5)
        wandb.log({"examples": [wandb.Image(image) for image in sample_imgs]})

In [63]:
# def sample_images(self):
#         z = torch.randn(25, self.hparams.latent_dim, device=self.device)
#         gen_imgs = self(z)
#         save_image(gen_imgs, f"{now_str}/{self.global_step}.png", nrow=5, normalize=True)
#         wandb.log({"examples": [wandb.Image(image) for image in gen_imgs]})

In [64]:
# Set up data
dm = CustomDataModule()

# Set up model
model = GAN(
    channels=3,
    width=128,
    height=128,
    latent_dim=128,
    lr=0.0001,
    n_critic=5,
    clip_value=0.01,
    sample_interval=400
)

# Set up trainer
trainer = L.Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,
    max_epochs=1000,
    logger=wandb_logger,
    precision="bf16-mixed",
    # default_root_dir={now_str}
)

# Train the model
trainer.fit(model, dm)

# save the model

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params | Mode  | In sizes | Out sizes       
--------------------------------------------------------------------------------------
0 | generator     | Generator     | 212 M  | train | [2, 128] | [2, 3, 128, 128]
1 | discriminator | Discriminator | 103 M  | train | ?        | ?               
--------------------------------------------------------------------------------------
316 M     Trainable params
0         Non-trainable params
316 M     Total params
1,264.107 Total estimated model params size (MB)
29        Modules in train mode
0         Modules in eval mode


Epoch 0:   0%|          | 0/274 [00:00<?, ?it/s] 

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [None]:
wandb.finish()