In [1]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
#from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [31]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.encoder0 = nn.Sequential(
            nn.Conv2d(13, 64, (4,4), stride=(2,2), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
        self.encoder1 = nn.Sequential(
            nn.Conv2d(64, 128, (4,4), stride=(2,2), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
        self.hidden0 = nn.Sequential(
            nn.Conv2d(128, 512, (4,4), stride=(2,2), padding=1),
            nn.ReLU()
        )
        
        self.decoder0 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, (4,4), stride=(2,2), padding=1),
            nn.BatchNorm2d(128),
            nn.Dropout()
            # CONCATENATE OUTPUT WITH OUTPUT OF ENC1
        )
        
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, (4,4), stride=(2,2), padding=1), # was 128
            nn.BatchNorm2d(64),
            nn.Dropout()
            # CONCATENATE OUTPUT WITH OUTPUT OF ENC0
        )
        
        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(128, 13, (4,4), stride=(2,2), padding=1),
            nn.Softmax()
        )
        
    def forward(self, x):
        print("Generator starting with:", x.shape)
        # Assume x is concatenated X,y
        # Only take [batch_size:13:8:8]
        x = x[:,:,:,:13].permute(0,3,1,2)
        x0 = x.clone().detach()
        print("Cut size:", x.shape)
        
        x1 = self.encoder0(x)
        print(x1.shape)
        x2 = self.encoder1(x1)
        print(x2.shape)
        x = self.hidden0(x2)
        print(x.shape)
        x = self.decoder0(x)
        print(x.shape)
        # concatenate output of decoder0 with output of encoder1
        x = torch.cat((x,x2), dim=1)
        print(x.shape)
        x = self.decoder1(x)
        print(x.shape)
        # concatenate output of decoder1 with output of encoder0
        x = torch.cat((x,x1), dim=1)
        print(x.shape)
        x = self.hidden1(x)
        print(x.shape)
        out_tensor = torch.cat((x0,x), dim=1)
        print("Generator output:", out_tensor.shape)
        return out_tensor

In [42]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.hidden0 = nn.Sequential( 
            nn.Conv2d(26, 64, (4,4), stride=(2,2), padding=1),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(64, 128, (4,4), stride=(2,2), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(128, 256, (4,4), stride=(2,2), padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(256, 512, (4,4), stride=(2,2), padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.hidden4 = nn.Sequential(
            nn.Conv2d(512, 512, (4,4), padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.out = nn.Sequential(
            # Output is a single variable representing probability between 0 and 1 that the input is real
            nn.Conv2d(512, 1, (4,4), padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """Note: assuming x is already concatenated in form [initial board, next board]."""
        
        print(x.shape)
        x = self.hidden0(x)
        print(x.shape)
        x = self.hidden1(x)
        print(x.shape)
        x = self.hidden2(x)
        print(x.shape)
        x = F.pad(x,(1,0,1,0))
        print(x.shape)
        x = self.hidden3(x)
        print(x.shape)
        x = F.pad(x,(1,0,1,0))
        print(x.shape)
        x = self.hidden4(x)
        print(x.shape)
        x = F.pad(x,(1,0,1,0))
        print(x.shape)
        x = self.out(x)
        print("Disc output:",x.shape)
        return x

In [45]:
discriminator = Discriminator()
generator = Generator()

# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

# Loss function
loss = nn.BCELoss()

# Number of steps to apply to the discriminator
d_steps = 1  # In Goodfellow et. al 2014 this variable is assigned to 1
# Number of epochs
num_epochs = 200
# Batch size
batch_size = 100

In [46]:
# changed torch.ones(size,1) to torch.ones(1,1) since output of model is a single number
def real_data_target(batch_size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones([batch_size,1,1,1]))
    if torch.cuda.is_available(): return data.cuda()
    return data

def fake_data_target(batch_size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros([batch_size,1,1,1]))
    if torch.cuda.is_available(): return data.cuda()
    return data

In [6]:
def gen_real_samples(X,y):
    """Returns a list of tensors in the form [initial board, next board]"""
    data = []
    for i in range(X.shape[0]):
        x_elt = torch.Tensor(X[i])
        y_elt = torch.Tensor(y[i])
        elt = torch.cat((x_elt, y_elt),dim=-1)
        data.append(elt)
    return data

In [7]:
# Create Dataset object
#   must inherit form torch.utils.data.Dataset
#   must override len(gamesDataSet) and gamesDataSet[i]
class gamesDataSet(Dataset):
    def __init__(self, data, transform=None):
        self.data = pd.DataFrame(data=data)
        #self.data.append([{"label": "fake", "data": d} for d in fake_data])
        
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
            
        sample = self.data.iloc[idx,0]
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

# Load data

In [8]:
import os
import numpy as np
import pandas as pd
# utils is a python file in this folder with functions from the other jupyter notebooks
from utils import xy_split, generate_random_boards 

#all_data = generate_random_boards()

# load "real" data (in this case random) from the file in directory
with open('random_data.npy', 'rb') as f:
    all_data = np.load(f)

# split data into X and y sets (inputs and targets)
X,y = xy_split(all_data, True)

# Combine X and y into a dataset where each ith element is the merged ith element of X and y
realdata = gen_real_samples(X,y)
print(realdata[0].shape)
# Turn realdata into a gamesDataSet object to give it the interface of torch.utils.data.Dataset
# Essentially lets us do what we do in next cell
dataset = gamesDataSet(data=realdata)

183 183
torch.Size([8, 8, 26])


  values = np.array([convert(v) for v in values])
  values = np.array([convert(v) for v in values])


In [51]:
def train_discriminator(optimizer, real_data, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data.permute(0,3,1,2))
    # Calculate error and backpropagate
    error_real = loss(prediction_real, real_data_target(real_data.shape[0]))
    error_real.backward()

    # 1.2 Train on Fake Data
    print("Fake data:", fake_data.shape)
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, fake_data_target(real_data.shape[0]))
    error_fake.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    
    # Return error
    return error_real + error_fake, prediction_real, prediction_fake

def train_generator(optimizer, fake_data):
    # 2. Train Generator
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, real_data_target(fake_data.shape[0]))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

In [52]:
# create dataloader object from dataset so that we can iterate thru batches & shuffle data
# https://pytorch.org/docs/stable/data.html
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
num_batches = len(data_loader)

In [53]:
#logger = Logger(model_name='VGAN', data_name='MNIST')

for epoch in range(num_epochs):
    for n_batch, real_batch in enumerate(data_loader):
        print(real_batch.shape)
        
        # 1. Train Discriminator
        real_data = Variable(real_batch)
        if torch.cuda.is_available(): real_data = real_data.cuda()
        
        # Generate fake data - form [initial board, next board]
        print("Real data:",real_data.shape)
        fake_data = generator(real_data).detach()
        
        # Train D
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
                                                                real_data, fake_data)

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(real_data)
        
        # Preprocess fake data (merge with previous board)
        # Train G
        g_error = train_generator(g_optimizer, fake_data)
        
        '''
        # Log error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        
        
        # Display Progress
        if (n_batch) % 100 == 0:
            display.clear_output(True)
            # Display Images
            test_images = vectors_to_images(generator(test_noise)).data.cpu()
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )
        # Model Checkpoints
        logger.save_models(generator, discriminator, epoch)
        '''

torch.Size([100, 8, 8, 26])
Real data: torch.Size([100, 8, 8, 26])
Generator starting with: torch.Size([100, 8, 8, 26])
Cut size: torch.Size([100, 13, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 2, 2])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 4, 4])
torch.Size([100, 13, 8, 8])
Generator output: torch.Size([100, 26, 8, 8])
torch.Size([100, 26, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 1, 1])
torch.Size([100, 256, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
Disc output: torch.Size([100, 1, 1, 1])
Fake data: torch.Size([100, 26, 8, 8])
torch.Size([100, 26, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 1, 1])
torch.Size([100, 256, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
torch.Size([100, 512, 1, 1])

torch.Size([83, 8, 8, 26])
Real data: torch.Size([83, 8, 8, 26])
Generator starting with: torch.Size([83, 8, 8, 26])
Cut size: torch.Size([83, 13, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 2, 2])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 4, 4])
torch.Size([83, 13, 8, 8])
Generator output: torch.Size([83, 26, 8, 8])
torch.Size([83, 26, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 1, 1])
torch.Size([83, 256, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
Disc output: torch.Size([83, 1, 1, 1])
Fake data: torch.Size([83, 26, 8, 8])
torch.Size([83, 26, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 1, 1])
torch.Size([83, 256, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
Dis

torch.Size([100, 8, 8, 26])
Real data: torch.Size([100, 8, 8, 26])
Generator starting with: torch.Size([100, 8, 8, 26])
Cut size: torch.Size([100, 13, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 2, 2])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 4, 4])
torch.Size([100, 13, 8, 8])
Generator output: torch.Size([100, 26, 8, 8])
torch.Size([100, 26, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 1, 1])
torch.Size([100, 256, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
Disc output: torch.Size([100, 1, 1, 1])
Fake data: torch.Size([100, 26, 8, 8])
torch.Size([100, 26, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 1, 1])
torch.Size([100, 256, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
torch.Size([100, 512, 1, 1])

torch.Size([83, 8, 8, 26])
Real data: torch.Size([83, 8, 8, 26])
Generator starting with: torch.Size([83, 8, 8, 26])
Cut size: torch.Size([83, 13, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 2, 2])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 4, 4])
torch.Size([83, 13, 8, 8])
Generator output: torch.Size([83, 26, 8, 8])
torch.Size([83, 26, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 1, 1])
torch.Size([83, 256, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
Disc output: torch.Size([83, 1, 1, 1])
Fake data: torch.Size([83, 26, 8, 8])
torch.Size([83, 26, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 1, 1])
torch.Size([83, 256, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
Dis

torch.Size([100, 8, 8, 26])
Real data: torch.Size([100, 8, 8, 26])
Generator starting with: torch.Size([100, 8, 8, 26])
Cut size: torch.Size([100, 13, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 2, 2])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 4, 4])
torch.Size([100, 13, 8, 8])
Generator output: torch.Size([100, 26, 8, 8])
torch.Size([100, 26, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 1, 1])
torch.Size([100, 256, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
Disc output: torch.Size([100, 1, 1, 1])
Fake data: torch.Size([100, 26, 8, 8])
torch.Size([100, 26, 8, 8])
torch.Size([100, 64, 4, 4])
torch.Size([100, 128, 2, 2])
torch.Size([100, 256, 1, 1])
torch.Size([100, 256, 2, 2])
torch.Size([100, 512, 1, 1])
torch.Size([100, 512, 2, 2])
torch.Size([100, 512, 1, 1])

torch.Size([83, 8, 8, 26])
Real data: torch.Size([83, 8, 8, 26])
Generator starting with: torch.Size([83, 8, 8, 26])
Cut size: torch.Size([83, 13, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 2, 2])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 4, 4])
torch.Size([83, 13, 8, 8])
Generator output: torch.Size([83, 26, 8, 8])
torch.Size([83, 26, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 1, 1])
torch.Size([83, 256, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
Disc output: torch.Size([83, 1, 1, 1])
Fake data: torch.Size([83, 26, 8, 8])
torch.Size([83, 26, 8, 8])
torch.Size([83, 64, 4, 4])
torch.Size([83, 128, 2, 2])
torch.Size([83, 256, 1, 1])
torch.Size([83, 256, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
torch.Size([83, 512, 1, 1])
torch.Size([83, 512, 2, 2])
Dis

KeyboardInterrupt: 