This example is based on - **`Learn Generative AI with Pytorch by Mark Liu`**, Chapter 3, Generative Adversial Networks: Shape and Number Generation

## Preparing Data

In [8]:
import torch
def onehot_encoder(position,depth):
 onehot=torch.zeros((depth,))
 onehot[position]=1
 return onehot

print(onehot_encoder(1,5))

tensor([0., 1., 0., 0., 0.])


In [9]:
def int_to_onehot(number):
 onehot=onehot_encoder(number,100)
 return onehot

onehot75=int_to_onehot(75)
print(onehot75)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [10]:
def onehot_to_int(onehot):
 num=torch.argmax(onehot)
 return num.item()

print(onehot_to_int(onehot75))

75


In [3]:
import torch

# generated multiples of 5 between 0 and 99
def gen_sequence():
 indices = torch.randint(0, 20, (10,))
 values = indices*5
 return values

In [4]:
sequence=gen_sequence()
print(sequence)

tensor([50, 15, 10, 75, 50, 80,  5, 20,  0, 10])


In [11]:
import numpy as np
def gen_batch():
 sequence=gen_sequence()
 batch=[int_to_onehot(i).numpy() for i in sequence]
 batch=np.array(batch)
 return torch.tensor(batch)

batch=gen_batch()

In [12]:
def data_to_num(data):
 num=torch.argmax(data,dim=-1)
 return num

numbers=data_to_num(batch)

#### Why one-hot-encode here?.

- we need to pass continuous data to neural networks
- images are continuous data (0-255), same way we embedd text to convert it to continuous nature data.

In [14]:
from torch import nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

D=nn.Sequential(
 nn.Linear(100,1),
 nn.Sigmoid()).to(device)

In [15]:
G=nn.Sequential(
 nn.Linear(100,100),
 nn.ReLU()).to(device)

In [16]:
loss_fn=nn.BCELoss()
lr=0.0005
optimD=torch.optim.Adam(D.parameters(),lr=lr)
optimG=torch.optim.Adam(G.parameters(),lr=lr)

In [17]:
# Define Early Stopping Function

class EarlyStop:
  def __init__(self, patience=1000):
    self.patience = patience
    self.steps = 0
    self.min_gdif = float('inf')

  def stop(self, gdif):
    if gdif < self.min_gdif:
      self.min_gdif = gdif
      self.steps = 0
    elif gdif >= self.min_gdif:
      self.steps += 1
    if self.steps >= self.patience:
      return True
    else:
      return False

stopper=EarlyStop()

In [19]:
real_labels=torch.ones((10,1)).to(device)
fake_labels=torch.zeros((10,1)).to(device)

In [20]:
def train_D_G(D,G,loss_fn,optimD,optimG):
    # Generate examples of real data
    true_data=gen_batch().to(device)
    # use 1 as labels since they are real
    preds=D(true_data)
    loss_D1=loss_fn(preds,real_labels.reshape(10,1))
    optimD.zero_grad()
    loss_D1.backward()
    optimD.step()
    # train D on fake data
    noise=torch.randn(10,100).to(device)
    generated_data=G(noise)
    # use 0 as labels since they are fake
    preds=D(generated_data)
    loss_D2=loss_fn(preds,fake_labels.reshape(10,1))
    optimD.zero_grad()
    loss_D2.backward()
    optimD.step()

    # train G
    noise=torch.randn(10,100).to(device)
    generated_data=G(noise)
    # use 1 as labels since G wants to fool D
    preds=D(generated_data)
    loss_G=loss_fn(preds,real_labels.reshape(10,1))
    optimG.zero_grad()
    loss_G.backward()
    optimG.step()
    return generated_data

In [21]:
stopper=EarlyStop(800)
mse=nn.MSELoss()
real_labels=torch.ones((10,1)).to(device)
fake_labels=torch.zeros((10,1)).to(device)

def distance(generated_data):
  nums=data_to_num(generated_data)
  remainders=nums%5
  ten_zeros=torch.zeros((10,1)).to(device)
  mseloss=mse(remainders,ten_zeros)

  return mseloss

for i in range(10000):
  gloss=0
  dloss=0
  generated_data=train_D_G(D,G,loss_fn,optimD,optimG)
  dis=distance(generated_data)
  if stopper.stop(dis)==True:
    break
  if i % 50 == 0:
    print(data_to_num(generated_data))

  return F.mse_loss(input, target, reduction=self.reduction)


tensor([41, 12, 80, 59,  7, 72, 18,  2, 21, 20])
tensor([58, 16,  2, 74, 87, 54, 23, 98, 93, 57])
tensor([ 9, 19, 59, 17, 89,  4, 71, 82, 55, 86])
tensor([22, 54, 91, 89, 88, 57, 57, 34, 61, 77])
tensor([59, 88, 87, 55, 55, 21, 22, 75,  4, 21])
tensor([22, 81,  4, 86, 60, 95, 61, 37, 76, 77])
tensor([21,  9, 15,  9, 26, 37, 77, 77, 37, 61])
tensor([37, 74, 37, 82, 76, 86, 21, 21,  4,  4])


  return F.mse_loss(input, target, reduction=self.reduction)


tensor([ 4, 77, 61, 15, 15, 55, 15, 60,  9, 76])
tensor([70, 55, 15, 15, 22, 75,  0, 22, 55, 60])
tensor([22, 55, 20, 55,  4, 15, 15, 96, 82, 15])
tensor([ 4, 15,  4,  0, 59, 77, 55, 55, 55, 21])
tensor([40, 90, 15, 50, 15, 40, 55, 20, 59,  4])
tensor([55, 55, 20, 50, 85, 70, 33, 65, 15, 55])
tensor([70, 15,  0, 60, 35, 77, 15, 20, 40, 60])
tensor([40, 20, 35, 70, 55, 55, 20, 55, 20, 95])
tensor([20, 70, 40, 15, 40, 40,  5, 70, 65, 20])
tensor([95, 20, 25, 80, 90, 80,  0, 70, 25, 40])
tensor([40, 65, 35, 80, 35, 40, 35, 40, 25, 75])
tensor([25, 40, 80, 35, 95, 25, 40, 65,  5, 80])


  return F.mse_loss(input, target, reduction=self.reduction)


tensor([ 5, 35, 80, 95, 80, 95,  5, 10, 70, 65])
tensor([65, 80, 80, 95, 25,  5, 95, 85, 85, 95])
tensor([65, 95, 35, 90, 95, 25, 95, 35, 80, 85])


  return F.mse_loss(input, target, reduction=self.reduction)


tensor([50, 15, 35, 25, 20,  5, 85, 25, 25, 25])
tensor([45, 75, 35, 90, 90, 95, 90, 50, 50, 50])


  return F.mse_loss(input, target, reduction=self.reduction)


tensor([20, 10, 45, 50, 75, 90, 75,  5, 30, 10])
tensor([45, 35, 20, 90, 75, 85, 85, 95, 45, 90])
tensor([95, 45, 45, 75, 30, 20, 70, 50, 45, 50])


In [22]:
# Export to TorchScript
import os
os.makedirs("files", exist_ok=True)
scripted = torch.jit.script(G)
scripted.save('files/num_gen.pt')

In [23]:
new_G=torch.jit.load('/content/files/num_gen.pt',
                     map_location=device)
new_G.eval()

RecursiveScriptModule(
  original_name=Sequential
  (0): RecursiveScriptModule(original_name=Linear)
  (1): RecursiveScriptModule(original_name=ReLU)
)

In [24]:
# obtain inputs from the latent space
torch.manual_seed(42)
noise=torch.randn((10,100)).to(device)
# feed the input to the generator
new_data=new_G(noise)
print(data_to_num(new_data))

tensor([40, 90, 85, 45, 10, 40, 40, 45, 70, 90])
