In [None]:
import torch as pt
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter

from torchvision.datasets import CIFAR100
from torchvision import transforms

from matplotlib import pyplot as plt

%matplotlib inline

In [None]:
transforms_ = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
])
train_ds = CIFAR100("./data", train=True, download=True, transform=transforms_)

CLASSES = train_ds.classes

In [None]:
# writer = SummaryWriter("./logdir")
device = "cuda" if pt.cuda.is_available() else "cpu"
batch_size = 60
gen_lr = 1e-4
desc_lr = 1e-5

In [None]:
train_loader = DataLoader(train_ds, batch_size=60, shuffle=True)

In [None]:
import numpy as np

In [None]:
# def selected_n_random_dataset_samples(data, labels, n=200):
#     perm = pt.randperm(len(data))
#     return data[perm][:n], np.array(labels)[perm][:n].tolist()

# images, labels = selected_n_random_dataset_samples(train_ds.data, train_ds.targets)

# features = pt.from_numpy(images).reshape(shape=(-1, 32 * 32))
# class_labels = [CLASSES[label] for label in labels]

# print (f"len(features): {len(features)}")
# print (f"len(class_labels): {len(class_labels)}")

# writer.add_embedding(
#     features,
#     metadata=class_labels,
#     label_img=pt.from_numpy(images).unsqueeze(1)
# )
# writer.flush()

In [None]:
def plot_images(images, labels=None):    

    images = images.detach().cpu().numpy() if type(images) == pt.Tensor else images

    for idx in range(0, len(images)):
        plt.subplot(2, 10, idx+1) if len(images) > 10 else plt.subplot(1, len(images), idx+1)

        plt.imshow(images[idx].reshape())
        plt.axis("off")

        if labels is not None:
            plt.title(CLASSES[labels[idx].item()])

In [None]:
plt.figure(figsize=(20, 4), dpi=300)

n_samples_per_class = {
    class_: 0 for class_ in CLASSES
}

for _, labels in train_loader:
    for label in labels:
        n_samples_per_class[CLASSES[label.item()]] += 1

plt.xticks(rotation=90, ha="center")
plt.margins(x=0.01)
plt.bar(n_samples_per_class.keys(), n_samples_per_class.values())

In [None]:
images, labels = next(iter(train_loader))
images = images[:20]
labels = labels[:20]

plt.figure(figsize=(20, 4), dpi=300)
plot_images(images, labels)

# writer.add_images(tag="20 Images from Train Data Loader", img_tensor=images)
# writer.flush()

In [None]:
print (f"Train dataset size: {train_ds.data.size}")

In [None]:
import numpy as np

from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from pytorch_model_summary import summary

from tqdm import tqdm

In [None]:
class Generator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.linear = nn.Sequential(
            nn.Linear(in_features=384, out_features=384*4*4),
            nn.LeakyReLU(0.2),
            # Reshaping into (384, 4, 4) the input in forward method
        )

        self.upsample = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2), # (384, 8, 8)
            nn.Conv2d(in_channels=384, out_channels=512, kernel_size=5, padding="same"), # (512, 8, 8)
            nn.LeakyReLU(0.2),
            nn.UpsamplingBilinear2d(scale_factor=2), # (512, 16, 16)
            nn.Conv2d(in_channels=512, out_channels=768, kernel_size=5, padding="same"), # (768, 16, 16)
            nn.LeakyReLU(0.2),
            nn.UpsamplingBilinear2d(scale_factor=2), # (768, 32, 32)
            nn.Conv2d(in_channels=768, out_channels=1024, kernel_size=5, padding="same"), # (1024, 32, 32)
            nn.LeakyReLU(0.2),
        )

        self.down_sample = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=4, padding="same"),  # (512, 32, 32)
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 3, kernel_size=4, padding="same"),  # (3, 32, 32)
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1, 384, 4, 4)
        x = self.upsample(x)
        x = self.down_sample(x)
        return F.tanh(x)

In [None]:
generator = Generator().to(device)
generator_optimizer = Adam(generator.parameters(), lr=gen_lr)

In [None]:
# def generator_loss(pred, label):
#     return F.binary_cross_entropy(pred, label)

In [None]:
print (summary(generator.to("cpu"),pt.Tensor(384, device="cpu")))

In [None]:
class Descriminator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.__conv_block = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1), # ((32 - 3) + 2*0) / 1 + 1 = 30
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.MaxPool2d(2), # Max pool kernel size is 2x2, hence 30 / 2 = 15
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),  # ((15 - 3) + 2*0) / 1 + 1 = 13
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),  # ((13 - 3) + 2*0) / 1 + 1 = 11
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.MaxPool2d(2),  # Max pool kernel size is 2x2, hence 13 / 2 = 5
        )
        self.__linear_block = nn.Sequential(
            nn.Linear(in_features=128*5*5, out_features=512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(in_features=512, out_features=1),
        )

    def forward(self, x):
        x = self.__conv_block(x)
        x = x.view(-1, 128*5*5)
        x = self.__linear_block(x)
        return F.sigmoid(x)

In [None]:
descriminator = Descriminator().to(device)
descriminator_optimizer = Adam(descriminator.parameters(), lr=desc_lr)

In [None]:
print (summary(descriminator.to("cpu"), pt.zeros((1, 3, 32, 32), device="cpu")))

In [None]:
# def descriminator_loss(pred, label):
#     return F.binary_cross_entropy(pred, label)

In [None]:
plt.figure(figsize=(20, 4), dpi=300)

generator = generator.to(device).eval()
garbage = pt.from_numpy(np.random.normal(size=(4, 1, 384))).to(dtype=pt.float32, device=device)

predicted = generator(garbage)
plot_images(predicted)

In [None]:
# writer.add_graph(generator, garbage)
# writer.flush()

In [None]:
descriminator = descriminator.eval().to(device)
predicted = descriminator(images.to(device))

print (f"Dataset image size: {images.size()}")
print (f"Model Generate images: {predicted}")

In [None]:
adversarial_loss = pt.nn.BCELoss()

In [None]:
# writer.add_graph(descriminator, images.to(device))
# writer.flush()

In [None]:
def train(current_epoch, total_epochs, generator, descriminator):
    d_loss_list, g_loss_list = [], []

    for images, _ in tqdm(train_loader):

        real_images = images.to(device)

        real_labels = pt.ones(size=(real_images.shape[0], 1), device=device, requires_grad=False)
        fake_labels = pt.zeros(size=(real_images.shape[0], 1), device=device, requires_grad=False)

# Generator Training
# -------------------------------------------------------------------------------------------------------------------------------------------------

        generator_optimizer.zero_grad()
        generator = generator.to(device).train()

        predicted_images = generator(
            pt.from_numpy(np.random.normal(size=(real_images.shape[0], 1, 384))).to(dtype=pt.float32, device=device),
        )

        descriminator = descriminator.to(device).eval()
        g_loss = adversarial_loss(descriminator(predicted_images), real_labels)

        g_loss.backward()

        generator_optimizer.step()
        g_loss_list.append(g_loss.item())

# Descriminator Training
# -------------------------------------------------------------------------------------------------------------------------------------------------

        descriminator_optimizer.zero_grad()

        descriminator = descriminator.to(device).train()

        d_real_loss = 0.20 * adversarial_loss(descriminator(real_images), real_labels)
        d_fake_loss = -0.20 * adversarial_loss(descriminator(predicted_images.detach()), fake_labels)

        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()

        descriminator_optimizer.step()
        d_loss_list.append(d_loss.item())

# Miscmiscellaneous...
# -------------------------------------------------------------------------------------------------------------------------------------------------

    print (f"Epoch: {current_epoch+1}/{total_epochs} Descriminator loss: {d_loss.item():.6f}", end="; ")
    print (f"Generator Training loss: {g_loss.item():.6f}")

# Real Time Visualization (While Training)
# -------------------------------------------------------------------------------------------------------------------------------------------------

    nrows = 1
    ncols = 5

    generator = generator.to(device).eval()

    sample_z_in_train = pt.from_numpy(np.random.normal(size=(nrows*ncols, 1, 384))).to(dtype=pt.float32, device=device)
    sample_gen_imgs_in_train = generator(sample_z_in_train).detach().cpu()


    _, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 2))

    for ncol, sample_gen_img_in_train in enumerate(sample_gen_imgs_in_train):
        # Using mT instead of T, as T is generally used to reverse 2 dim tensor shape, and is deprecated
        axes[ncol].imshow(sample_gen_img_in_train.T)
        axes[ncol].axis('off')
    plt.show()

# Miscmiscellaneous
# -------------------------------------------------------------------------------------------------------------------------------------------------

    return {
        "generator_loss": g_loss_list,
        "descriminator_loss": d_loss_list
    }

In [None]:
total_epochs = 2000

gen_loss = []
desc_loss = []
for epoch in range(total_epochs):
    temp = train(epoch, total_epochs, generator, descriminator)

    gen_loss.append(temp["generator_loss"])
    desc_loss.append(temp["descriminator_loss"])

In [None]:
# descriminator(predicted_images)

In [None]:
# real_labels

In [None]:
plt.subplot(1, 4, 1)
plt.plot(desc_loss, label="Descriminator Loss")
plt.legend(loc="upper right")
plt.title("Descriminator Train Loss")

plt.subplot(1, 4, 3)
plt.plot(gen_loss, label="Generator Loss")
plt.legend(loc="upper right")
plt.title("Descriminator Train Accuracy")
plt.title("Generator Train Loss")

In [None]:
pt.save(generator, "./generator.pt")
pt.save(descriminator, "./descriminator.pt")