In [None]:
!pip install wandb
import numpy as np
import wandb
from wandb.keras import WandbCallback

from keras.models import Sequential
from keras.layers import LSTM, TimeDistributed, RepeatVector, Dense

wandb.init()
config = wandb.config



[34m[1mwandb[0m: Currently logged in as: [33mjotaro[0m (use `wandb login --relogin` to force relogin)


In [None]:
class CharacterTable(object):
    def __init__(self, chars):

        self.chars = sorted(set(chars)) #set of character contained in chars
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
    
    # encodeする
    def encode(self, C, num_rows):
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(C):
            x[i, self.char_indices[c]] = 1.0
        return x
    
    def decode(self, x, calc_argmax=True):
        if(calc_argmax):
            x = x.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in x)


In [None]:
config.training_size = 4000
config.digits = 2
config.hidden_size = 100
config.batch_size = 100

# Parameters for the model and dataset.
config.training_size = 200
config.digits = 1
config.hidden_size = 128
config.batch_size = 10

In [None]:
chars = '0123456789+- '
ctable = CharacterTable(chars)
maxlen = config.digits + 1 + config.digits

In [None]:
questions = []
expected = []
seen = set()

In [None]:
print('generating data')
while len(questions) < config.training_size:
    f = lambda: int(''.join(np.random.choice(list('0123456789'))for i in range(np.random.randint(1, config.digits+1))))
    a, b = f(), f()
    key = tuple(sorted((a, b)))
    if key in seen:
        continue
    seen.add(key)
    q = '{}-{}'.format(a,b)
    query = q + ' ' * (maxlen - len(q))
    ans = str(a-b)
    ans += ' ' * (config.digits + 1 - len(ans))
    questions.append(query)
    expected.append(ans)


generating data


In [None]:
print('Total addition questions:', len(questions))


Total addition questions: 4000


In [None]:
print('Vectorization...')
x = np.zeros((len(questions), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(questions), config.digits + 1, len(chars)), dtype=np.bool)
for i, sentence in enumerate(questions):
    x[i] = ctable.encode(sentence, maxlen)
for i, sentence in enumerate(expected):
    y[i] = ctable.encode(sentence, config.digits + 1)


Vectorization...


In [None]:
# Explicitly set apart 10% for validation data that we never train over.
split_at = len(x) - len(x) // 10
(x_train, x_val) = x[:split_at], x[split_at:]
(y_train, y_val) = y[:split_at], y[split_at:]

In [None]:
model = Sequential()
#encoder
model.add(LSTM(config.hidden_size, input_shape=(maxlen, len(chars)))) #input_shape = (len_chars, len_vector)
model.add(RepeatVector(config.digits+1))
#decoder
model.add(LSTM(config.hidden_size, return_sequences=True))
model.add(TimeDistributed(Dense(len(chars), activation = 'softmax')))
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (None, 100)               45600     
_________________________________________________________________
repeat_vector (RepeatVector) (None, 3, 100)            0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 3, 100)            80400     
_________________________________________________________________
time_distributed (TimeDistri (None, 3, 13)             1313      
Total params: 127,313
Trainable params: 127,313
Non-trainable params: 0
_________________________________________________________________


In [None]:
for iteration in range(1, 200):
    print()
    print('-' * 50)
    print('Iteration', iteration)
    model.fit(x_train, y_train,
              batch_size=config.batch_size,
              epochs=1,
              validation_data=(x_val, y_val),callbacks=[WandbCallback()])
    # Select 10 samples from the validation set at random so we can visualize
    # errors.
    for i in range(10):
        ind = np.random.randint(0, len(x_val))
        rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
        preds = model.predict_classes(rowx, verbose=0)
        q = ctable.decode(rowx[0])
        correct = ctable.decode(rowy[0])
        guess = ctable.decode(preds[0], calc_argmax=False)
        print('Q', q, end=' ')
        print('A', correct, end=' ')
        if correct == guess:
            print('✅', end=' ')
        else:
            print('❌', end=' ')
        print(guess)


--------------------------------------------------
Iteration 1




Q 31-47 A -16 ❌    
Q 54-33 A 21  ❌    
Q 38-80 A -42 ❌    
Q 60-63 A -3  ❌    
Q 57-22 A 35  ❌    
Q 75-16 A 59  ❌    
Q 46-90 A -44 ❌    
Q 64-12 A 52  ❌    
Q 33-42 A -9  ❌    
Q 37-94 A -57 ❌    

--------------------------------------------------
Iteration 2
Q 65-17 A 48  ❌ -4 
Q 43-10 A 33  ❌ -4 
Q 50-34 A 16  ❌ -4 
Q 51-14 A 37  ❌ -4 
Q 17-46 A -29 ❌ -4 
Q 41-68 A -27 ❌ -4 
Q 67-66 A 1   ❌ -4 
Q 24-90 A -66 ❌ -4 
Q 51-14 A 37  ❌ -4 
Q 64-12 A 52  ❌ -4 

--------------------------------------------------
Iteration 3
Q 54-54 A 0   ❌ -1 
Q 26-74 A -48 ❌ -1 
Q 79-81 A -2  ❌ -1 
Q 43-58 A -15 ❌ -1 
Q 51-14 A 37  ❌ -1 
Q 52-36 A 16  ❌ -1 
Q 71-91 A -20 ❌ -1 
Q 87-79 A 8   ❌ -1 
Q 87-61 A 26  ❌ -1 
Q 11-12 A -1  ✅ -1 

--------------------------------------------------
Iteration 4
Q 48-99 A -51 ❌ -1 
Q 11-12 A -1  ❌ -  
Q 27-42 A -15 ❌ -1 
Q 66-56 A 10  ❌ -  
Q 93-78 A 15  ❌ -1 
Q 47-97 A -50 ❌ -1 
Q 52-11 A 41  ❌ -  
Q 65-80 A -15 ❌ -1 
Q 12-48 A -36 ❌ -1 
Q 32-60 A -28 ❌ -1 

-------

In [None]:
_x = x_val[np.array([ind])]
_x

array([[[False, False, False, False,  True, False, False, False, False,
         False, False, False, False],
        [False, False, False, False, False, False, False, False,  True,
         False, False, False, False],
        [False, False,  True, False, False, False, False, False, False,
         False, False, False, False],
        [False, False, False, False, False, False, False, False, False,
          True, False, False, False],
        [False, False, False, False, False, False,  True, False, False,
         False, False, False, False]]])

In [None]:
y = model.predict(x[:2])

In [None]:
ctable.decode(y[0])

'96 '

In [None]:
ctable.decode(x[0])

'97-1 '