In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch.optim as optim
import torchvision

device = torch.device('mps:0' if torch.backends.mps.is_available() else 'cpu')
mps = torch.backends.mps.is_available()

In [2]:
IMAGE_SIZE = 28 * 28
BATCH_SIZE = 128
EPOCHS = 300000
LEARNING_RATE = 0.1
NOISE_DIM = 10

In [3]:
MNIST = torchvision.datasets.MNIST('.', train=True, download=True)
X_train = MNIST.data.reshape(-1, IMAGE_SIZE) / 255
y_train = MNIST.targets
if mps:
  X_train = X_train.to(device)
  y_train = y_train.to(device)

In [4]:
activation = nn.LeakyReLU(0.2, inplace=True)
class Generator(nn.Module):
  def __init__(self, input_size: int, output_size: int, hidden_size: list, weights_data):
    super().__init__()
    self.input_size = input_size
    self.model = nn.Sequential(
      nn.Linear(input_size, hidden_size[0]),
      activation,
      *[nn.Linear(hidden_size[i//2], hidden_size[i//2+1]) if i%2==0 else activation for i in range(0, len(hidden_size)*2-2)],
      nn.Linear(hidden_size[-1], output_size),
      nn.Sigmoid() # from 0 to 1
    )
    if weights_data:
      for i, w in enumerate(weights_data):
        self.model[i].weight.data = w
      print("weights loaded")
    else:
      self.init_weights = lambda m: torch.nn.init.normal_(m.weight, mean=0, std=0.1) if type(m) == nn.Linear else None
      self.model.apply(self.init_weights)
    
  def forward(self, x):
    return self.model(x.view(-1, self.input_size))
  
class Discriminator(nn.Module):
  def __init__(self, input_size: int, output_size: int, hidden_size: list, weights_data):
    super().__init__()
    self.input_size = input_size
    self.ll = nn.Sequential(
      nn.Linear(input_size, hidden_size[0]),
      activation,
      *[nn.Linear(hidden_size[i//2], hidden_size[i//2+1]) if i%2==0 else activation for i in range(0, len(hidden_size)*2-2)],
      nn.Linear(hidden_size[-1], output_size),
      nn.Sigmoid() # from 0 to 1
    )
    if weights_data:
      for i, w in enumerate(weights_data):
        self.ll[i].weight.data = w
    else:
      self.init_weights = lambda m: torch.nn.init.normal_(m.weight, mean=0, std=0.1) if type(m) == nn.Linear else None
      self.ll.apply(self.init_weights)
  def forward(self, x):
    return self.ll(x.view(-1, self.input_size))

In [5]:
def generator_loss(fake_output, discriminator):
  return -torch.mean(torch.log(1 - discriminator(fake_output)))
def discriminator_loss(real_output, fake_output, discriminator):
  return torch.mean(torch.log(discriminator(real_output)) + torch.log(1 - discriminator(fake_output)))

In [6]:
generator_loss = nn.BCELoss()
discriminator_loss = nn.BCELoss()

In [13]:
generator = Generator(NOISE_DIM, IMAGE_SIZE, [128, 256, 256, 512, 1024], None).to(device)
discriminator =  Discriminator(IMAGE_SIZE, 1, [512, 256], None).to(device)
generator.load_state_dict(torch.load('model/generator.pth'))
discriminator.load_state_dict(torch.load('model/discriminator.pth'))

<All keys matched successfully>

In [14]:
z = torch.randn(BATCH_SIZE, NOISE_DIM).to(device)
fake_output = generator(z)
loss = discriminator_loss(discriminator(fake_output), torch.zeros(BATCH_SIZE, 1).to(device))
loss

tensor(0.0106, device='mps:0', grad_fn=<BinaryCrossEntropyBackward0>)