In [1]:
import torch
class Snake(torch.utils.data.Dataset):
    def __init__(self, size = 100000):
        import numpy as np
        theta = torch.rand(size) * 2 * np.pi
        # torch.RAND
        # torch.rand(*size, *, out=None, dtype=None, layout=np.strided, device=None, requires_grad=False) → Tensor
        # Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)[0,1)
        r = torch.randn(size)
        # np.RANDN
        # np.randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
        # Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution).
        self.size = size
        self.x = (10 + r) * torch.cos(theta)
        mask = (theta >= (0.5 * np.pi)) & (theta <= (1.5 * np.pi))
        offset = torch.where(mask, torch.tensor(10.0), torch.tensor(-10.0))
        self.y = (10 + r) * torch.sin(theta) + offset

        self.x = torch.reshape(self.x, (1,-1))
        self.y = torch.reshape(self.y, (1,-1))
        self.data = torch.cat([ self.x, self.y], 0 ).T

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.size

In [2]:
snake = Snake()
snake.data.shape
snake.__getitem__(0)

tensor([ 4.9092, -2.5244])

In [3]:
import matplotlib.pyplot as plt
from logger import Logger
import torch
from torch import nn, optim
from torchvision import transforms, datasets

data = Snake()
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
num_batches = len(data_loader)
print(num_batches)

1000


In [4]:
a = iter(data_loader)
next(a)

tensor([[  9.0884,  -9.6926],
        [  5.2206, -18.5835],
        [  0.7825,  -0.3294],
        [  8.5499,  -9.8603],
        [ -7.7150,   3.9016],
        [  6.2492,  -1.2316],
        [ -3.1417,  20.0697],
        [ 10.2096,  -7.4616],
        [  5.0309, -19.7464],
        [ -8.7476,   6.9066],
        [ -0.8005,  -1.0980],
        [ -4.0823,  20.3419],
        [ -9.4493,  14.9877],
        [ -8.3847,   1.3439],
        [  9.8372, -12.3451],
        [ -4.4446,   0.2590],
        [  2.3731, -20.0532],
        [ -9.0390,  11.7097],
        [ 10.9985,  -9.5341],
        [  1.7175,   0.2842],
        [ -5.7252,  19.1324],
        [ 10.2253, -14.6114],
        [ -1.4362,  21.0257],
        [  1.8595,  -0.3209],
        [ -6.4896,  16.7186],
        [  2.9492, -18.5098],
        [  2.2604, -19.8169],
        [ -4.2597,   1.5817],
        [-10.3568,   5.5983],
        [ -5.8867,  16.6154],
        [ -4.8191,   2.0937],
        [  8.8207,  -4.5310],
        [-10.0377,   7.0535],
        [ 

In [5]:
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 2
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(256, n_out),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

# Function to create noise samples for the generator's input

def noise(batchSize):
    n = torch.randn(batchSize, 100)
    if torch.cuda.is_available(): return n.cuda() 
    return n

In [6]:
class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 2
        n_out = 1
        
        self.hidden0 = nn.Sequential( 
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            torch.nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
# def images_to_vectors(images):
#     return images.view(images.size(0), 2)

# def vectors_to_images(vectors):
#     return vectors.view(vectors.size(0), 1, 28, 28)

In [7]:
discriminator = DiscriminatorNet()
generator = GeneratorNet()

if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

# Optimizers

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

# Loss function

loss = nn.BCELoss()

# How many epochs to train for

num_epochs = 200

# Number of steps to apply to the discriminator for each step of the generator (1 in Goodfellow et al.)

d_steps = 1

In [8]:
def real_data_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = torch.ones(size, 1)
    if torch.cuda.is_available(): return data.cuda()
    return data

def fake_data_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = torch.zeros(size, 1)
    if torch.cuda.is_available(): return data.cuda()
    return data

def train_discriminator(optimizer, real_data, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # Propagate real data
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, real_data_target(real_data.size(0)))
    error_real.backward()

    # Propagate fake data
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
    error_fake.backward()
    
    # Take a step
    optimizer.step()
    
    # Return error
    return error_real + error_fake, prediction_real, prediction_fake


def train_generator(optimizer, fake_data):
    # Reset gradients
    optimizer.zero_grad()

    # Propagate the fake data through the discriminator and backpropagate.
    # Note that since we want the generator to output something that gets
    # the discriminator to output a 1, we use the real data target here.
    prediction = discriminator(fake_data)
    error = loss(prediction, real_data_target(prediction.size(0)))
    error.backward()
    
    # Update weights with gradients
    optimizer.step()
    
    # Return error
    return error

In [None]:
num_test_samples = 10000
test_noise = noise(num_test_samples)

logger = Logger(model_name='VGAN', data_name='MNIST')

for epoch in range(num_epochs):
    for n_batch, (real_batch) in enumerate(data_loader):

        # Train discriminator on a real batch and a fake batch
        # print(real_batch.shape)
        real_data = real_batch#images_to_vectors(real_batch)
        if torch.cuda.is_available(): real_data = real_data.cuda()
        fake_data = generator(noise(real_data.size(0))).detach()
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
                                                                real_data, fake_data)
        
        # Train generator

        fake_data = generator(noise(real_batch.size(0)))
        g_error = train_generator(g_optimizer, fake_data)
        
        # Log errors and display progress

        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        if (n_batch) % 100 == 0:
            # display.clear_output(True)
            # Display Images
            test_images = generator(test_noise)#vectors_to_images(generator(test_noise)).data.cpu()
            print(test_images.detach().numpy().shape)
            test_images = test_images.detach().numpy()
            plt.scatter(test_images[:,0], test_images[:,1])
            plt.show()
            # logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )
            
        # Save model checkpoints
        logger.save_models(generator, discriminator, epoch)