In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [3]:
data_mean = 4
data_stddev = 1.25

In [4]:
g_input_size = 1
g_hidden_size = 50
g_output_size = 1
d_input_size = 100
d_hidden_size = 50
d_output_size = 1
minibatch_size = d_input_size

In [5]:
d_learning_rate = 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1
g_steps = 1

In [6]:
# dataからdata[0]を引いて2乗したリストを結合している？
# 何のために？
def decorate_with_diffs(data, exponent):
    # data.data = Variable => Tensor
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

In [7]:
before = Variable(torch.rand(10, 1))
after = decorate_with_diffs(before, 2.0)
print(before)
print(after)

Variable containing:
 0.1432
 0.1751
 0.5798
 0.3055
 0.5783
 0.4844
 0.7658
 0.1784
 0.3332
 0.0914
[torch.FloatTensor of size 10x1]

Variable containing:
 0.1432  0.0000
 0.1751  0.0010
 0.5798  0.1906
 0.3055  0.0263
 0.5783  0.1892
 0.4844  0.1164
 0.7658  0.3876
 0.1784  0.0012
 0.3332  0.0361
 0.0914  0.0027
[torch.FloatTensor of size 10x2]



In [8]:
name = 'Data and variances'
preprocess = lambda data: decorate_with_diffs(data, 2.0)
d_input_func = lambda x: x * 2

In [9]:
print('Using data [%s]' % name)

Using data [Data and variances]


In [10]:
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))

In [11]:
d_sampler = get_distribution_sampler(data_mean, data_stddev)

In [12]:
d_sampler(10)


 5.0650  7.6520  3.2748  3.9452  4.0405  3.3511  3.8134  5.8915  5.2322  3.8391
[torch.FloatTensor of size 1x10]

In [13]:
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)

In [14]:
gi_sampler = get_generator_input_sampler()

In [15]:
gi_sampler(3, 2)


 0.0010  0.9491
 0.6421  0.1529
 0.5812  0.0512
[torch.FloatTensor of size 3x2]

In [16]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

In [17]:
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)

In [18]:
G

Generator (
  (map1): Linear (1 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

In [19]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # 本物画像のとき1、偽物画像のとき0 = 本物画像の確率を出力
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

In [20]:
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)

In [21]:
D

Discriminator (
  (map1): Linear (200 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

In [22]:
criterion = nn.BCELoss()

In [23]:
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)

In [24]:
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

In [25]:
d_optimizer

<torch.optim.adam.Adam at 0x10b562390>

In [39]:
def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

In [43]:
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real + fake
        D.zero_grad()
        
        # 1A. Train D on real
        # 指定した平均・標準偏差のランダムデータを生成
        # Gがこの分布を学習するのが最終目的
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        # 本物画像を入力したときは1を出力する
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))
        d_real_error.backward()

        # 1B. Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))
        d_fake_error.backward()
        d_optimizer.step()  # update parameters

    for g_index in range(g_steps):
        G.zero_grad()
        
        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        # Gは騙したいので本物=1に分類されるように学習する
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))
        g_error.backward()
        g_optimizer.step()
    
    if epoch % print_interval == 0:
        print('%s: D: %s/%s G: %s (Real: %s, Fake: %s)' % (
            epoch,
            extract(d_real_error)[0],
            extract(d_fake_error)[0],
            extract(g_error)[0],
            stats(extract(d_real_data)),
            stats(extract(d_fake_data))))

  "Please ensure they have the same size.".format(target.size(), input.size()))


0: D: 0.05512115731835365/0.4082205593585968 G: 1.105771541595459 (Real: [4.1565990126132961, 1.2711233205301873], Fake: [0.67944827854633327, 0.038399155626412371])
200: D: 0.0028105496894568205/0.1890818178653717 G: 1.7559853792190552 (Real: [3.9789160776138304, 1.3551562150017891], Fake: [0.4557408770918846, 0.045124946576745199])
400: D: 0.004272036254405975/0.08585026860237122 G: 2.6675233840942383 (Real: [3.7743640394695102, 1.2353088212241417], Fake: [0.28140656068921088, 0.10168814482323617])
600: D: 6.556531388923759e-06/0.21870669722557068 G: 1.951550006866455 (Real: [3.9861805975437163, 1.3794096915423726], Fake: [0.051964230462908746, 0.28537573942563749])
800: D: 0.0017126555321738124/0.029417188838124275 G: 3.291212320327759 (Real: [4.0462950253486634, 1.0573092765582288], Fake: [0.96370529353618617, 0.40037159610085621])
1000: D: 0.06592345237731934/0.48201489448547363 G: 1.9867020845413208 (Real: [3.8995013171434403, 1.1036350894963858], Fake: [1.7540165454521774, 1.019

KeyboardInterrupt: 