In [5]:
import wandb
from lightning.pytorch.loggers import WandbLogger
#47080269e7b1b5a51a89830cb24c495498237e77

wandb.login()





True

In [6]:
wandb_logger = WandbLogger(project="JANGAN")

In [7]:
import datetime
now = datetime.datetime.now()
now_str = now.strftime("%Y-%m-%d_%H-%M-%S")
print(now_str)

# make a new folder with the name now_str
import os
os.makedirs(now_str, exist_ok=True)



2024-11-01_11-45-32


In [8]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from PIL import Image

torch.set_float32_matmul_precision("medium")

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

class CustomDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "archive", img_size: int = 128, batch_size: int = 256, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.CenterCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def setup(self, stage=None):
        self.dataset = CustomImageDataset(img_dir=self.data_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

# class Generator(nn.Module):
#     def __init__(self, latent_dim, img_shape):
#         super().__init__()
#         self.img_shape = img_shape

#         def block(in_feat, out_feat, normalize=True):
#             layers = [nn.Linear(in_feat, out_feat)]
#             if normalize:
#                 layers.append(nn.BatchNorm1d(out_feat, 0.8))
#             layers.append(nn.LeakyReLU(0.2, inplace=True))
#             return layers

#         self.model = nn.Sequential(
#             *block(latent_dim, 128, normalize=False),
#             *block(128, 256),
#             *block(256, 128),
#             nn.Linear(128, int(np.prod(img_shape))),
#             nn.Tanh()
#         )

#     def forward(self, z):
#         img = self.model(z)
#         img = img.view(img.size(0), *self.img_shape)
#         return img
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        # self.model = nn.Sequential(
        #     nn.Linear(int(np.prod(img_shape)), 128),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Linear(128, 256),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Linear(256, 1),
        # )
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
        

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

class GAN(pl.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 128,
        lr: float = 0.01,
        n_critic: int = 5,
        clip_value: float = 0.01,
        sample_interval: int = 400,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)
        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch, batch_idx):
        imgs = batch

        optimizer_g, optimizer_d = self.optimizers()
        
        # train discriminator
        self.toggle_optimizer(optimizer_d)
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
        fake_imgs = self(z).detach()

        real_validity = self.discriminator(imgs)
        fake_validity = self.discriminator(fake_imgs)

        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
        self.log("d_loss", d_loss, prog_bar=False)
        self.manual_backward(d_loss)
        optimizer_d.step()
        self.log("lr", self.hparams.lr)
        optimizer_d.zero_grad()

        # Clip weights of discriminator
        for p in self.discriminator.parameters():
            p.data.clamp_(-self.hparams.clip_value, self.hparams.clip_value)

        self.untoggle_optimizer(optimizer_d)

        # train generator
        if batch_idx % self.hparams.n_critic == 0:
            self.toggle_optimizer(optimizer_g)
            z = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
            gen_imgs = self(z)
            gen_validity = self.discriminator(gen_imgs)
            g_loss = -torch.mean(gen_validity)
            self.log("g_loss", g_loss, prog_bar=False)
            self.log("total_loss", g_loss + d_loss, prog_bar=False)
            self.manual_backward(g_loss)
            optimizer_g.step()
            optimizer_g.zero_grad()
            self.untoggle_optimizer(optimizer_g)

        if self.global_step % self.hparams.sample_interval == 0:
            self.sample_images()
        # self.log("lr", self.lr)

    def sample_images(self):
        z = torch.randn(25, self.hparams.latent_dim, device=self.device)
        gen_imgs = self(z)
        save_image(gen_imgs, f"{now_str}/{self.global_step}.png", nrow=5, normalize=True)
        wandb.log({"examples": [wandb.Image(image) for image in gen_imgs]})

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0001)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001)
        return [opt_g, opt_d], []

# Set up data
dm = CustomDataModule()

# Set up model
model = GAN(
    channels=3,
    width=128,
    height=128,
    latent_dim=128,
    lr=0.0001,
    n_critic=5,
    clip_value=0.01,
    sample_interval=400
)

# Set up trainer
trainer = pl.Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,
    max_epochs=1000,
    logger=wandb_logger,
    precision="bf16",
    default_root_dir={now_str}
)

# Train the model
trainer.fit(model, dm)

# save the model


/home/fil/miniconda3/envs/ML/lib/python3.12/site-packages/lightning_fabric/connector.py:571: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/fil/miniconda3/envs/ML/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params | Mode  | In sizes | Out sizes       
--------------------------------------------------------------------------------------
0 | generator     | Generator     | 51.1 M | train | [2, 128] | [2, 3, 128, 128]
1 | discriminator | Discriminator

Epoch 517:  18%|█▊        | 50/274 [00:02<00:11, 20.34it/s, v_num=k5ko] 

/home/fil/miniconda3/envs/ML/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [9]:
wandb.finish()

0,1
d_loss,▁▂▄▅▃▁▂▁▅▇▇▃▂▄▄▂▃▂▄▄▂█▅▁▁▄▆▆▃▁▃▅▂▂▄▄▇▅▃▂
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇███████
g_loss,▆█▆█▇▇▆▆▇▁▆▇▃▇█▇██▇▅██▇██▆▅▇▅▆▆██▆█▅▆▇▃▃
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_loss,▅▃▃▅▅▅▅▆▃▆▃▅▄▅█▅▃▅▄▅▄▄▃▃▅▃▄▃▅▃▃▅▅▃▁▅▃▃▄▃
trainer/global_step,▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇█████

0,1
d_loss,-0.87891
epoch,517.0
g_loss,-0.1001
lr,0.0001
total_loss,-0.99219
trainer/global_step,141699.0
