In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import h5py

device = torch.device("mps")

device

device(type='mps')

In [2]:
def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data

def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape,

    def forward(self, x):
        return x.view(*self.shape)
    
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            View(218*178*3),
            nn.Linear(3*218*178, 100),
            nn.LeakyReLU(),
            nn.LayerNorm(100),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )
        self.loss_function = nn.BCELoss()
        self.optimiser = torch.optim.Adam(self.parameters(), lr = 0.0001)
#         self.optimiser.param_groups[0]['capturable'] = True
        self.counter = 0
        self.progress = []
        pass
    
    def forward(self, inputs):
        return self.model(inputs)
    
    def train(self, inputs, targets):
        outputs = self.forward(inputs)
        loss = self.loss_function(outputs, targets)
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass
    
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns = ['loss'])
        df.plot(ylim = (0), figsize = (16, 8), alpha = 0.1, marker = '.', grid = True, yticks = (0, 0.2, 0.5, 1))
        pass
    pass

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 3*10*10),
            nn.LeakyReLU(),
            nn.LayerNorm(3*10*10),
            nn.Linear(3*10*10, 3*218*178),
            nn.Sigmoid(),
            View((218, 178, 3))
        )
        self.optimiser = torch.optim.Adam(self.parameters(), lr = 0.0001)
#         self.optimiser.param_groups[0]['capturable'] = True
        self.counter = 0
        self.progress = []
        pass
    
    def forward(self, inputs):
        return self.model(inputs)
    
    def train(self, D, inputs, targets):
        g_output = self.forward(inputs)
        d_output = D.forward(g_output)
        loss = D.loss_function(d_output, targets)
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass
    
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns = ['loss'])
        df.plot(ylim = (0), figsize = (16, 8), alpha = 0.1, marker = '.', grid = True, yticks = (0, 0.2, 0.5, 1))
        pass
    pass

In [3]:
class CelebADataset(Dataset):
    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')
        self.dataset = self.file_object['img_align_celeba']
        pass

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        if (index >= len(self.dataset)):
            raise IndexError()
        img = np.array(self.dataset[str(index)+'.jpg'])
        return torch.FloatTensor(img) / 255.0

    def plot_image(self, index):
        plt.imshow(np.array(self.dataset[str(index)+'.jpg']), interpolation='nearest')
        pass
    pass

celeba_dataset = CelebADataset('celeba/celeba_aligned_real_small.h5py')

In [4]:
%%time

D = Discriminator()
# D.to(device)
G = Generator()
# G.to(device)

epochs = 1

for epoch in range(epochs):
    for image_data_tensor in celeba_dataset:
        D.train(image_data_tensor, torch.FloatTensor([1.0]))
        D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))
        G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))
        pass
    pass

counter =  1000
counter =  2000
counter =  3000
counter =  4000
counter =  5000
counter =  6000
CPU times: user 30min 51s, sys: 35min 57s, total: 1h 6min 49s
Wall time: 10min 24s


In [5]:
def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data.to("mps")

def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data.to("mps")

class CelebADataset(Dataset):
    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')
        self.dataset = self.file_object['img_align_celeba']
        pass

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        if (index >= len(self.dataset)):
            raise IndexError()
        img = np.array(self.dataset[str(index)+'.jpg'])
        return torch.FloatTensor(img).to("mps") / 255.0

    def plot_image(self, index):
        plt.imshow(np.array(self.dataset[str(index)+'.jpg']), interpolation='nearest')
        pass
    pass

celeba_dataset = CelebADataset('celeba/celeba_aligned_real_small.h5py')

In [6]:
%%time

D = Discriminator()
D.to(device)
G = Generator()
G.to(device)

epochs = 1

for epoch in range(epochs):
    for image_data_tensor in celeba_dataset:
        D.train(image_data_tensor, torch.FloatTensor([1.0]).to(device))
        D.train(G.forward(generate_random_seed(100).to(device)).detach(), torch.FloatTensor([0.0]).to(device))
        G.train(D, generate_random_seed(100).to(device), torch.FloatTensor([1.0]).to(device))
        pass
    pass

counter =  1000
counter =  2000
counter =  3000
counter =  4000
counter =  5000
counter =  6000
CPU times: user 10min 48s, sys: 16min 3s, total: 26min 51s
Wall time: 8min 55s


In [7]:
def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data.to("mps")

def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data.to("mps")

class CelebADataset(Dataset):
    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')
        self.dataset = self.file_object['img_align_celeba']
        pass

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        if (index >= len(self.dataset)):
            raise IndexError()
        img = np.array(self.dataset[str(index)+'.jpg'])
        return torch.FloatTensor(img).to("mps") / 255.0

    def plot_image(self, index):
        plt.imshow(np.array(self.dataset[str(index)+'.jpg']), interpolation='nearest')
        pass
    pass

celeba_dataset = CelebADataset('celeba/celeba_aligned_real_small.h5py')

In [8]:
%%time

D = Discriminator()
D.to(device)
G = Generator()
G.to(device)

epochs = 1

FT0 = torch.FloatTensor([0.0]).to(device)
FT1 = torch.FloatTensor([1.0]).to(device)

for epoch in range(epochs):
    for image_data_tensor in celeba_dataset:
        D.train(image_data_tensor, FT1)
        D.train(G.forward(generate_random_seed(100)).detach(), FT0)
        G.train(D, generate_random_seed(100), FT1)
        pass
    pass

counter =  1000
counter =  2000
counter =  3000
counter =  4000
counter =  5000
counter =  6000
CPU times: user 26min 13s, sys: 21min 33s, total: 47min 47s
Wall time: 23min 14s
