In [88]:
import numpy as np
from torch import Tensor, randn, ones, zeros, optim, save
from torch.nn import Module, Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, BCELoss
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms
from sklearn.preprocessing import MinMaxScaler

In [41]:
def loadData(path):
    data = np.genfromtxt(path, delimiter=',')
    scaler = MinMaxScaler(feature_range=(-1,1))
    data = scaler.fit_transform(data)
    return Tensor(data)
        
def generateNoise(N):
    noise = Variable(randn(N, 100))
    return noise

def createBatches(data):
    batches = DataLoader(data, batch_size=100, shuffle=True)
    return batches

In [27]:
class Generator(Module):
    def __init__(self):
        super(Generator, self).__init__()
        d_input = 100
        d_output = 28*28
        
        self.input = Sequential(
            Linear(d_input, 256),
            ReLU()
        )
        self.hidden1 = Sequential(            
            Linear(256, 512),
            ReLU()
        )
        self.hidden2 = Sequential(
            Linear(512, 1024),
            ReLU()
        )
        self.output = Sequential(
            Linear(1024, d_output),
            Tanh()
        )
        
    def forward(self, x):
        x = self.input(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.output(x)
        return x

In [142]:
class Discriminator(Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        d_input = 28*28
        d_output = 1
        
        self.input = Sequential( 
            Linear(d_input, 1024),
            ReLU(),
            Dropout(0.2)
        )
        self.hidden1 = Sequential(
            Linear(1024, 512),
            ReLU(),
            Dropout(0.2)
        )
        self.hidden2 = Sequential(
            Linear(512, 256),
            ReLU(),
            Dropout(0.2)
        )
        self.output = Sequential(
            Linear(256, d_output),
            Sigmoid()
        )
    def forward(self, x):
        x = self.input(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.output(x)
        return x
        

In [29]:
""" 
Train Generator 

"""
def train_d(optim, trueData, falseData):
    trueN = trueData.size(0)
    falseN = falseData.size(0)
        
    # Train on true data
    optim.zero_grad()
    truePred = discriminator(trueData)
    trueError = loss(truePred, Variable(ones(trueN, 1)))
    trueError.backward()

    # Train on generated data
    falsePred = discriminator(falseData)
    falseError = loss(falsePred, Variable(zeros(falseN, 1)))
    falseError.backward()
    
    optim.step()

    return trueError + falseError

""" 
Train Generator based on Discriminators response 

"""
def train_g(optim, falseData):
    falseN = falseData.size(0)
    optim.zero_grad()

    pred = discriminator(falseData) # Response from Discriminator
    error = loss(pred, Variable(ones(falseN, 1))) # Pretending that the false data is true data
    error.backward()
    optim.step()
    
    return error

In [14]:
data = loadData('fashion_mnist.csv')
batches = createBatches(data)

In [144]:
epochs = 200
generator = Generator()
discriminator = Discriminator()

g_optim = optim.Adam(generator.parameters(), lr=2e-4)
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4)
loss = BCELoss()

errors_G = []
errors_D = []

print('-------- Training models --------')
for epoch in range(1, epochs):
    for n_batch, batch in enumerate(batches):
        # Train Discriminator
        trueData = Variable(batch)
        falseData = generator(generateNoise(batch.size(0))).detach() # detach to not calculate gradients
        error_d = train_d(d_optim, trueData, falseData)

        # Train Generator
        falseData = generator(generateNoise(batch.size(0)))
        error_g = train_g(g_optim, falseData)
    
    print('Epoch', epoch)
    print('Generator Error:', error_g.data[0])
    print('Discriminator Error:', error_d.data[0])
    print()
    
    errors_G.append(error_g.data[0])
    errors_D.append(error_d.data[0])
print('-------- Finished training --------')   

save(generator.state_dict(), 'generator.pt')
save(discriminator.state_dict(), 'discriminator.pt')

figG, axG = plt.subplots()
axG.plot(range(1, epochs), errors_G)
axG.set_xlabel('epoch')
axG.set_ylabel('error')
axG.set_title('Generator')
figG.savefig('errorG.png')


figD, axD = plt.subplots()
axD.plot(range(1, epochs), errors_D)
axD.set_xlabel('epoch')
axD.set_ylabel('error')
axD.set_title('Discriminator')
figD.savefig('errorD.png')


-------- Training models --------
Epoch 1
Generator Error: 6.585610866546631
Discriminator Error: 0.06657128781080246

Epoch 2
Generator Error: 6.290902614593506
Discriminator Error: 0.21591618657112122

Epoch 3
Generator Error: 4.079988479614258
Discriminator Error: 0.1162613108754158

Epoch 4
Generator Error: 4.1981401443481445
Discriminator Error: 0.2441917061805725

Epoch 5
Generator Error: 5.203634738922119
Discriminator Error: 0.25047945976257324

Epoch 6
Generator Error: 4.225180149078369
Discriminator Error: 0.28376469016075134

Epoch 7
Generator Error: 2.705981731414795
Discriminator Error: 0.2871397137641907

Epoch 8
Generator Error: 3.551867723464966
Discriminator Error: 0.3300721347332001

Epoch 9
Generator Error: 3.054715633392334
Discriminator Error: 0.3615506887435913

Epoch 10
Generator Error: 2.5832839012145996
Discriminator Error: 0.506671667098999

Epoch 11
Generator Error: 2.673696279525757
Discriminator Error: 0.42720502614974976

Epoch 12
Generator Error: 2.605177

Epoch 97
Generator Error: 1.2624363899230957
Discriminator Error: 1.0272928476333618

Epoch 98
Generator Error: 1.1193019151687622
Discriminator Error: 1.018295168876648

Epoch 99
Generator Error: 1.2085599899291992
Discriminator Error: 1.0281447172164917

Epoch 100
Generator Error: 1.174302101135254
Discriminator Error: 1.0751324892044067

Epoch 101
Generator Error: 1.2388075590133667
Discriminator Error: 0.9954402446746826

Epoch 102
Generator Error: 1.0610524415969849
Discriminator Error: 0.9604381322860718

Epoch 103
Generator Error: 1.0269018411636353
Discriminator Error: 1.0476651191711426

Epoch 104
Generator Error: 1.211050033569336
Discriminator Error: 1.0557975769042969

Epoch 105
Generator Error: 1.0241918563842773
Discriminator Error: 1.1133673191070557

Epoch 106
Generator Error: 1.1918383836746216
Discriminator Error: 1.141908884048462

Epoch 107
Generator Error: 1.148067831993103
Discriminator Error: 1.0034575462341309

Epoch 108
Generator Error: 1.171406626701355
Discri

Epoch 192
Generator Error: 1.2430617809295654
Discriminator Error: 0.9720213413238525

Epoch 193
Generator Error: 1.3644235134124756
Discriminator Error: 1.0251013040542603

Epoch 194
Generator Error: 1.3536027669906616
Discriminator Error: 0.8513145446777344

Epoch 195
Generator Error: 1.3220752477645874
Discriminator Error: 0.8891733288764954

Epoch 196
Generator Error: 1.3438664674758911
Discriminator Error: 0.867490828037262

Epoch 197
Generator Error: 1.2495372295379639
Discriminator Error: 0.9362081289291382

Epoch 198
Generator Error: 1.386513113975525
Discriminator Error: 1.0462803840637207

Epoch 199
Generator Error: 1.3919624090194702
Discriminator Error: 0.9366355538368225

-------- Finished training --------


In [1]:
generator = Generator()
model.load_state_dict(torch.load('generator.pt'))

testData = generator(generateNoise(100))

for image in testData:
    image = image.view(28,28).data
    plt.imshow(image, cmap='Greys')
    plt.axis('off')
    plt.show()

NameError: name 'Generator' is not defined