In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms as transforms
from torchvision.datasets import CIFAR10
import sys

In [None]:
batch_size=  64
lr = 1e-3
max_epochs = 11
latent_dim = 128
device = 'cpu'

In [None]:
transform = transforms.ToTensor()
dataset_path = '/Users/JasperHilliard/Documents/Git testing Project/ganang'
train_data = CIFAR10(root=dataset_path, transform=transform, train=True, download=True)
validation_data = CIFAR10(root=dataset_path, train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size)
train_dict = enumerate(train_loader, 0)

In [None]:
class Vae(torch.nn.Module):
        def __init__(self, img_dim=(65536), channels=3, latent_dim=latent_dim):
            super(
).__init__()

            #Encoder
            self.encoder_conv_1 = torch.nn.Conv2d(3,32,3,stride=1,padding=1)
            self.encoder_conv_2 = torch.nn.Conv2d(32,64,3,stride=1,padding=1)
            self.encoder_FC_1 = torch.nn.Linear(img_dim, latent_dim)
            self.encoder_FC_2 = torch.nn.Linear(img_dim, latent_dim)
            #Decoder
            self.decoder_FC_1 = torch.nn.Linear(128,65536)
            self.decoder_conv_1 = torch.nn.ConvTranspose2d(64,32,3,stride=1,padding=1)
            self.decoder_conv_2 = torch.nn.ConvTranspose2d(32,3,3,stride=1,padding=1)

        def encoder(self, x, img_dim=(32,32), channels=3, latent_dim=latent_dim):
            h = self.encoder_conv_1(x)
            h = self.encoder_conv_2(h)
            h = h.view(h.size(0), -1)
            mu = torch.Tensor(h)
            sigma = torch.Tensor(h)
            mu = self.encoder_FC_1(mu)
            sigma = self.encoder_FC_2(sigma)

            return mu, sigma

        def decoder(self, k, img_dim=(32,32), channels=3, latent_dim=latent_dim):
            k = relu(k)
            k = torch.Tensor(self.decoder_FC_1(k))
            k = k.view(64, 64, 32, 32)
            k = relu(self.decoder_conv_1(k))
            k = torch.sigmoid(self.decoder_conv_2(k))
            return k


        def reparameterization(self, mu, sigma):
            eps = torch.randn_like(sigma)
            sigma = torch.exp(sigma/2)
            reparameterized = mu + sigma * eps
            return reparameterized


        def forward(self, x):
            mu, sigma = self.encoder(x)
            reparameterized = self.reparameterization(mu, sigma)
            z = self.decoder(reparameterized)
            return z, mu, sigma

relu = torch.nn.ReLU()
net = Vae().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

In [None]:
def training_loop():
    for epoch in range(max_epochs):
        for i in range(max_epochs):
            data = next(iter(train_loader))
            in_image, nothing = data
            in_image = in_image.to(device)
            out_image, mu, sigma = net(in_image)
            KL = -0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp(),dim=1)
            loss = torch.nn.functional.binary_cross_entropy(out_image, in_image, size_average=0) + KL
            loss = torch.mean(loss)
            print('Epoch is:', epoch)
            print('loss: is', loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


In [None]:
training_loop()