Credit : https://github.com/ParikhKadam/bidaf-keras

In [1]:
import tensorflow as tf
import json
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook
import subprocess as sp
import os

In [2]:
SET_FLOAT_DTYPE = 'float32'

if SET_FLOAT_DTYPE == 'float32':
    TF_DATATYPE = tf.float32
else:
    TF_DATATYPE = tf.float16
    

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [4]:
tf.__version__

'2.0.0'

In [5]:
# Credit to : https://www.kaggle.com/sanjay11100/squad-stanford-q-a-json-to-pandas-dataframe
# Modified to include first answer_start and text for dev set

def squad_json_to_dataframe_train(input_file_path, record_path = ['data','paragraphs','qas','answers'],
                           verbose = 1):
    """
    input_file_path: path to the squad json file.
    record_path: path to deepest level in json file default value is
    ['data','paragraphs','qas','answers']
    verbose: 0 to suppress it default is 1
    """
    if verbose:
        print("Reading the json file")    
    file = json.loads(open(input_file_path).read())
    if verbose:
        print("processing...")
    # parsing different level's in the json file
    js = pd.io.json.json_normalize(file , record_path )
    m = pd.io.json.json_normalize(file, record_path[:-1] )
    r = pd.io.json.json_normalize(file,record_path[:-2])
    
    #combining it into single dataframe
    idx = np.repeat(r['context'].values, r.qas.str.len())
    ndx  = np.repeat(m['id'].values,m['answers'].str.len())
    m['context'] = idx
    js['q_idx'] = ndx
    main = pd.concat([ m[['id','question','context']].set_index('id'),js.set_index('q_idx')],1,sort=False).reset_index()
    main['c_id'] = main['context'].factorize()[0]
    if verbose:
        print("shape of the dataframe is {}".format(main.shape))
        print("Done")
    return main


def squad_json_to_dataframe_dev(input_file_path, record_path = ['data','paragraphs','qas','answers'],
                           verbose = 1):
    """
    input_file_path: path to the squad json file.
    record_path: path to deepest level in json file default value is
    ['data','paragraphs','qas','answers']
    verbose: 0 to suppress it default is 1
    """
    if verbose:
        print("Reading the json file")    
    file = json.loads(open(input_file_path).read())
    if verbose:
        print("processing...")
    # parsing different level's in the json file
    js = pd.io.json.json_normalize(file , record_path )
    m = pd.io.json.json_normalize(file, record_path[:-1] )
    r = pd.io.json.json_normalize(file,record_path[:-2])
    
    #combining it into single dataframe
    idx = np.repeat(r['context'].values, r.qas.str.len())
#     ndx  = np.repeat(m['id'].values,m['answers'].str.len())
    m['context'] = idx
#     js['q_idx'] = ndx
    main = m[['id','question','context','answers']].set_index('id').reset_index()
    main['c_id'] = main['context'].factorize()[0]
    answer_start = []
    answer_text = []

    for answers in tqdm_notebook(main['answers'].values):
        answer_start.append(answers[0]['answer_start'])
        answer_text.append(answers[0]['text'])

    main['answer_start'] = answer_start
    main['text'] = answer_text
    if verbose:
        print("shape of the dataframe is {}".format(main.shape))
        print("Done")
    return main

In [6]:
TRAIN_NUM_SAMPLES = 40000 #df_train.shape[0]
DEV_NUM_SAMPLES = 10000 #df_dev.shape[0]

if not os.path.isfile('./data/data.h5'):
    df_train = squad_json_to_dataframe_train('./data/train-v1.1.json')
    df_dev = squad_json_to_dataframe_dev('./data/dev-v1.1.json')
 
    df_train = df_train[:TRAIN_NUM_SAMPLES]
    df_dev = df_dev[:DEV_NUM_SAMPLES]
else:
    df_train = pd.read_hdf('./data/data.h5', 'train')
    df_dev = pd.read_hdf('./data/data.h5', 'dev')
    
idx = np.random.permutation(df_train.shape[0])
df_train = df_train.iloc[idx]

In [7]:
from tqdm import tqdm_notebook
from nltk import word_tokenize

if not os.path.isfile('./data/data.h5'):
    answer_start = []
    answer_end = []
    for i in tqdm_notebook(range(df_train.shape[0])): 
        context_split = word_tokenize(df_train.context.values[i][:df_train.answer_start.values[i]])
        answer_start.append(len(context_split))
        answer_end.append(len(context_split) + len(word_tokenize(df_train.text.values[i])) -1)
    df_train['answer_end'] = answer_end
    df_train['answer_start'] = answer_start

    answer_start = []
    answer_end = []
    for i in tqdm_notebook(range(df_dev.shape[0])): 
        context_split = word_tokenize(df_dev.context.values[i][:df_dev.answer_start.values[i]])
        answer_start.append(len(context_split))
        answer_end.append(len(context_split) + len(word_tokenize(df_dev.text.values[i])) -1)
    df_dev['answer_end'] = answer_end
    df_dev['answer_start'] = answer_start

    df_train.question = [word_tokenize(q) for q in tqdm_notebook(df_train.question.values)]
    df_train.context = [word_tokenize(q) for q in tqdm_notebook(df_train.context.values)]

    df_dev.question = [word_tokenize(q) for q in tqdm_notebook(df_dev.question.values)]
    df_dev.context = [word_tokenize(q) for q in tqdm_notebook(df_dev.context.values)]
    
    df_train.to_hdf('./data/data.h5', 'train')
    df_dev.to_hdf('./data/data.h5', 'dev')

In [8]:
from pymagnitude import MagnitudeUtils, Magnitude
#from scripts import MagnitudeVectors

#vectors = MagnitudeVectors(50).load_vectors()
vectors = Magnitude('./data/magnitude/glove.6B.100d.magnitude')

In [9]:
y_train = df_train.answer_start.values, df_train.answer_end.values
x_train = df_train.context.values, df_train.question.values

y_dev = df_dev.answer_start.values, df_dev.answer_end.values
x_dev = df_dev.context.values, df_dev.question.values

In [10]:
BATCH_SIZE = 8
num_batches_per_epoch_train = int(np.ceil(TRAIN_NUM_SAMPLES /float(BATCH_SIZE)))
num_batches_per_epoch_dev = int(np.ceil(DEV_NUM_SAMPLES /float(BATCH_SIZE))) 

In [11]:
import gc

def train_gen():
    for i in range(TRAIN_NUM_SAMPLES):
        if not (i % BATCH_SIZE):
            context_pad_length = max([len(t) for t in x_train[0][i:i+BATCH_SIZE]])
            question_pad_length = max([len(t) for t in x_train[1][i:i+BATCH_SIZE]])

        X_context_batch = vectors.query(x_train[0][i], pad_to_length = context_pad_length)
        X_question_batch = vectors.query(x_train[1][i], pad_to_length = question_pad_length)

        Y_start_batch = tf.keras.utils.to_categorical(y_train[0][i],context_pad_length)
        #Y_end_batch = tf.keras.utils.to_categorical(y_train[1][i],context_pad_length).reshape(1,-1)
        #Y_batch = np.concatenate([Y_start_batch, Y_end_batch])
        #Y_batch = np.concatenate([y_train[0][i], y_train[1][i]])
        
        yield ((tf.constant(X_context_batch, dtype = SET_FLOAT_DTYPE), tf.constant(X_question_batch, dtype = SET_FLOAT_DTYPE)), 
               tf.constant(Y_start_batch, dtype = SET_FLOAT_DTYPE)) #, y_train[1][i]],
        
        
def dev_gen():
    for i in range(DEV_NUM_SAMPLES):
        if not (i%BATCH_SIZE):
            context_pad_length = max([len(t) for t in x_dev[0][i:i+BATCH_SIZE]])
            question_pad_length = max([len(t) for t in x_dev[1][i:i+BATCH_SIZE]])

        X_context_batch = vectors.query(x_dev[0][i], pad_to_length = context_pad_length)
        X_question_batch = vectors.query(x_dev[1][i], pad_to_length = question_pad_length)

        #Y_start_batch = tf.keras.utils.to_categorical(y_dev[0][i],context_pad_length)
        #Y_end_batch =tf.keras.utils.to_categorical(y_dev[1][i],context_pad_length)
        
        yield ((tf.constant(X_context_batch, dtype = SET_FLOAT_DTYPE), tf.constant(X_question_batch, dtype = SET_FLOAT_DTYPE)), 
               tf.constant([y_dev[0][i]], dtype = SET_FLOAT_DTYPE)) #, y_dev[1][i]]
        
    

In [14]:
dataset_train = tf.data.Dataset.from_generator(train_gen, ((TF_DATATYPE, TF_DATATYPE), TF_DATATYPE)).batch(BATCH_SIZE).repeat().prefetch(1)
dataset_dev = tf.data.Dataset.from_generator(dev_gen, ((TF_DATATYPE, TF_DATATYPE), TF_DATATYPE)).batch(BATCH_SIZE).repeat().prefetch(1)


In [15]:
from tensorflow.keras import backend as K


K.set_floatx(SET_FLOAT_DTYPE)
#K.set_epsilon(1e-4)

K.clear_session()

In [16]:
def prepare_for_end_prob(inputs):
    encoded_context, merged_context, modeled_context, span_begin_probabilities = inputs
    weighted_sum = K.sum(K.expand_dims(span_begin_probabilities, axis=-1) * modeled_context, -2)
    passage_weighted_by_predicted_span = K.expand_dims(weighted_sum, axis=1)
    tile_shape = K.concatenate([[1], [K.shape(encoded_context)[1]], [1]], axis=0)
    passage_weighted_by_predicted_span = K.tile(passage_weighted_by_predicted_span, tile_shape)
    multiply1 = modeled_context * passage_weighted_by_predicted_span
    span_end_representation = K.concatenate(
            [merged_context, modeled_context, passage_weighted_by_predicted_span, multiply1])

    return span_end_representation


In [27]:
EMBED_LENGTH = 100
DROPOUT_RATE = 0.2

from tensorflow.keras.layers import Input, Add, LSTM, Bidirectional, Concatenate, TimeDistributed, Dense, Softmax, Flatten, Lambda, Multiply, Add, Dropout, SpatialDropout1D
from tensorflow.keras.models import Model 
from tensorflow.keras.optimizers import Adam, Adadelta
from tensorflow.keras.activations import linear
from layers import Similarity, C2QAttention, Q2CAttention, MergedContext, SpanBegin, SpanEnd, Highway

######## INPUT LAYER #########
context_input = Input(shape = (None, EMBED_LENGTH), dtype = SET_FLOAT_DTYPE, name = 'context_input')
question_input = Input(shape = (None, EMBED_LENGTH), dtype = SET_FLOAT_DTYPE, name = 'question_input')

#highway_layer = Highway(name='highway_1')                       
skip_conn_question = question_input
skip_conn_context = context_input

resnet_nonskip = Dense(EMBED_LENGTH, name = 'resnet-nonskip')
context_resnet = resnet_nonskip(context_input)
question_resnet = resnet_nonskip(question_input)

question_embedding = Add()([question_resnet, skip_conn_question])                       
context_embedding = Add()([context_resnet, skip_conn_context])

skip_conn_question = question_embedding
skip_conn_context = context_embedding

resnet_nonskip_2 = Dense(EMBED_LENGTH, name = 'resnet-nonskip-2')
context_resnet = resnet_nonskip_2(context_embedding)
question_resnet = resnet_nonskip_2(question_embedding)

question_embedding = Add()([question_resnet, skip_conn_question])                       
context_embedding = Add()([context_resnet, skip_conn_context])


#question_embedding = highway_layer(question_input)                       
#context_embedding = highway_layer(context_input)

#highway_layer_2 = Highway(name='highway_2')                       
#question_embedding = highway_layer_2(question_embedding)                       
#context_embedding = highway_layer_2(context_embedding)

#context_embedding = context_input
#question_embedding = question_input


######## CONTEXTUAL EMBEDDING LAYER ########
encoder_layer = Bidirectional(LSTM(EMBED_LENGTH, return_sequences=True, recurrent_initializer='glorot_uniform'))
encoded_question = encoder_layer(question_embedding)
encoded_context = encoder_layer(context_embedding)

######## SIMILARITY LAYER ########
similarity_matrix = Similarity(name='similarity_layer')([encoded_context, encoded_question])


####### ATTENTION LAYER #########
context_to_query_attention = C2QAttention(name='context_to_query_attention')([
           similarity_matrix, encoded_question])
query_to_context_attention = Q2CAttention(name='query_to_context_attention')([
            similarity_matrix, encoded_context])

###### MERGE ATTENTIONS ########
context_concat = K.concatenate([encoded_context, context_to_query_attention, query_to_context_attention], axis = -1)
merged_context = Dense(500, name = 'merged_context')(context_concat)

#merged_context = MergedContext(name='merged_context')(
#            [encoded_context, context_to_query_attention, query_to_context_attention])

###### MODELLING LAYER #########
modeled_context = Bidirectional(LSTM(EMBED_LENGTH,return_sequences=True, recurrent_initializer='glorot_uniform'), name='decoder')(merged_context)
modeled_context = SpatialDropout1D(0.2)(modeled_context)


span_begin_concat = K.concatenate([merged_context, modeled_context], axis = -1)
span_begin_weights = Dense(1, name = 'Dense_span_begin')(span_begin_concat)
span_begin_weights = Dropout(0.2)(span_begin_weights)
span_begin_probabilities = Softmax(name = 'span-begin-output')(K.squeeze(span_begin_weights, axis=-1))

#span_end_representation = Lambda(prepare_for_end_prob)((encoded_context, merged_context, modeled_context, span_begin_probabilities))
#span_end_representation = Bidirectional(LSTM(EMBED_LENGTH, return_sequences=True, recurrent_initializer='glorot_uniform'), name='output_end_prob_decoder')(span_end_representation)
#span_end_input = K.concatenate([merged_context, span_end_representation])
#span_end_weights = TimeDistributed(Dense(1), name = 'Dense_span_end')(span_end_input)
#span_end_weights = Dropout(0.2)(span_end_weights)
#span_end_probabilities = Softmax(name = 'span-end-output')(K.squeeze(span_end_weights, axis=-1))

#output = Lambda(lambda x : K.stack([x[0], x[1]], axis = 1), name = 'output', dtype = 'float32')([span_begin_probabilities, span_end_probabilities])

model = Model([context_input, question_input], span_begin_probabilities)

In [28]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
context_input (InputLayer)      [(None, None, 100)]  0                                            
__________________________________________________________________________________________________
question_input (InputLayer)     [(None, None, 100)]  0                                            
__________________________________________________________________________________________________
resnet-nonskip (Dense)          (None, None, 100)    10100       context_input[0][0]              
                                                                 question_input[0][0]             
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, 100)    0           resnet-nonskip[0][0]       

In [19]:
import os 
!mkdir ./logs/temp_weights

class SaveWeights(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        #print(self.model.layers[2].get_weights()[0])

        if not (batch% int(num_batches_per_epoch_train/10)):
            
            print('Saving Temp Weights')
            self.model.save_weights('./logs/temp_weights/weights-batch-{}.h5'.format(batch))


save_weights = SaveWeights()

mkdir: cannot create directory ‘./logs/temp_weights’: File exists


In [20]:
from tensorflow.keras.losses import sparse_categorical_crossentropy
def negative_avg_log_error(y_true, y_pred):
    
    def sum_of_log_probabilities(true_and_pred):
        y_true, y_pred_start, y_pred_end = true_and_pred

        start_probability = tf.gather(y_pred_start, K.cast(y_true[0], dtype= tf.int32))
        end_probability = tf.gather(y_pred_start, K.cast(y_true[1], dtype= tf.int32))
        return (K.log(start_probability) + K.log(end_probability))
        
        #return K.cast(true_and_pred[0], dtype = 'float32')
    
    #y_true_start = y_true[:,0]
    #y_true_end = y_true[:,1]
    #y_true = K.squeeze(y_true, axis=1)
    y_pred_start = y_pred[:,0, :]
    y_pred_end = y_pred[:,1,:]
    
    batch_probability_sum = K.map_fn(sum_of_log_probabilities, (y_true, y_pred_start, y_pred_end), dtype='float32')
    
    return -K.mean(batch_probability_sum, axis = 0) #sparse_categorical_crossentropy(y_true, y_pred)

In [21]:
from tensorflow.keras.losses import sparse_categorical_crossentropy
def accuracy(y_true, y_pred):
    
    def calculate_accuracy(true_and_pred):
        y_true, y_pred_start, y_pred_end = true_and_pred

        start_probability = tf.gather(y_pred_start, K.cast(y_true[0], dtype= tf.int32))
        end_probability = tf.gather(y_pred_start, K.cast(y_true[1], dtype= tf.int32))
        return (start_probability + end_probability)/2.0
        
        #return K.cast(true_and_pred[0], dtype = 'float32')
    
    #y_true_start = y_true[:,0]
    #y_true_end = y_true[:,1]
    #y_true = K.squeeze(y_true, axis=1)
    y_pred_start = y_pred[:,0, :]
    y_pred_end = y_pred[:,1,:]
    
    batch_probability_sum = K.map_fn(calculate_accuracy, (y_true, y_pred_start, y_pred_end), dtype='float32')
    
    return -K.mean(batch_probability_sum, axis = 0) #sparse_categorical_crossentropy(y_true, y_pred)

In [22]:
from tensorflow.keras.optimizers import Adam, Adadelta, Nadam

model.compile(optimizer = Adadelta(learning_rate = 0.5), loss = 'categorical_crossentropy', metrics = ['accuracy'])


In [23]:
memory_usage = []
class MemoryCheck(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        mem = sp.check_output('nvidia-smi | grep python', shell=True).split()[-2].decode('utf-8')
        memory_usage.append(int(mem[:-3]))
        print(' ' + mem)

mem_check = MemoryCheck()

In [24]:
# CSVLogger, Model Checkpoint(Test save model), 2-highway, Dev set

In [25]:
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint

checkpoint = ModelCheckpoint('./logs/saved_models/bidaf-weights-best.h5', save_best_only = True, save_weights_only = True, mode = 'min', monitor = 'val_loss', verbose = 1)
logger = CSVLogger('./logs/training.log', append = True)

In [26]:
history = model.fit_generator(generator = dataset_train,
                    steps_per_epoch = num_batches_per_epoch_train, 
                    epochs = 10, 
                    validation_data = dataset_dev, 
                    validation_steps = num_batches_per_epoch_dev,
                    #workers = 8,
                    #use_multiprocessing = True,
                    #shuffle = True,
                    callbacks = [mem_check])

Epoch 1/10
 795MiB
   1/5000 [..............................] - ETA: 2:18:51 - loss: 5.0192 - accuracy: 0.0000e+00 1307MiB
   2/5000 [..............................] - ETA: 1:26:26 - loss: 5.3385 - accuracy: 0.0000e+00 1307MiB
   3/5000 [..............................] - ETA: 1:16:10 - loss: 5.3742 - accuracy: 0.0417     1307MiB
   4/5000 [..............................] - ETA: 1:05:07 - loss: 5.3304 - accuracy: 0.0312 1307MiB
   5/5000 [..............................] - ETA: 59:43 - loss: 5.3278 - accuracy: 0.0250   1307MiB
   6/5000 [..............................] - ETA: 54:59 - loss: 5.2792 - accuracy: 0.0208 1307MiB
   7/5000 [..............................] - ETA: 52:05 - loss: 5.2548 - accuracy: 0.0179 1307MiB
   8/5000 [..............................] - ETA: 50:02 - loss: 5.2608 - accuracy: 0.0156 1307MiB
   9/5000 [..............................] - ETA: 48:19 - loss: 5.2495 - accuracy: 0.0139 1307MiB
  10/5000 [..............................] - ETA: 45:57 - loss: 5.2364 - accu

  84/5000 [..............................] - ETA: 33:54 - loss: 4.8127 - accuracy: 0.0640 1307MiB
  85/5000 [..............................] - ETA: 33:51 - loss: 4.8040 - accuracy: 0.0647 1307MiB
  86/5000 [..............................] - ETA: 33:48 - loss: 4.8015 - accuracy: 0.0654 1307MiB
  87/5000 [..............................] - ETA: 33:45 - loss: 4.8040 - accuracy: 0.0647 1307MiB
  88/5000 [..............................] - ETA: 33:44 - loss: 4.8027 - accuracy: 0.0653 1307MiB
  89/5000 [..............................] - ETA: 33:36 - loss: 4.7988 - accuracy: 0.0646 1307MiB
  90/5000 [..............................] - ETA: 33:35 - loss: 4.7992 - accuracy: 0.0653 1307MiB
  91/5000 [..............................] - ETA: 33:33 - loss: 4.7918 - accuracy: 0.0659 1307MiB
  92/5000 [..............................] - ETA: 33:30 - loss: 4.8007 - accuracy: 0.0652 1307MiB
  93/5000 [..............................] - ETA: 33:29 - loss: 4.7917 - accuracy: 0.0659 1307MiB
  94/5000 [.........

 167/5000 [>.............................] - ETA: 31:20 - loss: 4.6666 - accuracy: 0.0726 1307MiB
 168/5000 [>.............................] - ETA: 31:19 - loss: 4.6645 - accuracy: 0.0722 1307MiB
 169/5000 [>.............................] - ETA: 31:16 - loss: 4.6626 - accuracy: 0.0717 1307MiB
 170/5000 [>.............................] - ETA: 31:12 - loss: 4.6636 - accuracy: 0.0713 1307MiB
 171/5000 [>.............................] - ETA: 31:10 - loss: 4.6641 - accuracy: 0.0709 1307MiB
 172/5000 [>.............................] - ETA: 31:07 - loss: 4.6628 - accuracy: 0.0705 1307MiB
 173/5000 [>.............................] - ETA: 31:04 - loss: 4.6627 - accuracy: 0.0701 1307MiB
 174/5000 [>.............................] - ETA: 31:00 - loss: 4.6582 - accuracy: 0.0704 1307MiB
 175/5000 [>.............................] - ETA: 30:57 - loss: 4.6572 - accuracy: 0.0707 1307MiB
 176/5000 [>.............................] - ETA: 30:54 - loss: 4.6532 - accuracy: 0.0710 1307MiB
 177/5000 [>........

 250/5000 [>.............................] - ETA: 29:58 - loss: 4.5811 - accuracy: 0.0700 1307MiB
 251/5000 [>.............................] - ETA: 29:57 - loss: 4.5813 - accuracy: 0.0697 1307MiB
 252/5000 [>.............................] - ETA: 29:56 - loss: 4.5810 - accuracy: 0.0694 1307MiB
 253/5000 [>.............................] - ETA: 29:55 - loss: 4.5808 - accuracy: 0.0692 1307MiB
 254/5000 [>.............................] - ETA: 29:53 - loss: 4.5792 - accuracy: 0.0694 1307MiB
 255/5000 [>.............................] - ETA: 29:52 - loss: 4.5790 - accuracy: 0.0691 1307MiB
 256/5000 [>.............................] - ETA: 29:49 - loss: 4.5779 - accuracy: 0.0688 1307MiB
 257/5000 [>.............................] - ETA: 29:49 - loss: 4.5770 - accuracy: 0.0691 1307MiB
 258/5000 [>.............................] - ETA: 29:48 - loss: 4.5760 - accuracy: 0.0693 1307MiB
 259/5000 [>.............................] - ETA: 29:47 - loss: 4.5754 - accuracy: 0.0690 1307MiB
 260/5000 [>........

 333/5000 [>.............................] - ETA: 28:39 - loss: 4.5371 - accuracy: 0.0706 1307MiB
 334/5000 [=>............................] - ETA: 28:38 - loss: 4.5360 - accuracy: 0.0707 1307MiB
 335/5000 [=>............................] - ETA: 28:36 - loss: 4.5363 - accuracy: 0.0705 1307MiB
 336/5000 [=>............................] - ETA: 28:35 - loss: 4.5341 - accuracy: 0.0707 1307MiB
 337/5000 [=>............................] - ETA: 28:34 - loss: 4.5321 - accuracy: 0.0708 1307MiB
 338/5000 [=>............................] - ETA: 28:33 - loss: 4.5303 - accuracy: 0.0714 1307MiB
 339/5000 [=>............................] - ETA: 28:31 - loss: 4.5310 - accuracy: 0.0712 1307MiB
 340/5000 [=>............................] - ETA: 28:30 - loss: 4.5296 - accuracy: 0.0713 1307MiB
 341/5000 [=>............................] - ETA: 28:28 - loss: 4.5254 - accuracy: 0.0718 1307MiB
 342/5000 [=>............................] - ETA: 28:27 - loss: 4.5235 - accuracy: 0.0716 1307MiB
 343/5000 [=>.......

 416/5000 [=>............................] - ETA: 27:20 - loss: 4.4916 - accuracy: 0.0712 1307MiB
 417/5000 [=>............................] - ETA: 27:20 - loss: 4.4905 - accuracy: 0.0710 1307MiB
 418/5000 [=>............................] - ETA: 27:19 - loss: 4.4911 - accuracy: 0.0709 1307MiB
 419/5000 [=>............................] - ETA: 27:18 - loss: 4.4906 - accuracy: 0.0707 1307MiB
 420/5000 [=>............................] - ETA: 27:17 - loss: 4.4913 - accuracy: 0.0705 1307MiB
 421/5000 [=>............................] - ETA: 27:16 - loss: 4.4905 - accuracy: 0.0707 1307MiB
 422/5000 [=>............................] - ETA: 27:16 - loss: 4.4910 - accuracy: 0.0705 2241MiB
 423/5000 [=>............................] - ETA: 27:16 - loss: 4.4904 - accuracy: 0.0709 2241MiB
 424/5000 [=>............................] - ETA: 27:15 - loss: 4.4894 - accuracy: 0.0708 2241MiB
 425/5000 [=>............................] - ETA: 27:14 - loss: 4.4896 - accuracy: 0.0706 2241MiB
 426/5000 [=>.......

 499/5000 [=>............................] - ETA: 26:17 - loss: 4.4649 - accuracy: 0.0716 2241MiB
 500/5000 [==>...........................] - ETA: 26:17 - loss: 4.4653 - accuracy: 0.0715 2241MiB
 501/5000 [==>...........................] - ETA: 26:22 - loss: 4.4645 - accuracy: 0.0716 2241MiB
 502/5000 [==>...........................] - ETA: 26:22 - loss: 4.4636 - accuracy: 0.0720 2241MiB
 503/5000 [==>...........................] - ETA: 26:21 - loss: 4.4631 - accuracy: 0.0718 2241MiB
 504/5000 [==>...........................] - ETA: 26:21 - loss: 4.4628 - accuracy: 0.0717 2241MiB
 505/5000 [==>...........................] - ETA: 26:20 - loss: 4.4606 - accuracy: 0.0715 2241MiB
 506/5000 [==>...........................] - ETA: 26:19 - loss: 4.4613 - accuracy: 0.0714 2241MiB
 507/5000 [==>...........................] - ETA: 26:18 - loss: 4.4595 - accuracy: 0.0717 2241MiB
 508/5000 [==>...........................] - ETA: 26:17 - loss: 4.4586 - accuracy: 0.0719 2241MiB
 509/5000 [==>......

 582/5000 [==>...........................] - ETA: 25:27 - loss: 4.4401 - accuracy: 0.0739 2241MiB
 583/5000 [==>...........................] - ETA: 25:26 - loss: 4.4398 - accuracy: 0.0742 2241MiB
 584/5000 [==>...........................] - ETA: 25:25 - loss: 4.4389 - accuracy: 0.0741 2241MiB
 585/5000 [==>...........................] - ETA: 25:24 - loss: 4.4369 - accuracy: 0.0741 2241MiB
 586/5000 [==>...........................] - ETA: 25:24 - loss: 4.4359 - accuracy: 0.0744 2241MiB
 587/5000 [==>...........................] - ETA: 25:23 - loss: 4.4355 - accuracy: 0.0745 2241MiB
 588/5000 [==>...........................] - ETA: 25:22 - loss: 4.4341 - accuracy: 0.0746 2241MiB
 589/5000 [==>...........................] - ETA: 25:22 - loss: 4.4350 - accuracy: 0.0745 2241MiB
 590/5000 [==>...........................] - ETA: 25:22 - loss: 4.4334 - accuracy: 0.0748 2241MiB
 591/5000 [==>...........................] - ETA: 25:21 - loss: 4.4329 - accuracy: 0.0747 2241MiB
 592/5000 [==>......

 665/5000 [==>...........................] - ETA: 24:43 - loss: 4.4042 - accuracy: 0.0776 2241MiB
 666/5000 [==>...........................] - ETA: 24:43 - loss: 4.4042 - accuracy: 0.0777 2241MiB
 667/5000 [===>..........................] - ETA: 24:42 - loss: 4.4035 - accuracy: 0.0776 2241MiB
 668/5000 [===>..........................] - ETA: 24:42 - loss: 4.4033 - accuracy: 0.0775 2241MiB
 669/5000 [===>..........................] - ETA: 24:42 - loss: 4.4033 - accuracy: 0.0774 2241MiB
 670/5000 [===>..........................] - ETA: 24:41 - loss: 4.4044 - accuracy: 0.0772 2241MiB
 671/5000 [===>..........................] - ETA: 24:40 - loss: 4.4042 - accuracy: 0.0771 2241MiB
 672/5000 [===>..........................] - ETA: 24:39 - loss: 4.4047 - accuracy: 0.0770 2241MiB
 673/5000 [===>..........................] - ETA: 24:38 - loss: 4.4042 - accuracy: 0.0769 2241MiB
 674/5000 [===>..........................] - ETA: 24:37 - loss: 4.4040 - accuracy: 0.0770 2241MiB
 675/5000 [===>.....

 748/5000 [===>..........................] - ETA: 23:59 - loss: 4.3972 - accuracy: 0.0769 2241MiB
 749/5000 [===>..........................] - ETA: 23:59 - loss: 4.3968 - accuracy: 0.0769 2241MiB
 750/5000 [===>..........................] - ETA: 23:58 - loss: 4.3964 - accuracy: 0.0768 2241MiB
 751/5000 [===>..........................] - ETA: 23:58 - loss: 4.3956 - accuracy: 0.0769 2241MiB
 752/5000 [===>..........................] - ETA: 23:57 - loss: 4.3962 - accuracy: 0.0768 2241MiB
 753/5000 [===>..........................] - ETA: 23:56 - loss: 4.3956 - accuracy: 0.0769 2241MiB
 754/5000 [===>..........................] - ETA: 23:56 - loss: 4.3940 - accuracy: 0.0771 2241MiB
 755/5000 [===>..........................] - ETA: 23:55 - loss: 4.3948 - accuracy: 0.0772 2241MiB
 756/5000 [===>..........................] - ETA: 23:55 - loss: 4.3947 - accuracy: 0.0771 2241MiB
 757/5000 [===>..........................] - ETA: 23:54 - loss: 4.3947 - accuracy: 0.0771 2241MiB
 758/5000 [===>.....

 831/5000 [===>..........................] - ETA: 23:16 - loss: 4.3798 - accuracy: 0.0784 2241MiB
 832/5000 [===>..........................] - ETA: 23:16 - loss: 4.3794 - accuracy: 0.0784 2241MiB
 833/5000 [===>..........................] - ETA: 23:15 - loss: 4.3791 - accuracy: 0.0783 2241MiB
 834/5000 [====>.........................] - ETA: 23:14 - loss: 4.3786 - accuracy: 0.0782 2241MiB
 835/5000 [====>.........................] - ETA: 23:14 - loss: 4.3781 - accuracy: 0.0784 2241MiB
 836/5000 [====>.........................] - ETA: 23:13 - loss: 4.3778 - accuracy: 0.0785 2241MiB
 837/5000 [====>.........................] - ETA: 23:13 - loss: 4.3769 - accuracy: 0.0787 2241MiB
 838/5000 [====>.........................] - ETA: 23:12 - loss: 4.3776 - accuracy: 0.0786 2241MiB
 839/5000 [====>.........................] - ETA: 23:12 - loss: 4.3779 - accuracy: 0.0785 2241MiB
 840/5000 [====>.........................] - ETA: 23:11 - loss: 4.3773 - accuracy: 0.0784 2241MiB
 841/5000 [====>....

 914/5000 [====>.........................] - ETA: 22:38 - loss: 4.3617 - accuracy: 0.0785 2241MiB
 915/5000 [====>.........................] - ETA: 22:37 - loss: 4.3608 - accuracy: 0.0788 2241MiB
 916/5000 [====>.........................] - ETA: 22:36 - loss: 4.3601 - accuracy: 0.0789 2241MiB
 917/5000 [====>.........................] - ETA: 22:36 - loss: 4.3595 - accuracy: 0.0788 2241MiB
 918/5000 [====>.........................] - ETA: 22:35 - loss: 4.3595 - accuracy: 0.0787 2241MiB
 919/5000 [====>.........................] - ETA: 22:35 - loss: 4.3595 - accuracy: 0.0788 2241MiB
 920/5000 [====>.........................] - ETA: 22:34 - loss: 4.3594 - accuracy: 0.0787 2241MiB
 921/5000 [====>.........................] - ETA: 22:34 - loss: 4.3592 - accuracy: 0.0786 2241MiB
 922/5000 [====>.........................] - ETA: 22:33 - loss: 4.3595 - accuracy: 0.0786 2241MiB
 923/5000 [====>.........................] - ETA: 22:33 - loss: 4.3583 - accuracy: 0.0787 2241MiB
 924/5000 [====>....

 997/5000 [====>.........................] - ETA: 22:00 - loss: 4.3467 - accuracy: 0.0786 2241MiB
 998/5000 [====>.........................] - ETA: 22:00 - loss: 4.3462 - accuracy: 0.0787 2241MiB
 999/5000 [====>.........................] - ETA: 21:59 - loss: 4.3460 - accuracy: 0.0787 2241MiB
1000/5000 [=====>........................] - ETA: 21:59 - loss: 4.3465 - accuracy: 0.0786 2241MiB
1001/5000 [=====>........................] - ETA: 21:58 - loss: 4.3465 - accuracy: 0.0785 2241MiB
1002/5000 [=====>........................] - ETA: 21:58 - loss: 4.3462 - accuracy: 0.0785 2241MiB
1003/5000 [=====>........................] - ETA: 21:57 - loss: 4.3456 - accuracy: 0.0788 2241MiB
1004/5000 [=====>........................] - ETA: 21:57 - loss: 4.3454 - accuracy: 0.0787 2241MiB
1005/5000 [=====>........................] - ETA: 21:56 - loss: 4.3452 - accuracy: 0.0786 2241MiB
1006/5000 [=====>........................] - ETA: 21:55 - loss: 4.3448 - accuracy: 0.0785 2241MiB
1007/5000 [=====>...

1080/5000 [=====>........................] - ETA: 21:28 - loss: 4.3371 - accuracy: 0.0784 2241MiB
1081/5000 [=====>........................] - ETA: 21:27 - loss: 4.3365 - accuracy: 0.0783 2241MiB
1082/5000 [=====>........................] - ETA: 21:27 - loss: 4.3367 - accuracy: 0.0782 2241MiB
1083/5000 [=====>........................] - ETA: 21:27 - loss: 4.3368 - accuracy: 0.0781 2241MiB
1084/5000 [=====>........................] - ETA: 21:26 - loss: 4.3364 - accuracy: 0.0783 2241MiB
1085/5000 [=====>........................] - ETA: 21:26 - loss: 4.3359 - accuracy: 0.0782 2241MiB
1086/5000 [=====>........................] - ETA: 21:25 - loss: 4.3353 - accuracy: 0.0783 2241MiB
1087/5000 [=====>........................] - ETA: 21:25 - loss: 4.3354 - accuracy: 0.0783 2241MiB
1088/5000 [=====>........................] - ETA: 21:24 - loss: 4.3350 - accuracy: 0.0782 2241MiB
1089/5000 [=====>........................] - ETA: 21:24 - loss: 4.3353 - accuracy: 0.0783 2241MiB
1090/5000 [=====>...

1163/5000 [=====>........................] - ETA: 20:59 - loss: 4.3266 - accuracy: 0.0803 2241MiB
1164/5000 [=====>........................] - ETA: 20:59 - loss: 4.3271 - accuracy: 0.0802 2241MiB
1165/5000 [=====>........................] - ETA: 20:58 - loss: 4.3267 - accuracy: 0.0802 2241MiB
1166/5000 [=====>........................] - ETA: 20:58 - loss: 4.3264 - accuracy: 0.0803 2241MiB


















KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

plt.plot(memory_usage)