In [None]:
from data import sprites
latent_dim = 256
trainloader, valloader = sprites.get_loader(
        batch_size=64,
        workers=16,
        val_ratio=0.0, 
        path="../data/external/sprites"
        )

In [None]:
data, label = next(iter(trainloader))
print(data.shape)

In [None]:
for item, label in iter(trainloader):
    pass

In [None]:
from pl_bolts.models.gans import GAN, DCGAN
from autoencoder.encoders import ConvEncoder
from autoencoder.decoders import ConvDecoder
from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule
import torchmetrics

import torch 
from torch import Tensor

class CustomDCGAN(DCGAN):
    def __init__(self, swap_prob=0.3, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.real_accuracy = torchmetrics.Accuracy()
        self.fake_accuracy = torchmetrics.Accuracy()
        self.swap_prob = swap_prob

    def _get_disc_loss(self, real: Tensor) -> Tensor:
        # Train with real
        # print(f"real shape: {real.shape}")
        real_pred = self.discriminator(real)
        # print(f"real_pred shape: {real_pred.shape}")
        # real_gt = torch.ones_like(real_pred)

        # real_acc = self.real_accuracy(real_pred.cpu(), torch.ones_like(real_pred).type(torch.long).cpu())
        real_acc = self.real_accuracy(real_pred, torch.ones_like(real_pred, device=real_pred.device).type(torch.long))
        self.log("acc/real", real_acc, on_epoch=True, on_step=False)

        # Train with fake
        fake_pred = self._get_fake_pred(real)
        # fake_gt = torch.zeros_like(fake_pred)
        fake_acc = self.fake_accuracy(fake_pred, torch.zeros_like(fake_pred, device=fake_pred.device).type(torch.long))
        # fake_acc = self.fake_accuracy(fake_pred.cpu(), torch.zeros_like(fake_pred).type(torch.long).cpu())
        self.log("acc/fake", fake_acc, on_epoch=True, on_step=False)

        # Soften labels
        real_gt = (0.7 - 1.2) * torch.rand_like(real_pred) + 1.2
        fake_gt = (0.0 - 0.3) * torch.rand_like(fake_pred) + 0.3

        # Randomly swap labels
        swap_chances = torch.rand_like(real_pred)
        # print(swap_chances.shape)
        real_gt = torch.where(swap_chances < self.swap_prob, fake_gt, real_gt)
        fake_gt = torch.where(swap_chances < self.swap_prob, real_gt, fake_gt)

        # Loss
        real_loss = self.criterion(real_pred, real_gt)
        fake_loss = self.criterion(fake_pred, fake_gt)

        disc_loss = real_loss + fake_loss

        dis_acc = (real_acc + fake_acc) / 2 # MAYBE BAD?
        self.log("acc/total", dis_acc, on_epoch=True, on_step=False)

        return disc_loss

class BenGAN(DCGAN):
    def init_generator(self, img_dim):
        generator = ConvDecoder(
            latent_shape=self.hparams.latent_dim, 
            output_shape=img_dim,
            final_activation=torch.tanh,)
        return generator
    
    def init_discriminator(self, img_dim):
        discriminator = ConvEncoder(
            input_shape=img_dim,
            latent_shape=1,
            final_activation=torch.sigmoid,
            dropout_rate=0.1)
        return discriminator

In [None]:
from pl_bolts.models.gans import GAN, DCGAN, SRResNet, SRGAN
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
import torchvision

logger = WandbLogger(project="gan_test")
# gan = BenGAN(3, 96, 96, latent_dim=latent_dim)
# gan = DCGAN(image_channels=3, feature_maps_disc=16, feature_maps_gen=16, latent_dim=latent_dim)
# gan = CustomDCGAN(image_channels=3, feature_maps_disc=32, feature_maps_gen=128, latent_dim=latent_dim)
gan = CustomDCGAN(image_channels=3, latent_dim=64)
# gan = DCGAN(image_channels=3)
# gan = SRResNet()
# gan = SRGAN(scale_factor=2)
# gan = SRResNet(image_channels=3, feature_maps=16, latent_dim=latent_dim)
print(gan)
# gan = GAN(3, 96, 96, latent_dim=latent_dim)
#gan = GAN(1, 28, 28, latent_dim=latent_dim)

In [None]:
trainer = Trainer(
    accelerator="auto",
    devices=-1, 
    logger=logger, 
    # logger=pl.
    callbacks=[
        TensorboardGenerativeModelImageSampler(num_samples=64, nrow=8, normalize=True, ),
        LatentDimInterpolator(interpolate_epoch_interval=4)],
    strategy="dp",
    max_epochs=500, 
    log_every_n_steps=10,
    track_grad_norm=2)

# trainloadersssss = CIFAR10DataModule(
#     "/tmp/", 
#     num_workers=16,
#     drop_last=True,
#     train_transforms=torchvision.transforms.Compose([
#         torchvision.transforms.RandomHorizontalFlip(),
#         torchvision.transforms.Resize((64, 64)),
#         torchvision.transforms.ToTensor(),
#         torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#     ]),
#     pin_memory=True,
#     batch_size=256)
#trainer.fit(gan, train_dataloaders=MNISTDataModule(num_workers=0, batch_size=64))

In [None]:
trainer.fit(gan, train_dataloaders=trainloader) #, val_dataloaders=valloader)

In [None]:
import torch
noise = torch.rand(32, latent_dim).to(device=gan.device)
img = gan(noise)

In [None]:
type(img.detach().cpu().numpy())

In [None]:
image = img.detach().cpu().numpy()[0] 
image.shape

In [None]:
image = image.swapaxes(0,1)
image = image.swapaxes(1,2)
#image = image.swapaxes(0,2)
image.shape

In [None]:
import torch
noise = torch.rand(16, latent_dim).to(device=gan.device)
img = gan(noise)

from PIL import Image
import numpy as np
images = img.detach().cpu().numpy()
for image in images:
    image = np.squeeze(image)
    print(image.shape)
    image = image.swapaxes(0,1)
    image = image.swapaxes(1,2)
    PIL_image = Image.fromarray((image*255).astype('uint8'),'RGB')
    #PIL_image = Image.fromarray((image).astype('uint8'),'F')
    #PIL_image = Image.fromarray((image*255),'F')
    PIL_image.show()

# PIL_image = Image.fromarray(img[0]).astype('uint8'), 'RGB')