In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import random

In [2]:
from typing import List

def create_binary_list_from_int(number: int) -> List[int]:
    if number < 0 or type(number) is not int:
        raise ValueError("Only Positive integers are allowed")

    return [int(x) for x in list(bin(number))[2:]]

In [7]:
def generate_even_data(max_int: int, batch_size: int=16):
    # Get the number of binary places needed to represent the maximum number
    max_length = int(math.log(max_int, 2))

    # Sample batch_size number of integers in range 0-max_int
    sampled_integers = np.random.randint(0, int(max_int / 2), batch_size)

    # create a list of labels all ones because all numbers are even
    labels = [1] * batch_size

    # Generate a list of binary numbers for training.
    data = [create_binary_list_from_int(int(x * 2)) for x in sampled_integers]
    data = [([0] * (max_length - len(x))) + x for x in data]

    return labels, data

In [4]:
class Generator(nn.Module):

    def __init__(self, input_length: int):
        super(Generator, self).__init__()
        self.dense_layer = nn.Linear(int(input_length), int(input_length))
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.dense_layer(x))

In [5]:
class Discriminator(nn.Module):
    def __init__(self, input_length: int):
        super(Discriminator, self).__init__()
        self.dense = nn.Linear(int(input_length), 1);
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.dense(x))

In [119]:
import math

import torch
import torch.nn as nn


def train(max_int: int = 128, batch_size: int = 16, training_steps: int = 500):
    input_length = int(math.log(max_int, 2))
    
    # Models
    generator = Generator(input_length)
    discriminator = Discriminator(input_length)
    
    # Optimizers
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)
    
    # loss
    loss = nn.BCELoss()
    
    for i in range(training_steps):
        # zero the gradients on each iteration
        generator_optimizer.zero_grad()
        
        # Create noisy input for generator
        # Need float type instead of int
        noise = torch.randint(0, 2, size=(batch_size, input_length)).float()
        generated_data = generator(noise)
        
        # Generate examples of even real data
        true_labels, true_data = generate_even_data(max_int, batch_size=batch_size)
        true_labels = torch.tensor(true_labels).float()
        true_data = torch.tensor(true_data).float()
        
        # Train the generator
        # We invert the labels here and don't train the discriminator because we want the generator
        # to make things the discriminator classifies as true.
        generator_discriminator_out = discriminator(generated_data)
        generator_loss = loss(generator_discriminator_out, true_labels)
        generator_loss.backward()
        generator_optimizer.step()
        
        # Train the discriminator on the true/generated data
        discriminator_optimizer.zero_grad()
        true_discriminator_out = discriminator(true_data)
        true_discriminator_loss = loss(true_discriminator_out, true_labels)
        
        # add .detach() here think about this
        generator_discriminator_out = discriminator(generated_data.detach())
        generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size))
        discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2
        discriminator_loss.backward()
        discriminator_optimizer.step()
        
        print(f"[Iter {i}] [GenLoss {generator_loss.item():0.4f}] [DisLoss {discriminator_loss.item():0.4f}] [Acc {((generated_data.round()*(2**torch.arange(generated_data.shape[1]-1, -1, -1))[None, ...]).sum(1)%2==0).float().sum(0)/generated_data.shape[0]}]")
        if i%500 == 0:
            print(f'[Pred: { (generated_data.round()*(2**torch.arange(generated_data.shape[1]-1, -1, -1))[None, ...]).sum(1) }]')
        if i%500 == 0:
            print(f'[GT: { (true_data*(2**torch.arange(true_data.shape[1]-1, -1, -1))[None, ...]).sum(1) }')

In [143]:
train()

[Iter 0] [GenLoss 0.7094] [DisLoss 0.7685] [Acc 0.0]
[Pred: tensor([73.,  1.,  3.,  3.,  3., 33.,  3.,  3.,  1.,  1.,  3., 33., 43.,  1.,
         1.,  9.], grad_fn=<SumBackward1>)]
[GT: tensor([116.,   0., 110.,   0., 126.,  30.,  82.,  84.,  76., 106.,  22., 110.,
        102.,  92., 102.,  10.])
[Iter 1] [GenLoss 0.7131] [DisLoss 0.7550] [Acc 0.0625]
[Iter 2] [GenLoss 0.7045] [DisLoss 0.7155] [Acc 0.0]
[Iter 3] [GenLoss 0.7064] [DisLoss 0.7263] [Acc 0.0]
[Iter 4] [GenLoss 0.7024] [DisLoss 0.7592] [Acc 0.125]
[Iter 5] [GenLoss 0.7098] [DisLoss 0.7532] [Acc 0.0]
[Iter 6] [GenLoss 0.6921] [DisLoss 0.7597] [Acc 0.0625]
[Iter 7] [GenLoss 0.7074] [DisLoss 0.7395] [Acc 0.0]
[Iter 8] [GenLoss 0.6894] [DisLoss 0.7929] [Acc 0.0]
[Iter 9] [GenLoss 0.7037] [DisLoss 0.7439] [Acc 0.0625]
[Iter 10] [GenLoss 0.6992] [DisLoss 0.7610] [Acc 0.0]
[Iter 11] [GenLoss 0.6982] [DisLoss 0.7638] [Acc 0.0]
[Iter 12] [GenLoss 0.6757] [DisLoss 0.7785] [Acc 0.0]
[Iter 13] [GenLoss 0.6907] [DisLoss 0.7133] [Acc 0