In [None]:
import torch as pt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision.datasets import CIFAR100
from torchvision import transforms

from matplotlib import pyplot as plt

%matplotlib inline

In [None]:
transforms_ = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
])
train_ds = CIFAR100("./data", train=True, download=True, transform=transforms_)

CLASSES = train_ds.classes

In [None]:
train_loader = DataLoader(train_ds, batch_size=60, shuffle=True)

In [None]:
writer = SummaryWriter("./logdir")
device = "cuda" if pt.cuda.is_available() else "cpu"
batch_size = 60
gen_lr = 1e-4
desc_lr = 1e-5

In [None]:
def selected_n_random_dataset_samples(data, labels, n=200):
    values = pt.randperm(len(data))
    return data[values][:n], labels[values][:n]

images, labels = selected_n_random_dataset_samples(train_ds.data, train_ds.targets)

class_labels = [CLASSES[labels] for label in labels]

writer.add_embedding(
    images.view(-1, 32*32),
    metadata=class_labels,
    label_img=images.unsqueeze(1)
)
writer.flush()

In [None]:
def plot_images(images, labels=None):    

    images = images.detach().numpy() if type(images) == pt.Tensor else images

    for idx in range(0, len(images)):
        plt.subplot(2, 10, idx+1) if len(images) > 10 else plt.subplot(1, len(images), idx+1)

        plt.imshow(images[idx].T)
        plt.axis("off")

        if labels is not None:
            plt.title(CLASSES[labels[idx].item()])

In [None]:
plt.figure(figsize=(20, 4), dpi=300)

n_samples_per_class = {
    class_: 0 for class_ in CLASSES
}

for _, labels in train_loader:
    for label in labels:
        n_samples_per_class[CLASSES[label.item()]] += 1

plt.xticks(rotation=90, ha="center")
plt.margins(x=0.01)
plt.bar(n_samples_per_class.keys(), n_samples_per_class.values())

In [None]:
images, labels = next(iter(train_loader))
images = images[:20]
labels = labels[:20]

plt.figure(figsize=(20, 4), dpi=300)
plot_images(images, labels)

writer.add_images(tag="20 Images from Train Data Loader", img_tensor=images)
writer.flush()

In [None]:
print (f"Train dataset size: {train_ds.data.size}")

In [None]:
import numpy as np

from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from pytorch_model_summary import summary

from tqdm import tqdm

In [None]:
class Generator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.linear = nn.Sequential(
            nn.Linear(in_features=384, out_features=384*4*4),
            nn.LeakyReLU(0.2),
            # Reshaping into (384, 4, 4) the input in forward method
        )

        self.upsample = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2), # (384, 8, 8)
            nn.Conv2d(in_channels=384, out_channels=512, kernel_size=5, padding="same"), # (512, 8, 8)
            nn.LeakyReLU(0.2),
            nn.UpsamplingBilinear2d(scale_factor=2), # (512, 16, 16)
            nn.Conv2d(in_channels=512, out_channels=768, kernel_size=5, padding="same"), # (768, 16, 16)
            nn.LeakyReLU(0.2),
            nn.UpsamplingBilinear2d(scale_factor=2), # (768, 32, 32)
            nn.Conv2d(in_channels=768, out_channels=1024, kernel_size=5, padding="same"), # (1024, 32, 32)
            nn.LeakyReLU(0.2),
        )

        self.down_sample = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=4, padding="same"),  # (512, 32, 32)
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 3, kernel_size=4, padding="same"),  # (3, 32, 32)
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1, 384, 4, 4)
        x = self.upsample(x)
        x = self.down_sample(x)
        return F.tanh(x)

In [None]:
generator = Generator().to(device)
generator_optimizer = Adam(generator.parameters(), lr=gen_lr)

In [None]:
def generator_loss(pred, label):
    return F.binary_cross_entropy(pred, label)

In [None]:
print (summary(generator.to("cpu"),pt.Tensor(384, device="cpu")))

In [None]:
class Descriminator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.__conv_block = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1), # ((32 - 3) + 2*0) / 1 + 1 = 30
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.MaxPool2d(2), # Max pool kernel size is 2x2, hence 30 / 2 = 15
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),  # ((15 - 3) + 2*0) / 1 + 1 = 13
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),  # ((13 - 3) + 2*0) / 1 + 1 = 11
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.MaxPool2d(2),  # Max pool kernel size is 2x2, hence 13 / 2 = 5
        )
        self.__linear_block = nn.Sequential(
            nn.Linear(in_features=128*5*5, out_features=512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(in_features=512, out_features=1),
        )

    def forward(self, x):
        x = self.__conv_block(x)
        x = x.view(-1, 128*5*5)
        x = self.__linear_block(x)
        return F.sigmoid(x)

In [None]:
descriminator = Descriminator().to(device)
descriminator_optimizer = Adam(descriminator.parameters(), lr=desc_lr)

In [None]:
print (summary(descriminator.to("cpu"), pt.zeros((1, 3, 32, 32), device="cpu")))

In [None]:
def descriminator_loss(pred, label):
    return F.binary_cross_entropy(pred, label)

In [None]:
plt.figure(figsize=(20, 4), dpi=300)

generator = generator.to("cpu").eval()
garbage = pt.from_numpy(np.random.normal(size=(4, 1, 384))).to(dtype=pt.float32, device="cpu")

predicted = generator(garbage)
plot_images(predicted)

In [None]:
writer.add_graph(generator, garbage)
writer.flush()

In [None]:
descriminator = descriminator.eval().to(device)
predicted = descriminator(images)

print (f"Dataset image size: {images.size()}")
print (f"Model Generate images: {predicted}")

In [None]:
writer.add_graph(descriminator, images)
writer.flush()

In [None]:
def train(current_epoch, total_epochs):
    for i, (images, _) in enumerate(train_loader):

        real_images = images.to(device)

        generator_optimizer.zero_grad()

        generator = generator.to(device).eval()
        predicted_images = generator(
            pt.from_numpy(np.random.normal((batch_size, 1, 384))).to(dtype=pt.float32, device=device),
        )

        if True:
            descriminator_optimizer.zero_grad()

            descriminator = descriminator.to(device).train()
            yhat_real = descriminator(real_images)
            yhat_predicted = descriminator(predicted_images)

            yhat_real_predicted = pt.concat([pt.ones_like(yhat_real), pt.zeros_like(yhat_predicted)], dim=0)
            # real images label: 1
            # predicted images label: 0

            noise_real = 0.2 * pt.from_numpy(
                np.random.uniform(yhat_real.shape)
            )
            noise_predicted = -0.2 * pt.from_numpy(
                np.random.uniform(yhat_predicted.shape)
            )
            y_real_predicted = pt.concat([noise_real, noise_predicted], dim=0)

            desc_loss = desc_loss(y_real_predicted, yhat_real_predicted)

            desc_loss.backward()

            descriminator_optimizer.step()
        
        generator = generator.to(device).train()

        generated_images = generator(
            pt.from_numpy(np.random.normal((batch_size, 1, 384))).to(dtype=pt.float32, device=device),
        )
    
        descriminator = descriminator.to(device).eval()
        predicted_labels = descriminator(generated_images)

        gen_loss = generator_loss(
            pt.zeros_like(predicted_labels), predicted_labels
        )

        gen_loss.backward()

        generator_optimizer.step()

        print (f"Epoch: {current_epoch+1}/{total_epochs} Descriminator loss: {desc_loss.item():.6f}", end="; ")
        print (f"Generator Training loss: {gen_loss.item():.6f}")

        yield {
            "desc_loss": desc_loss.item(),
            "gen_loss": gen_loss.item()
        }

In [None]:
total_epochs = 2000

desc_loss, gen_loss = [], []

for epoch in range(total_epochs):
    temp = next(train(epoch, total_epochs, generator, descriminator))

    desc_loss.append(temp["desc_loss"])
    gen_loss.append(temp["gen_loss"])

    writer.add_scalars(
        "Generator Loss vs Descriminator Loss",
        {"Generator Loss": gen_loss, "Descriminator Loss": desc_loss}
    )
writer.flush()
writer.close()

In [None]:
plt.subplot(1, 4, 1)
plt.plot(desc_loss, label="Descriminator Loss")
plt.legend(loc="upper right")
plt.title("Descriminator Train Loss")

plt.subplot(1, 4, 3)
plt.plot(gen_loss, label="Generator Loss")
plt.legend(loc="upper right")
plt.title("Descriminator Train Accuracy")
plt.title("Generator Train Loss")

In [None]:
pt.save(generator, "./generator.pt")
pt.save(descriminator, "./descriminator.pt")