In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm


In [2]:
# Configuration
epochs      = 100
batch_size  = 64
sample_size = 100    # Number of random values to sample
g_lr        = 1.0e-4 # Generator's learning rate
d_lr        = 1.0e-4 # Discriminator's learning rate

In [3]:
# DataLoader for MNIST
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 146890916.57it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 81081455.04it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 31369139.27it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 14399492.64it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
# Generator Network
class Generator(nn.Sequential):
    def __init__(self, sample_size: int):
        super().__init__(
            nn.Linear(sample_size, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 784),
            nn.Sigmoid())

        # Random value vector size
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # Generate randon values
        z = torch.randn(batch_size, self.sample_size)

        # Generator output
        output = super().forward(z)

        # Convert the output into a greyscale image (1x28x28)
        generated_images = output.reshape(batch_size, 1, 28, 28)
        return generated_images

In [5]:
#Discriminator Network
class Discriminator(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1))

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        prediction = super().forward(images.reshape(-1, 784))
        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


In [6]:
# To save images in grid layout
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol)     # Images in a grid
    image_grid = image_grid.permute(1, 2, 0) # Move channel last
    image_grid = image_grid.cpu().numpy()    # To Numpy

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


In [7]:
# Real and fake labels
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)


In [8]:
# Generator and Discriminator networks
generator = Generator(sample_size)
discriminator = Discriminator()

In [9]:
# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)


In [10]:
# Training loop
for epoch in range(epochs):

    d_losses = []
    g_losses = []

    for images, labels in tqdm(dataloader):
        #===============================
        # Discriminator Network Training
        #===============================

        # Loss with MNIST image inputs and real_targets as labels
        discriminator.train()
        d_loss = discriminator(images, real_targets)

        # Generate images in eval mode
        generator.eval()
        with torch.no_grad():
            generated_images = generator(batch_size)

        # Loss with generated image inputs and fake_targets as labels
        d_loss += discriminator(generated_images, fake_targets)

        # Optimizer updates the discriminator parameters
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #===============================
        # Generator Network Training
        #===============================

        # Generate images in train mode
        generator.train()
        generated_images = generator(batch_size)

        # Loss with generated image inputs and real_targets as labels
        discriminator.eval() # eval but we still need gradients
        g_loss = discriminator(generated_images, real_targets)

        # Optimizer updates the generator parameters
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Keep losses for logging
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    # Print average losses
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # Save images
    save_image_grid(epoch, generator(batch_size), ncol=8)

100%|██████████| 937/937 [00:23<00:00, 40.71it/s]


0 0.4579072228236285 2.70171331786804


100%|██████████| 937/937 [00:15<00:00, 58.82it/s]


1 0.4140520644321513 2.0489276216276937


100%|██████████| 937/937 [00:15<00:00, 59.79it/s]


2 0.4671648385812277 1.6631388633711743


100%|██████████| 937/937 [00:16<00:00, 58.29it/s]


3 0.6227117387245915 1.1796640389884168


100%|██████████| 937/937 [00:17<00:00, 55.05it/s]


4 0.514058011700275 1.278915920786822


100%|██████████| 937/937 [00:16<00:00, 55.32it/s]


5 0.5439690992697326 1.4932737723231442


100%|██████████| 937/937 [00:16<00:00, 56.67it/s]


6 0.8187510459629169 1.4378300967567632


100%|██████████| 937/937 [00:16<00:00, 56.83it/s]


7 0.6043641148343794 1.7437858797824497


100%|██████████| 937/937 [00:16<00:00, 58.08it/s]


8 0.5701170886847864 1.8500073902634953


100%|██████████| 937/937 [00:17<00:00, 54.64it/s]


9 0.5310927195317463 1.855527222856259


100%|██████████| 937/937 [00:15<00:00, 59.25it/s]


10 0.5141374436997299 1.9036623935434835


100%|██████████| 937/937 [00:17<00:00, 55.11it/s]


11 0.5374759687876116 1.930261155202778


100%|██████████| 937/937 [00:16<00:00, 56.44it/s]


12 0.5307988069927679 1.9353771875100558


100%|██████████| 937/937 [00:16<00:00, 58.11it/s]


13 0.46827592612457886 1.9621095265497392


100%|██████████| 937/937 [00:16<00:00, 57.56it/s]


14 0.45831403821674965 2.0075947434314414


100%|██████████| 937/937 [00:16<00:00, 57.98it/s]


15 0.4136228948481691 2.230042278448571


100%|██████████| 937/937 [00:16<00:00, 57.93it/s]


16 0.43378442735783956 2.310522092062992


100%|██████████| 937/937 [00:17<00:00, 54.48it/s]


17 0.52244544967516 2.1005483242338214


100%|██████████| 937/937 [00:16<00:00, 55.99it/s]


18 0.5073893415635271 2.094763481756921


100%|██████████| 937/937 [00:16<00:00, 57.69it/s]


19 0.4231089561255281 2.2245569133707654


100%|██████████| 937/937 [00:16<00:00, 57.99it/s]


20 0.43209077915616073 2.2303906006136214


100%|██████████| 937/937 [00:17<00:00, 52.57it/s]


21 0.42310054010235515 2.286195949531034


100%|██████████| 937/937 [00:16<00:00, 55.65it/s]


22 0.41608458666753106 2.329468636782472


100%|██████████| 937/937 [00:16<00:00, 55.36it/s]


23 0.35928200361822815 2.537129626202863


100%|██████████| 937/937 [00:16<00:00, 57.97it/s]


24 0.3966139262838323 2.476226714402914


100%|██████████| 937/937 [00:16<00:00, 56.79it/s]


25 0.41152134275296454 2.4744925745777793


100%|██████████| 937/937 [00:16<00:00, 57.33it/s]


26 0.40804511332524623 2.4626830087272946


100%|██████████| 937/937 [00:16<00:00, 55.58it/s]


27 0.43464525657822256 2.4617728164127315


100%|██████████| 937/937 [00:17<00:00, 54.26it/s]


28 0.428321666419188 2.5079960181847962


100%|██████████| 937/937 [00:16<00:00, 57.29it/s]


29 0.3988233103855188 2.576816746875596


100%|██████████| 937/937 [00:16<00:00, 57.81it/s]


30 0.440773094492381 2.5072768763773214


100%|██████████| 937/937 [00:16<00:00, 57.37it/s]


31 0.4423282054534208 2.5461597908268363


100%|██████████| 937/937 [00:16<00:00, 56.38it/s]


32 0.44150221414601665 2.474729055783283


100%|██████████| 937/937 [00:17<00:00, 52.48it/s]


33 0.4322969248257935 2.484900736630663


100%|██████████| 937/937 [00:16<00:00, 57.09it/s]


34 0.43763028118628194 2.5014230912370388


100%|██████████| 937/937 [00:16<00:00, 57.63it/s]


35 0.39175891338697494 2.692176292138522


100%|██████████| 937/937 [00:16<00:00, 57.77it/s]


36 0.4907224420994806 2.488343217329638


100%|██████████| 937/937 [00:18<00:00, 51.74it/s]


37 0.4303843464579119 2.6056422823393994


100%|██████████| 937/937 [00:16<00:00, 55.58it/s]


38 0.49845570338573886 2.4242908196362767


100%|██████████| 937/937 [00:16<00:00, 58.32it/s]


39 0.4942520211511386 2.43549746357644


100%|██████████| 937/937 [00:16<00:00, 58.17it/s]


40 0.48654732521595545 2.4551452337232447


100%|██████████| 937/937 [00:16<00:00, 58.23it/s]


41 0.5009609998512929 2.5059844378855085


100%|██████████| 937/937 [00:16<00:00, 58.08it/s]


42 0.49502993533298323 2.480296754404473


100%|██████████| 937/937 [00:16<00:00, 56.28it/s]


43 0.45870322542485714 2.518501930328418


100%|██████████| 937/937 [00:16<00:00, 55.55it/s]


44 0.576205168933217 2.3143175328935603


100%|██████████| 937/937 [00:16<00:00, 58.52it/s]


45 0.4746688786921913 2.501320429011114


100%|██████████| 937/937 [00:15<00:00, 58.60it/s]


46 0.5106621240857062 2.3945229389369804


100%|██████████| 937/937 [00:16<00:00, 58.24it/s]


47 0.5121066655522349 2.4543277317838963


100%|██████████| 937/937 [00:16<00:00, 58.25it/s]


48 0.5246747833305896 2.4388889036158234


100%|██████████| 937/937 [00:16<00:00, 55.44it/s]


49 0.522003538389725 2.430914558748553


100%|██████████| 937/937 [00:16<00:00, 55.45it/s]


50 0.5171162594217119 2.3977399966505573


100%|██████████| 937/937 [00:16<00:00, 58.08it/s]


51 0.5238920623553855 2.4018105010212802


100%|██████████| 937/937 [00:16<00:00, 58.08it/s]


52 0.5367268621126386 2.4604235231049416


100%|██████████| 937/937 [00:16<00:00, 57.41it/s]


53 0.5437116488384971 2.401842393641029


100%|██████████| 937/937 [00:17<00:00, 54.95it/s]


54 0.5408848720337309 2.4486822006923794


100%|██████████| 937/937 [00:16<00:00, 55.28it/s]


55 0.743876531004524 2.1505669644191907


100%|██████████| 937/937 [00:16<00:00, 57.07it/s]


56 0.48217001781583213 2.623426608240337


100%|██████████| 937/937 [00:15<00:00, 58.61it/s]


57 0.5972370986241287 2.307932708153984


100%|██████████| 937/937 [00:15<00:00, 58.72it/s]


58 0.57177911786285 2.3369146114734094


100%|██████████| 937/937 [00:15<00:00, 60.24it/s]


59 0.5556138383031782 2.347597612388106


100%|██████████| 937/937 [00:16<00:00, 57.59it/s]


60 1.006628735215585 1.752723192965081


100%|██████████| 937/937 [00:18<00:00, 51.24it/s]


61 0.7437122745186949 2.2541105312115355


100%|██████████| 937/937 [00:17<00:00, 54.35it/s]


62 0.670419217682699 2.27786388799246


100%|██████████| 937/937 [00:16<00:00, 58.46it/s]


63 0.4974124828391294 2.6378360308603392


100%|██████████| 937/937 [00:16<00:00, 56.12it/s]


64 0.6433664385066343 2.2226583588721276


100%|██████████| 937/937 [00:17<00:00, 54.47it/s]


65 0.6398258178694652 2.201157865554826


100%|██████████| 937/937 [00:16<00:00, 57.35it/s]


66 0.6458843567328112 2.199188169827456


100%|██████████| 937/937 [00:17<00:00, 52.92it/s]


67 0.6349449242001027 2.2270704922579396


100%|██████████| 937/937 [00:17<00:00, 55.05it/s]


68 0.6236515946551092 2.2060317736170716


100%|██████████| 937/937 [00:15<00:00, 59.80it/s]


69 0.6159764362564718 2.2345899015251574


100%|██████████| 937/937 [00:19<00:00, 48.78it/s]


70 1.3339647223244608 1.5375188947232836


100%|██████████| 937/937 [00:15<00:00, 58.84it/s]


71 0.7343802686180351 2.073022483634338


100%|██████████| 937/937 [00:17<00:00, 54.18it/s]


72 0.9811686747801851 1.855482958869466


100%|██████████| 937/937 [00:16<00:00, 56.96it/s]


73 0.7849193365813065 2.0390411489419518


100%|██████████| 937/937 [00:16<00:00, 56.91it/s]


74 0.421137533978566 2.729421351482672


100%|██████████| 937/937 [00:15<00:00, 58.65it/s]


75 1.092602808072193 1.6948791134828183


100%|██████████| 937/937 [00:15<00:00, 59.07it/s]


76 0.5455046559792318 2.3293920542921556


100%|██████████| 937/937 [00:16<00:00, 56.90it/s]


77 0.6915304518305759 2.105606051875535


100%|██████████| 937/937 [00:16<00:00, 56.79it/s]


78 1.2397487270412222 1.5024962472941095


100%|██████████| 937/937 [00:16<00:00, 55.54it/s]


79 1.0728536989289388 1.7009911415671082


100%|██████████| 937/937 [00:17<00:00, 53.98it/s]


80 0.648447859662574 2.0469983188930416


100%|██████████| 937/937 [00:16<00:00, 55.33it/s]


81 0.6232244647490686 2.090659709088703


100%|██████████| 937/937 [00:16<00:00, 56.64it/s]


82 0.6664039740376692 2.0620413509733266


100%|██████████| 937/937 [00:17<00:00, 54.83it/s]


83 0.7078903509878425 2.0217941456187116


100%|██████████| 937/937 [00:17<00:00, 53.85it/s]


84 0.6955914645273886 2.013484339958321


100%|██████████| 937/937 [00:16<00:00, 55.64it/s]


85 0.7016798185945957 2.0293005950422907


100%|██████████| 937/937 [00:17<00:00, 53.07it/s]


86 0.698468359646446 2.0497497018653275


100%|██████████| 937/937 [00:15<00:00, 58.59it/s]


87 0.6799249726972814 2.0727604565269284


100%|██████████| 937/937 [00:15<00:00, 59.38it/s]


88 0.6917799796785845 2.0702690428069204


100%|██████████| 937/937 [00:15<00:00, 59.81it/s]


89 0.7047340293576904 2.049808291259163


100%|██████████| 937/937 [00:16<00:00, 57.88it/s]


90 0.7361588637755037 1.9951122288006728


100%|██████████| 937/937 [00:16<00:00, 57.16it/s]


91 0.6540527042228105 2.0833932060823876


100%|██████████| 937/937 [00:16<00:00, 58.51it/s]


92 0.715248158004775 1.961734428985905


100%|██████████| 937/937 [00:15<00:00, 58.89it/s]


93 0.7002288393811457 1.9881703602975054


100%|██████████| 937/937 [00:16<00:00, 58.28it/s]


94 0.7050084022918404 1.988326113913586


100%|██████████| 937/937 [00:15<00:00, 59.41it/s]


95 0.8318882970697979 1.8251747040509414


100%|██████████| 937/937 [00:15<00:00, 58.63it/s]


96 0.6345787151391631 2.090464726965099


100%|██████████| 937/937 [00:16<00:00, 57.39it/s]


97 0.6941051138058034 1.9601021632313602


100%|██████████| 937/937 [00:16<00:00, 57.23it/s]


98 0.6967646954153747 1.9735487464143475


100%|██████████| 937/937 [00:15<00:00, 59.22it/s]

99 0.7088932843447495 1.9566744627031472





In [13]:
!zip -r /content/file.zip /content/

  adding: content/ (stored 0%)
  adding: content/.config/ (stored 0%)
  adding: content/.config/active_config (stored 0%)
  adding: content/.config/logs/ (stored 0%)
  adding: content/.config/logs/2023.11.21/ (stored 0%)
  adding: content/.config/logs/2023.11.21/14.24.29.745469.log (deflated 58%)
  adding: content/.config/logs/2023.11.21/14.18.34.556141.log (deflated 91%)
  adding: content/.config/logs/2023.11.21/14.24.39.684965.log (deflated 57%)
  adding: content/.config/logs/2023.11.21/14.21.40.489438.log (deflated 86%)
  adding: content/.config/logs/2023.11.21/14.24.40.484653.log (deflated 56%)
  adding: content/.config/logs/2023.11.21/14.21.30.762319.log (deflated 58%)
  adding: content/.config/.last_update_check.json (deflated 22%)
  adding: content/.config/.last_survey_prompt.yaml (stored 0%)
  adding: content/.config/.last_opt_in_prompt.yaml (stored 0%)
  adding: content/.config/default_configs.db (deflated 98%)
  adding: content/.config/gce (stored 0%)
  adding: content/.confi