In [None]:
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = torch.from_numpy(x_train).float().to(device)
y_train = torch.from_numpy(y_train).float().to(device)
x_test = torch.from_numpy(x_test).float().to(device)
y_test = torch.from_numpy(y_test).float().to(device)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
torch.isnan(x_train).any()

tensor(False, device='cuda:0')

In [None]:
mean = torch.mean(x_train, dim=0)
std = torch.std(x_train, dim=0) + 1e-8
x_train = x_train - mean
x_train = x_train / std

In [None]:
torch.isnan(mean).any()

tensor(False, device='cuda:0')

In [None]:
torch.isnan(std).any()

tensor(False, device='cuda:0')

In [None]:
torch.isnan(x_train).any()

tensor(False, device='cuda:0')

In [None]:
class UpsampleBlock(nn.Module):
  def __init__(self, in_channels, ini=True):
    super(UpsampleBlock, self).__init__()
    reg = torch.sqrt(torch.tensor(2/(in_channels*4*4)))
    self.seq = nn.Sequential(
        nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=4, padding=1, stride=2) * reg if ini else nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=4, padding=1, stride=2),
        nn.BatchNorm2d(in_channels // 2),
        nn.LeakyReLU(0.1),
    )

  def forward(self, x):
    return self.seq(x)

In [None]:
class Noiser(nn.Module):
  def __init__(self, amplitude_seq):
    super(Noiser, self).__init__()
    self.seq = amplitude_seq

  def forward(self, x, iter):
    return x + torch.randn_like(x) * self.seq[iter]

In [None]:
class Generator(nn.Module):
  def __init__(self, amplitude_seq):
    super(Generator, self).__init__()
    self.entry = nn.Sequential(
        # Entry
        nn.Linear(16, 16 * 64) * torch.sqrt(torch.tensor(2/16)),
        nn.Unflatten(1, (64, 4, 4)),
        nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1, stride=2) * torch.sqrt(torch.tensor(2/(32*4*4))),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.1)
    )
    self.block1 = UpsampleBlock(32)
    self.block2 = UpsampleBlock(16)
    self.finisher = nn.Conv2d(8, 1, kernel_size=1, stride=1)
    self.Noiser = Noiser(amplitude_seq)

  def forward(self, x, iter):
    x = self.entry(x)
    x = self.block1(x)
    if not self.training:
      x = self.Noiser(x, iter)
    x = self.block2(x)
    if not self.training:
      x = self.Noiser(x, iter)
    return self.finisher(x)

In [None]:
BATCH_SIZE = 256
N_SAMPLES = x_train.shape[0]
ITERS = 100

In [None]:
noise_ampli = torch.exp(torch.exp(torch.exp(torch.linspace(1, 0, steps=300)))) - torch.exp(torch.exp(torch.tensor(1)))
noise_ampli = noise_ampli / noise_ampli[0] * 0.3
noise_ampli = noise_ampli.to(device)

In [None]:
Discriminator = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=4, padding=1, stride=2) torch.sqrt(torch.tensor(2/(1*4*4))),
    nn.Conv2d(32, 64, kernel_size=4, padding=1, stride=2) * torch.sqrt(torch.tensor(2/(32*4*4))),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.1),
    nn.Flatten(),
    nn.Linear(7 * 7 * 64, 16) * torch.sqrt(torch.tensor(2/(7*7*64))),
    nn.BatchNorm1d(16),
    nn.LeakyReLU(0.1),
    nn.Linear(16, 1) * torch.sqrt(torch.tensor(2/(16))),
).to(device)

Generator = Generator(noise_ampli).to(device)

SyntaxError: invalid syntax. Perhaps you forgot a comma? (<ipython-input-14-2a9329eef6d5>, line 2)

In [None]:
optim_dis = torch.optim.Adam(Discriminator.parameters(), lr = 3e-4)
optim_gen = torch.optim.Adam(Generator.parameters(), lr = 3e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
def noisify_labels(labels):
    mu = 0
    sigma = 0.020
    noise = torch.normal(mean=mu, std=sigma, size=labels.shape) # Match the shape of labels
    noise = torch.abs(noise)
    noisy_labels = labels.clone().float()  # Cast to float to support noise addition

    # Use .view(-1) to ensure both are 1D
    noisy_labels = noisy_labels.view(-1).to(device)
    labelsv = labels.view(-1).to(device)
    noise = noise.view(-1).to(device)

    noisy_labels[labelsv == 0] += noise[labelsv == 0] # Flatten labels for indexing
    noisy_labels[labelsv == 1] -= noise[labelsv == 1] # Flatten labels for indexing

    return noisy_labels.reshape(labels.shape) # Reshape to original shape

In [None]:
def train_gen(iter):
  for _ in range(ITERS):
    # Gen Data
    batch_x = Generator(torch.randn(BATCH_SIZE, 16).to(device), iter).to(device)
    preds = Discriminator(batch_x).to(device)

    loss = criterion(preds, noisify_labels(torch.ones_like(preds).to(device))) # L = log(D(G(Z)))

    # Backpropagation for generator
    optim_gen.zero_grad()
    loss.backward()
    optim_gen.step()

def train_dis(iter):
  for _ in range(ITERS):
    # Real data
    ix = torch.randint(0, N_SAMPLES, (BATCH_SIZE,)).to(device)
    batch_x = x_train[ix].unsqueeze(1).to(device)
    preds_r = Discriminator(batch_x).to(device)
    r_loss = criterion(preds_r, noisify_labels(torch.ones_like(preds_r).to(device)))

    # Gen data
    batch_x = Generator(torch.randn(BATCH_SIZE, 16).to(device), iter).to(device)
    preds_g = Discriminator(batch_x)
    g_loss = criterion(preds_g, noisify_labels(torch.zeros_like(preds_g)))

    # Total loss
    loss = r_loss + g_loss # L = ( log(D(X)) + log(1-D(G(Z))) ) / 2

    # Backpropagation for discriminator
    optim_dis.zero_grad()
    loss.backward()
    optim_dis.step(

In [None]:
def sample_gen(num_images):
    Generator.eval()
    plt.figure(figsize=(num_images * 3, 3))
    for i in range(num_images):
        tensor = Generator(torch.randn(2, 16)[0, :].unsqueeze(0).to(device), 1)
        tensor = (tensor * std.to(device)) + mean.to(device)
        image = tensor.detach().cpu().numpy().reshape(28, 28)

        # Plot the image
        plt.subplot(1, num_images, i + 1)
        plt.imshow(image, cmap='gray')
        plt.axis('off')  # Optional: remove axes for a cleaner look
    plt.show()
    Generator.train()

def sample_reg():
  tensor = x_train[0]
  image = tensor.detach().cpu().numpy().reshape(28,28)

  # Plot the image
  plt.imshow(image, cmap='gray')
  plt.axis('off')  # Optional: remove axes for a cleaner look
  plt.show()

In [None]:
gen_loss = []
dis_loss = []
for i in range(300):
  print(f'iteration: {i}')
  sample_gen(6)

  batch = Generator(torch.randn(BATCH_SIZE, 16).to(device), i)
  loss = criterion(Discriminator(batch).to(device), torch.zeros(BATCH_SIZE, 1).to(device)).item()
  while(True):
    prev_loss = loss
    batch = Generator(torch.randn(BATCH_SIZE, 16).to(device), i)
    loss = criterion(Discriminator(batch).to(device), torch.ones(BATCH_SIZE, 1).to(device)).item()
    if(np.abs(prev_loss-loss)<1e-2):
      gen_loss.append(loss)
      break
    train_gen(i)

  batch = Generator(torch.randn(BATCH_SIZE, 16).to(device), i)
  loss = criterion(Discriminator(batch).to(device), torch.zeros(BATCH_SIZE, 1).to(device)).item()
  while(True):
    prev_loss = loss
    batch = Generator(torch.randn(BATCH_SIZE, 16).to(device), i)
    loss = criterion(Discriminator(batch).to(device), torch.zeros(BATCH_SIZE, 1).to(device)).item()
    if(np.abs(prev_loss-loss)<1e-3):
      dis_loss.append(loss)
      break
    train_dis(i)

In [None]:
plt.plot(gen_loss)
plt.plot(dis_loss)
plt.show()