<a href="https://colab.research.google.com/github/alec-gironda/Pokemon-Variational-Autoencoder/blob/main/NewNotebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Imports

In [None]:
#imports

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt
import numpy as np
import os
import torch.nn as nn

Load Data

In [None]:
num_pokemon = len(os.listdir("drive/MyDrive/data/pokemon_jpg/pokemon_jpg/"))

ims = []

for im_name in os.listdir("drive/MyDrive/data/pokemon_jpg/pokemon_jpg/"):
    s = (f"drive/MyDrive/data/pokemon_jpg/pokemon_jpg/{str(im_name)}")
    curr = torchvision.io.read_image(s)
    ims.append(curr)

ims = torch.stack(ims)

ims = ims/255

Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.relu = nn.ReLU()

        self.conv_layer_1 =  nn.Conv2d(3,16,5,stride = 2,padding = 2)
        self.bn_1 = nn.BatchNorm2d(16)

        self.conv_layer_2 =  nn.Conv2d(16,32,5,stride = 2,padding = 2)
        self.bn_2 = nn.BatchNorm2d(32)

        self.conv_layer_3 =  nn.Conv2d(32,64,5,stride = 2,padding = 2)
        self.bn_3 = nn.BatchNorm2d(64)

        self.conv_layer_4 =  nn.Conv2d(64,128,5,stride = 2,padding = 2)
        self.bn_4 = nn.BatchNorm2d(128)

        self.flatten = nn.Flatten()

        self.fc_layer_1 = nn.Linear(32768,128)

    def forward(self, x):

        x = self.conv_layer_1(x)
        x = self.bn_1(x)
        x = self.relu(x)

        x = self.conv_layer_2(x)
        x = self.bn_2(x)
        x = self.relu(x)

        x = self.conv_layer_3(x)
        x = self.bn_3(x)
        x = self.relu(x)

        x = self.conv_layer_4(x)
        x = self.bn_4(x)
        x = self.relu(x)

        x = self.flatten(x)

        x = self.fc_layer_1(x)

        return x

Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self,batch_size):
        super(Decoder, self).__init__()

        self.relu = nn.ReLU()

        self.fc_layer_1 = nn.Linear(128,32768)

        self.conv_layer_1 =  nn.ConvTranspose2d(128,64,4,stride = 2,padding =1)
        self.bn_1 = nn.BatchNorm2d(64)

        self.conv_layer_2 =  nn.ConvTranspose2d(64,32,4,stride = 2,padding =1)
        self.bn_2 = nn.BatchNorm2d(32)

        self.conv_layer_3 =  nn.ConvTranspose2d(32,16,4,stride = 2,padding =1)
        self.bn_3 = nn.BatchNorm2d(16)

        self.conv_layer_4 =  nn.ConvTranspose2d(16,3,4,stride = 2,padding =1)

        self.flatten = nn.Flatten()

        self.batch_size = batch_size


    def forward(self, x):

        x = self.fc_layer_1(x)

        x = self.relu(x)

        x = torch.reshape(x,(self.batch_size,128,16,16))

        x = self.conv_layer_1(x)
        x = self.bn_1(x)
        x = self.relu(x)

        x = self.conv_layer_2(x)
        x = self.bn_2(x)
        x = self.relu(x)

        x = self.conv_layer_3(x)
        x = self.bn_3(x)
        x = self.relu(x)

        x = self.conv_layer_4(x)
        x = self.relu(x)

        x = self.flatten(x)

        x = torch.reshape(x,(self.batch_size,3,256,256))

        return x


Loss

In [None]:
def loss_fn(input_image,output_image):
    input_image = input_image.view(-1, 256*256)
    output_image = output_image.view(-1, 256*256)
    return torch.sum((input_image-output_image)**2)


Set Up

In [None]:
batch_size = 10
new_ims = ims[:800]

batches = torch.reshape(new_ims,(int(len(new_ims)/batch_size),batch_size,3,256,256))

encoder = Encoder().to("cuda")
decoder = Decoder(batch_size).to("cuda")

loss =  loss_fn # Step 2: loss
encoder_opt = torch.optim.Adam(encoder.parameters(), lr=.01) # Step 3: training method
decoder_opt = torch.optim.Adam(decoder.parameters(), lr=.01) # Step 3: training method

RuntimeError: ignored

Train

In [None]:
train_loss_history = []
for epoch in range(500):
  for batch in enumerate(batches):
    train_loss = 0.0
    encoder_opt.zero_grad()
    decoder_opt.zero_grad()
    encoded_out = encoder(ims[:batch_size].to("cuda"))
    decoded_out = decoder(encoded_out)
    fit = loss(ims[:batch_size].to("cuda"),decoded_out)
    fit.backward()
    encoder_opt.step()
    decoder_opt.step()
    train_loss += fit.item() / batch_size
    train_loss_history.append(train_loss)
    if epoch % 10 == 0:
      print(f'Epoch {epoch}, Train loss {train_loss}')
print(train_loss_history[-1])

Test

In [None]:
plt.imshow(transforms.ToPILImage()(ims[0]))
plt.savefig("drive/MyDrive/original.jpg")

In [None]:
test = new_ims[:batch_size].to("cuda")
out = encoder(test)
out = decoder(out)
out = torch.reshape(out,(batch_size,3,256,256))

plt.imshow(transforms.ToPILImage()(out[0]))
plt.savefig("drive/MyDrive/out.jpg")