Credit : https://github.com/ParikhKadam/bidaf-keras

In [1]:
import tensorflow as tf
import json
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook
import subprocess as sp
import os

In [2]:
SET_FLOAT_DTYPE = 'float32'

if SET_FLOAT_DTYPE == 'float32':
    TF_DATATYPE = tf.float32
else:
    TF_DATATYPE = tf.float16
    

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [4]:
tf.__version__

'2.0.0'

In [5]:
# Credit to : https://www.kaggle.com/sanjay11100/squad-stanford-q-a-json-to-pandas-dataframe
# Modified to include first answer_start and text for dev set

def squad_json_to_dataframe_train(input_file_path, record_path = ['data','paragraphs','qas','answers'],
                           verbose = 1):
    """
    input_file_path: path to the squad json file.
    record_path: path to deepest level in json file default value is
    ['data','paragraphs','qas','answers']
    verbose: 0 to suppress it default is 1
    """
    if verbose:
        print("Reading the json file")    
    file = json.loads(open(input_file_path).read())
    if verbose:
        print("processing...")
    # parsing different level's in the json file
    js = pd.io.json.json_normalize(file , record_path )
    m = pd.io.json.json_normalize(file, record_path[:-1] )
    r = pd.io.json.json_normalize(file,record_path[:-2])
    
    #combining it into single dataframe
    idx = np.repeat(r['context'].values, r.qas.str.len())
    ndx  = np.repeat(m['id'].values,m['answers'].str.len())
    m['context'] = idx
    js['q_idx'] = ndx
    main = pd.concat([ m[['id','question','context']].set_index('id'),js.set_index('q_idx')],1,sort=False).reset_index()
    main['c_id'] = main['context'].factorize()[0]
    if verbose:
        print("shape of the dataframe is {}".format(main.shape))
        print("Done")
    return main


def squad_json_to_dataframe_dev(input_file_path, record_path = ['data','paragraphs','qas','answers'],
                           verbose = 1):
    """
    input_file_path: path to the squad json file.
    record_path: path to deepest level in json file default value is
    ['data','paragraphs','qas','answers']
    verbose: 0 to suppress it default is 1
    """
    if verbose:
        print("Reading the json file")    
    file = json.loads(open(input_file_path).read())
    if verbose:
        print("processing...")
    # parsing different level's in the json file
    js = pd.io.json.json_normalize(file , record_path )
    m = pd.io.json.json_normalize(file, record_path[:-1] )
    r = pd.io.json.json_normalize(file,record_path[:-2])
    
    #combining it into single dataframe
    idx = np.repeat(r['context'].values, r.qas.str.len())
#     ndx  = np.repeat(m['id'].values,m['answers'].str.len())
    m['context'] = idx
#     js['q_idx'] = ndx
    main = m[['id','question','context','answers']].set_index('id').reset_index()
    main['c_id'] = main['context'].factorize()[0]
    answer_start = []
    answer_text = []

    for answers in tqdm_notebook(main['answers'].values):
        answer_start.append(answers[0]['answer_start'])
        answer_text.append(answers[0]['text'])

    main['answer_start'] = answer_start
    main['text'] = answer_text
    if verbose:
        print("shape of the dataframe is {}".format(main.shape))
        print("Done")
    return main

In [6]:
TRAIN_NUM_SAMPLES = 40000 #df_train.shape[0]
DEV_NUM_SAMPLES = 10000 #df_dev.shape[0]

if not os.path.isfile('./data/data.h5'):
    df_train = squad_json_to_dataframe_train('./data/train-v1.1.json')
    df_dev = squad_json_to_dataframe_dev('./data/dev-v1.1.json')
 
    df_train = df_train[:TRAIN_NUM_SAMPLES]
    df_dev = df_dev[:DEV_NUM_SAMPLES]
else:
    df_train = pd.read_hdf('./data/data.h5', 'train')
    df_dev = pd.read_hdf('./data/data.h5', 'dev')
    
idx = np.random.permutation(df_train.shape[0])
df_train = df_train.iloc[idx]

In [7]:
from tqdm import tqdm_notebook
from nltk import word_tokenize

if not os.path.isfile('./data/data.h5'):
    answer_start = []
    answer_end = []
    for i in tqdm_notebook(range(df_train.shape[0])): 
        context_split = word_tokenize(df_train.context.values[i][:df_train.answer_start.values[i]])
        answer_start.append(len(context_split))
        answer_end.append(len(context_split) + len(word_tokenize(df_train.text.values[i])) -1)
    df_train['answer_end'] = answer_end
    df_train['answer_start'] = answer_start

    answer_start = []
    answer_end = []
    for i in tqdm_notebook(range(df_dev.shape[0])): 
        context_split = word_tokenize(df_dev.context.values[i][:df_dev.answer_start.values[i]])
        answer_start.append(len(context_split))
        answer_end.append(len(context_split) + len(word_tokenize(df_dev.text.values[i])) -1)
    df_dev['answer_end'] = answer_end
    df_dev['answer_start'] = answer_start

    df_train.question = [word_tokenize(q) for q in tqdm_notebook(df_train.question.values)]
    df_train.context = [word_tokenize(q) for q in tqdm_notebook(df_train.context.values)]

    df_dev.question = [word_tokenize(q) for q in tqdm_notebook(df_dev.question.values)]
    df_dev.context = [word_tokenize(q) for q in tqdm_notebook(df_dev.context.values)]
    
    df_train.to_hdf('./data/data.h5', 'train')
    df_dev.to_hdf('./data/data.h5', 'dev')

In [8]:
from pymagnitude import MagnitudeUtils, Magnitude
#from scripts import MagnitudeVectors

#vectors = MagnitudeVectors(50).load_vectors()
vectors = Magnitude('./data/magnitude/glove.6B.100d.magnitude')

In [9]:
y_train = df_train.answer_start.values, df_train.answer_end.values
x_train = df_train.context.values, df_train.question.values

y_dev = df_dev.answer_start.values, df_dev.answer_end.values
x_dev = df_dev.context.values, df_dev.question.values

In [10]:
BATCH_SIZE = 8
num_batches_per_epoch_train = int(np.ceil(TRAIN_NUM_SAMPLES /float(BATCH_SIZE)))
num_batches_per_epoch_dev = int(np.ceil(DEV_NUM_SAMPLES /float(BATCH_SIZE))) 

In [11]:
import gc

def train_gen():
    for i in range(TRAIN_NUM_SAMPLES):
        if not (i % BATCH_SIZE):
            context_pad_length = max([len(t) for t in x_train[0][i:i+BATCH_SIZE]])
            question_pad_length = max([len(t) for t in x_train[1][i:i+BATCH_SIZE]])

        X_context_batch = vectors.query(x_train[0][i], pad_to_length = context_pad_length)
        X_question_batch = vectors.query(x_train[1][i], pad_to_length = question_pad_length)

        #Y_start_batch = tf.keras.utils.to_categorical(y_train[0][i],context_pad_length).reshape(1,-1)
        #Y_end_batch = tf.keras.utils.to_categorical(y_train[1][i],context_pad_length).reshape(1,-1)
        #Y_batch = np.concatenate([Y_start_batch, Y_end_batch])
        #Y_batch = np.concatenate([y_train[0][i], y_train[1][i]])
        
        yield ((tf.constant(X_context_batch, dtype = SET_FLOAT_DTYPE), tf.constant(X_question_batch, dtype = SET_FLOAT_DTYPE)), 
               tf.constant([y_train[0][i]], dtype = SET_FLOAT_DTYPE)) #, y_train[1][i]],
        
        
def dev_gen():
    for i in range(DEV_NUM_SAMPLES):
        if not (i%BATCH_SIZE):
            context_pad_length = max([len(t) for t in x_dev[0][i:i+BATCH_SIZE]])
            question_pad_length = max([len(t) for t in x_dev[1][i:i+BATCH_SIZE]])

        X_context_batch = vectors.query(x_dev[0][i], pad_to_length = context_pad_length)
        X_question_batch = vectors.query(x_dev[1][i], pad_to_length = question_pad_length)

        #Y_start_batch = tf.keras.utils.to_categorical(y_dev[0][i],context_pad_length)
        #Y_end_batch =tf.keras.utils.to_categorical(y_dev[1][i],context_pad_length)
        
        yield ((tf.constant(X_context_batch, dtype = SET_FLOAT_DTYPE), tf.constant(X_question_batch, dtype = SET_FLOAT_DTYPE)), 
               tf.constant([y_dev[0][i]], dtype = SET_FLOAT_DTYPE)) #, y_dev[1][i]]
        
    

In [12]:
dataset_train = tf.data.Dataset.from_generator(train_gen, ((TF_DATATYPE, TF_DATATYPE), TF_DATATYPE)).batch(BATCH_SIZE).repeat().prefetch(1)
dataset_dev = tf.data.Dataset.from_generator(dev_gen, ((TF_DATATYPE, TF_DATATYPE), TF_DATATYPE)).batch(BATCH_SIZE).repeat().prefetch(1)


In [13]:
from tensorflow.keras import backend as K


K.set_floatx(SET_FLOAT_DTYPE)
#K.set_epsilon(1e-4)

K.clear_session()

In [14]:
def prepare_for_end_prob(inputs):
    encoded_context, merged_context, modeled_context, span_begin_probabilities = inputs
    weighted_sum = K.sum(K.expand_dims(span_begin_probabilities, axis=-1) * modeled_context, -2)
    passage_weighted_by_predicted_span = K.expand_dims(weighted_sum, axis=1)
    tile_shape = K.concatenate([[1], [K.shape(encoded_context)[1]], [1]], axis=0)
    passage_weighted_by_predicted_span = K.tile(passage_weighted_by_predicted_span, tile_shape)
    multiply1 = modeled_context * passage_weighted_by_predicted_span
    span_end_representation = K.concatenate(
            [merged_context, modeled_context, passage_weighted_by_predicted_span, multiply1])

    return span_end_representation


In [15]:
EMBED_LENGTH = 100
DROPOUT_RATE = 0.2

from tensorflow.keras.layers import Input, Add, LSTM, Bidirectional, Concatenate, TimeDistributed, Dense, Softmax, Flatten, Lambda, Multiply, Add, Dropout, SpatialDropout1D
from tensorflow.keras.models import Model 
from tensorflow.keras.optimizers import Adam, Adadelta
from tensorflow.keras.activations import linear
from layers import Similarity, C2QAttention, Q2CAttention, MergedContext, SpanBegin, SpanEnd, Highway

######## INPUT LAYER #########
context_input = Input(shape = (None, EMBED_LENGTH), dtype = SET_FLOAT_DTYPE, name = 'context_input')
question_input = Input(shape = (None, EMBED_LENGTH), dtype = SET_FLOAT_DTYPE, name = 'question_input')

#highway_layer = Highway(name='highway_1')                       
skip_conn_question = question_input
skip_conn_context = context_input

resnet_nonskip = Dense(EMBED_LENGTH, name = 'resnet-nonskip')
context_resnet = resnet_nonskip(context_input)
question_resnet = resnet_nonskip(question_input)

question_embedding = Add()([question_resnet, skip_conn_question])                       
context_embedding = Add()([context_resnet, skip_conn_context])


#question_embedding = highway_layer(question_input)                       
#context_embedding = highway_layer(context_input)

#highway_layer_2 = Highway(name='highway_2')                       
#question_embedding = highway_layer_2(question_embedding)                       
#context_embedding = highway_layer_2(context_embedding)

#context_embedding = context_input
#question_embedding = question_input


######## CONTEXTUAL EMBEDDING LAYER ########
encoder_layer = Bidirectional(LSTM(EMBED_LENGTH, return_sequences=True, recurrent_initializer='glorot_uniform'))
encoded_question = encoder_layer(question_embedding)
encoded_context = encoder_layer(context_embedding)

######## SIMILARITY LAYER ########
similarity_matrix = Similarity(name='similarity_layer')([encoded_context, encoded_question])


####### ATTENTION LAYER #########
context_to_query_attention = C2QAttention(name='context_to_query_attention')([
           similarity_matrix, encoded_question])
query_to_context_attention = Q2CAttention(name='query_to_context_attention')([
            similarity_matrix, encoded_context])

###### MERGE ATTENTIONS ########
context_concat = K.concatenate([encoded_context, context_to_query_attention, query_to_context_attention], axis = -1)
merged_context = Dense(500, name = 'merged_context')(context_concat)

#merged_context = MergedContext(name='merged_context')(
#            [encoded_context, context_to_query_attention, query_to_context_attention])

###### MODELLING LAYER #########
modeled_context = Bidirectional(LSTM(EMBED_LENGTH,return_sequences=True, recurrent_initializer='glorot_uniform'), name='decoder')(merged_context)
modeled_context = SpatialDropout1D(0.2)(modeled_context)


span_begin_concat = K.concatenate([merged_context, modeled_context], axis = -1)
span_begin_weights = Dense(1, name = 'Dense_span_begin')(span_begin_concat)
span_begin_weights = Dropout(0.2)(span_begin_weights)
span_begin_probabilities = Softmax(name = 'span-begin-output')(K.squeeze(span_begin_weights, axis=-1))

#span_end_representation = Lambda(prepare_for_end_prob)((encoded_context, merged_context, modeled_context, span_begin_probabilities))
#span_end_representation = Bidirectional(LSTM(EMBED_LENGTH, return_sequences=True, recurrent_initializer='glorot_uniform'), name='output_end_prob_decoder')(span_end_representation)
#span_end_input = K.concatenate([merged_context, span_end_representation])
#span_end_weights = TimeDistributed(Dense(1), name = 'Dense_span_end')(span_end_input)
#span_end_weights = Dropout(0.2)(span_end_weights)
#span_end_probabilities = Softmax(name = 'span-end-output')(K.squeeze(span_end_weights, axis=-1))

#output = Lambda(lambda x : K.stack([x[0], x[1]], axis = 1), name = 'output', dtype = 'float32')([span_begin_probabilities, span_end_probabilities])

model = Model([context_input, question_input], span_begin_probabilities)

In [16]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
context_input (InputLayer)      [(None, None, 100)]  0                                            
__________________________________________________________________________________________________
question_input (InputLayer)     [(None, None, 100)]  0                                            
__________________________________________________________________________________________________
resnet-nonskip (Dense)          (None, None, 100)    10100       context_input[0][0]              
                                                                 question_input[0][0]             
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, 100)    0           resnet-nonskip[0][0]         

In [17]:
import os 
!mkdir ./logs/temp_weights

class SaveWeights(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        #print(self.model.layers[2].get_weights()[0])

        if not (batch% int(num_batches_per_epoch_train/10)):
            
            print('Saving Temp Weights')
            self.model.save_weights('./logs/temp_weights/weights-batch-{}.h5'.format(batch))


save_weights = SaveWeights()

mkdir: cannot create directory ‘./logs/temp_weights’: File exists


In [18]:
from tensorflow.keras.losses import sparse_categorical_crossentropy
def negative_avg_log_error(y_true, y_pred):
    
    def sum_of_log_probabilities(true_and_pred):
        y_true, y_pred_start, y_pred_end = true_and_pred

        start_probability = tf.gather(y_pred_start, K.cast(y_true[0], dtype= tf.int32))
        end_probability = tf.gather(y_pred_start, K.cast(y_true[1], dtype= tf.int32))
        return (K.log(start_probability) + K.log(end_probability))
        
        #return K.cast(true_and_pred[0], dtype = 'float32')
    
    #y_true_start = y_true[:,0]
    #y_true_end = y_true[:,1]
    #y_true = K.squeeze(y_true, axis=1)
    y_pred_start = y_pred[:,0, :]
    y_pred_end = y_pred[:,1,:]
    
    batch_probability_sum = K.map_fn(sum_of_log_probabilities, (y_true, y_pred_start, y_pred_end), dtype='float32')
    
    return -K.mean(batch_probability_sum, axis = 0) #sparse_categorical_crossentropy(y_true, y_pred)

In [19]:
from tensorflow.keras.losses import sparse_categorical_crossentropy
def accuracy(y_true, y_pred):
    
    def calculate_accuracy(true_and_pred):
        y_true, y_pred_start, y_pred_end = true_and_pred

        start_probability = tf.gather(y_pred_start, K.cast(y_true[0], dtype= tf.int32))
        end_probability = tf.gather(y_pred_start, K.cast(y_true[1], dtype= tf.int32))
        return (start_probability + end_probability)/2.0
        
        #return K.cast(true_and_pred[0], dtype = 'float32')
    
    #y_true_start = y_true[:,0]
    #y_true_end = y_true[:,1]
    #y_true = K.squeeze(y_true, axis=1)
    y_pred_start = y_pred[:,0, :]
    y_pred_end = y_pred[:,1,:]
    
    batch_probability_sum = K.map_fn(calculate_accuracy, (y_true, y_pred_start, y_pred_end), dtype='float32')
    
    return -K.mean(batch_probability_sum, axis = 0) #sparse_categorical_crossentropy(y_true, y_pred)

In [20]:
from tensorflow.keras.optimizers import Adam, Adadelta, Nadam

model.compile(optimizer = Adadelta(learning_rate = 0.5), loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])


In [21]:
memory_usage = []
class MemoryCheck(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        mem = sp.check_output('nvidia-smi | grep python', shell=True).split()[-2].decode('utf-8')
        memory_usage.append(int(mem[:-3]))
        print(' ' + mem)

mem_check = MemoryCheck()

In [22]:
# CSVLogger, Model Checkpoint(Test save model), 2-highway, Dev set

In [23]:
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint

checkpoint = ModelCheckpoint('./logs/saved_models/bidaf-weights-best.h5', save_best_only = True, save_weights_only = True, mode = 'min', monitor = 'val_loss', verbose = 1)
logger = CSVLogger('./logs/training.log', append = True)

In [24]:
history = model.fit_generator(generator = dataset_train,
                    steps_per_epoch = num_batches_per_epoch_train, 
                    epochs = 10, 
                    validation_data = dataset_dev, 
                    validation_steps = num_batches_per_epoch_dev,
                    #workers = 8,
                    #use_multiprocessing = True,
                    #shuffle = True,
                    callbacks = [mem_check])

Epoch 1/10
 795MiB
   1/5000 [..............................] - ETA: 2:20:55 - loss: 5.5707 - accuracy: 0.0000e+00 1307MiB
   2/5000 [..............................] - ETA: 1:26:20 - loss: 5.5303 - accuracy: 0.0000e+00 1307MiB
   3/5000 [..............................] - ETA: 1:08:23 - loss: 5.3872 - accuracy: 0.0000e+00 1307MiB
   4/5000 [..............................] - ETA: 1:02:08 - loss: 5.4156 - accuracy: 0.0000e+00 1307MiB
   5/5000 [..............................] - ETA: 58:23 - loss: 5.3521 - accuracy: 0.0000e+00   1307MiB
   6/5000 [..............................] - ETA: 54:45 - loss: 5.3589 - accuracy: 0.0000e+00 1307MiB
   7/5000 [..............................] - ETA: 52:33 - loss: 5.3173 - accuracy: 0.0000e+00 1307MiB
   8/5000 [..............................] - ETA: 50:14 - loss: 5.2525 - accuracy: 0.0000e+00 1307MiB
   9/5000 [..............................] - ETA: 49:07 - loss: 5.1895 - accuracy: 0.0000e+00 1307MiB
  10/5000 [..............................] - ETA: 47:

  83/5000 [..............................] - ETA: 33:22 - loss: 4.8646 - accuracy: 0.0497 1307MiB
  84/5000 [..............................] - ETA: 33:21 - loss: 4.8639 - accuracy: 0.0491 1307MiB
  85/5000 [..............................] - ETA: 33:15 - loss: 4.8579 - accuracy: 0.0485 1307MiB
  86/5000 [..............................] - ETA: 33:10 - loss: 4.8423 - accuracy: 0.0509 1307MiB
  87/5000 [..............................] - ETA: 33:07 - loss: 4.8419 - accuracy: 0.0517 1307MiB
  88/5000 [..............................] - ETA: 33:04 - loss: 4.8316 - accuracy: 0.0540 1307MiB
  89/5000 [..............................] - ETA: 33:06 - loss: 4.8233 - accuracy: 0.0562 1307MiB
  90/5000 [..............................] - ETA: 33:03 - loss: 4.8190 - accuracy: 0.0556 1307MiB
  91/5000 [..............................] - ETA: 33:05 - loss: 4.8151 - accuracy: 0.0549 1307MiB
  92/5000 [..............................] - ETA: 33:04 - loss: 4.8126 - accuracy: 0.0543 1307MiB
  93/5000 [.........

 166/5000 [..............................] - ETA: 30:42 - loss: 4.6384 - accuracy: 0.0685 2119MiB
 167/5000 [>.............................] - ETA: 30:40 - loss: 4.6371 - accuracy: 0.0681 2119MiB
 168/5000 [>.............................] - ETA: 30:37 - loss: 4.6359 - accuracy: 0.0677 2119MiB
 169/5000 [>.............................] - ETA: 30:34 - loss: 4.6350 - accuracy: 0.0673 2119MiB
 170/5000 [>.............................] - ETA: 30:30 - loss: 4.6296 - accuracy: 0.0684 2119MiB
 171/5000 [>.............................] - ETA: 30:30 - loss: 4.6248 - accuracy: 0.0694 2119MiB
 172/5000 [>.............................] - ETA: 30:28 - loss: 4.6288 - accuracy: 0.0690 2119MiB
 173/5000 [>.............................] - ETA: 30:24 - loss: 4.6228 - accuracy: 0.0694 2119MiB
 174/5000 [>.............................] - ETA: 30:24 - loss: 4.6224 - accuracy: 0.0697 2119MiB
 175/5000 [>.............................] - ETA: 30:21 - loss: 4.6229 - accuracy: 0.0700 2119MiB
 176/5000 [>........

 249/5000 [>.............................] - ETA: 29:22 - loss: 4.5551 - accuracy: 0.0693 2119MiB
 250/5000 [>.............................] - ETA: 29:21 - loss: 4.5530 - accuracy: 0.0695 2119MiB
 251/5000 [>.............................] - ETA: 29:19 - loss: 4.5535 - accuracy: 0.0692 2119MiB
 252/5000 [>.............................] - ETA: 29:17 - loss: 4.5520 - accuracy: 0.0694 2119MiB
 253/5000 [>.............................] - ETA: 29:17 - loss: 4.5511 - accuracy: 0.0697 2119MiB
 254/5000 [>.............................] - ETA: 29:16 - loss: 4.5486 - accuracy: 0.0699 2119MiB
 255/5000 [>.............................] - ETA: 29:14 - loss: 4.5490 - accuracy: 0.0696 2119MiB
 256/5000 [>.............................] - ETA: 29:12 - loss: 4.5469 - accuracy: 0.0698 2119MiB
 257/5000 [>.............................] - ETA: 29:11 - loss: 4.5492 - accuracy: 0.0696 2119MiB
 258/5000 [>.............................] - ETA: 29:08 - loss: 4.5472 - accuracy: 0.0698 2119MiB
 259/5000 [>........

 332/5000 [>.............................] - ETA: 28:14 - loss: 4.5210 - accuracy: 0.0681 2119MiB
 333/5000 [>.............................] - ETA: 28:13 - loss: 4.5211 - accuracy: 0.0679 2119MiB
 334/5000 [=>............................] - ETA: 28:12 - loss: 4.5196 - accuracy: 0.0681 2119MiB
 335/5000 [=>............................] - ETA: 28:11 - loss: 4.5185 - accuracy: 0.0679 2119MiB
 336/5000 [=>............................] - ETA: 28:10 - loss: 4.5157 - accuracy: 0.0681 2119MiB
 337/5000 [=>............................] - ETA: 28:09 - loss: 4.5146 - accuracy: 0.0682 2119MiB
 338/5000 [=>............................] - ETA: 28:07 - loss: 4.5124 - accuracy: 0.0680 2119MiB
 339/5000 [=>............................] - ETA: 28:06 - loss: 4.5106 - accuracy: 0.0678 2119MiB
 340/5000 [=>............................] - ETA: 28:04 - loss: 4.5114 - accuracy: 0.0680 2119MiB
 341/5000 [=>............................] - ETA: 28:02 - loss: 4.5111 - accuracy: 0.0678 2119MiB
 342/5000 [=>.......

 415/5000 [=>............................] - ETA: 26:56 - loss: 4.4844 - accuracy: 0.0690 2119MiB
 416/5000 [=>............................] - ETA: 26:55 - loss: 4.4832 - accuracy: 0.0688 2119MiB
 417/5000 [=>............................] - ETA: 26:53 - loss: 4.4830 - accuracy: 0.0689 2119MiB
 418/5000 [=>............................] - ETA: 26:52 - loss: 4.4825 - accuracy: 0.0688 2119MiB
 419/5000 [=>............................] - ETA: 26:51 - loss: 4.4811 - accuracy: 0.0689 2119MiB
 420/5000 [=>............................] - ETA: 26:51 - loss: 4.4811 - accuracy: 0.0690 2119MiB
 421/5000 [=>............................] - ETA: 26:50 - loss: 4.4807 - accuracy: 0.0692 2119MiB
 422/5000 [=>............................] - ETA: 26:49 - loss: 4.4805 - accuracy: 0.0690 2119MiB
 423/5000 [=>............................] - ETA: 26:48 - loss: 4.4798 - accuracy: 0.0689 2119MiB
 424/5000 [=>............................] - ETA: 26:47 - loss: 4.4800 - accuracy: 0.0687 2119MiB
 425/5000 [=>.......

 498/5000 [=>............................] - ETA: 25:38 - loss: 4.4550 - accuracy: 0.0690 2119MiB
 499/5000 [=>............................] - ETA: 25:37 - loss: 4.4548 - accuracy: 0.0691 2119MiB
 500/5000 [==>...........................] - ETA: 25:36 - loss: 4.4543 - accuracy: 0.0693 2119MiB
 501/5000 [==>...........................] - ETA: 25:36 - loss: 4.4566 - accuracy: 0.0691 2119MiB
 502/5000 [==>...........................] - ETA: 25:35 - loss: 4.4551 - accuracy: 0.0692 2119MiB
 503/5000 [==>...........................] - ETA: 25:34 - loss: 4.4543 - accuracy: 0.0691 2119MiB
 504/5000 [==>...........................] - ETA: 25:33 - loss: 4.4543 - accuracy: 0.0689 2119MiB
 505/5000 [==>...........................] - ETA: 25:33 - loss: 4.4533 - accuracy: 0.0691 2119MiB
 506/5000 [==>...........................] - ETA: 25:32 - loss: 4.4537 - accuracy: 0.0692 2119MiB
 507/5000 [==>...........................] - ETA: 25:31 - loss: 4.4534 - accuracy: 0.0690 2119MiB
 508/5000 [==>......

 581/5000 [==>...........................] - ETA: 24:32 - loss: 4.4396 - accuracy: 0.0676 2119MiB
 582/5000 [==>...........................] - ETA: 24:31 - loss: 4.4388 - accuracy: 0.0674 2119MiB
 583/5000 [==>...........................] - ETA: 24:30 - loss: 4.4387 - accuracy: 0.0673 2119MiB
 584/5000 [==>...........................] - ETA: 24:30 - loss: 4.4377 - accuracy: 0.0679 2119MiB
 585/5000 [==>...........................] - ETA: 24:29 - loss: 4.4375 - accuracy: 0.0677 2119MiB
 586/5000 [==>...........................] - ETA: 24:28 - loss: 4.4378 - accuracy: 0.0678 2119MiB
 587/5000 [==>...........................] - ETA: 24:29 - loss: 4.4375 - accuracy: 0.0677 2119MiB
 588/5000 [==>...........................] - ETA: 24:28 - loss: 4.4371 - accuracy: 0.0676 2119MiB
 589/5000 [==>...........................] - ETA: 24:28 - loss: 4.4365 - accuracy: 0.0677 2119MiB
 590/5000 [==>...........................] - ETA: 24:27 - loss: 4.4357 - accuracy: 0.0680 2119MiB
 591/5000 [==>......

 664/5000 [==>...........................] - ETA: 23:45 - loss: 4.4209 - accuracy: 0.0715 2119MiB
 665/5000 [==>...........................] - ETA: 23:45 - loss: 4.4203 - accuracy: 0.0716 2119MiB
 666/5000 [==>...........................] - ETA: 23:44 - loss: 4.4197 - accuracy: 0.0719 2119MiB
 667/5000 [===>..........................] - ETA: 23:44 - loss: 4.4189 - accuracy: 0.0722 2119MiB
 668/5000 [===>..........................] - ETA: 23:43 - loss: 4.4202 - accuracy: 0.0720 2119MiB
 669/5000 [===>..........................] - ETA: 23:43 - loss: 4.4197 - accuracy: 0.0719 2119MiB
 670/5000 [===>..........................] - ETA: 23:42 - loss: 4.4196 - accuracy: 0.0722 2119MiB
 671/5000 [===>..........................] - ETA: 23:41 - loss: 4.4189 - accuracy: 0.0723 2119MiB
 672/5000 [===>..........................] - ETA: 23:41 - loss: 4.4197 - accuracy: 0.0722 2119MiB
 673/5000 [===>..........................] - ETA: 23:40 - loss: 4.4182 - accuracy: 0.0723 2119MiB
 674/5000 [===>.....

 747/5000 [===>..........................] - ETA: 22:55 - loss: 4.4013 - accuracy: 0.0748 2119MiB
 748/5000 [===>..........................] - ETA: 22:54 - loss: 4.4001 - accuracy: 0.0749 2119MiB
 749/5000 [===>..........................] - ETA: 22:54 - loss: 4.4003 - accuracy: 0.0751 2119MiB
 750/5000 [===>..........................] - ETA: 22:53 - loss: 4.4007 - accuracy: 0.0750 2119MiB
 751/5000 [===>..........................] - ETA: 22:52 - loss: 4.3999 - accuracy: 0.0751 2119MiB
 752/5000 [===>..........................] - ETA: 22:52 - loss: 4.3992 - accuracy: 0.0750 2119MiB
 753/5000 [===>..........................] - ETA: 22:51 - loss: 4.4000 - accuracy: 0.0750 2119MiB
 754/5000 [===>..........................] - ETA: 22:51 - loss: 4.3999 - accuracy: 0.0749 2119MiB
 755/5000 [===>..........................] - ETA: 22:51 - loss: 4.3998 - accuracy: 0.0748 2119MiB
 756/5000 [===>..........................] - ETA: 22:51 - loss: 4.3992 - accuracy: 0.0751 2119MiB
 757/5000 [===>.....

 830/5000 [===>..........................] - ETA: 22:12 - loss: 4.3827 - accuracy: 0.0750 2119MiB
 831/5000 [===>..........................] - ETA: 22:12 - loss: 4.3819 - accuracy: 0.0751 2119MiB
 832/5000 [===>..........................] - ETA: 22:11 - loss: 4.3816 - accuracy: 0.0751 2119MiB
 833/5000 [===>..........................] - ETA: 22:11 - loss: 4.3820 - accuracy: 0.0750 2119MiB
 834/5000 [====>.........................] - ETA: 22:10 - loss: 4.3823 - accuracy: 0.0751 2119MiB
 835/5000 [====>.........................] - ETA: 22:10 - loss: 4.3815 - accuracy: 0.0753 2119MiB
 836/5000 [====>.........................] - ETA: 22:09 - loss: 4.3814 - accuracy: 0.0752 2119MiB
 837/5000 [====>.........................] - ETA: 22:08 - loss: 4.3808 - accuracy: 0.0753 2119MiB
 838/5000 [====>.........................] - ETA: 22:08 - loss: 4.3804 - accuracy: 0.0755 2119MiB
 839/5000 [====>.........................] - ETA: 22:08 - loss: 4.3806 - accuracy: 0.0755 2119MiB
 840/5000 [====>....

 913/5000 [====>.........................] - ETA: 21:39 - loss: 4.3685 - accuracy: 0.0757 2119MiB
 914/5000 [====>.........................] - ETA: 21:38 - loss: 4.3682 - accuracy: 0.0756 2119MiB
 915/5000 [====>.........................] - ETA: 21:38 - loss: 4.3686 - accuracy: 0.0757 2119MiB
 916/5000 [====>.........................] - ETA: 21:38 - loss: 4.3683 - accuracy: 0.0756 2119MiB
 917/5000 [====>.........................] - ETA: 21:37 - loss: 4.3686 - accuracy: 0.0755 2119MiB
 918/5000 [====>.........................] - ETA: 21:37 - loss: 4.3684 - accuracy: 0.0756 2119MiB
 919/5000 [====>.........................] - ETA: 21:36 - loss: 4.3682 - accuracy: 0.0755 2119MiB
 920/5000 [====>.........................] - ETA: 21:36 - loss: 4.3679 - accuracy: 0.0755 2119MiB
 921/5000 [====>.........................] - ETA: 21:35 - loss: 4.3677 - accuracy: 0.0755 2119MiB
 922/5000 [====>.........................] - ETA: 21:35 - loss: 4.3666 - accuracy: 0.0755 2119MiB
 923/5000 [====>....

 996/5000 [====>.........................] - ETA: 21:10 - loss: 4.3490 - accuracy: 0.0763 2119MiB
 997/5000 [====>.........................] - ETA: 21:09 - loss: 4.3493 - accuracy: 0.0765 2119MiB
 998/5000 [====>.........................] - ETA: 21:09 - loss: 4.3494 - accuracy: 0.0764 2119MiB
 999/5000 [====>.........................] - ETA: 21:09 - loss: 4.3486 - accuracy: 0.0766 2119MiB
1000/5000 [=====>........................] - ETA: 21:08 - loss: 4.3486 - accuracy: 0.0765 2119MiB
1001/5000 [=====>........................] - ETA: 21:08 - loss: 4.3478 - accuracy: 0.0765 2119MiB
1002/5000 [=====>........................] - ETA: 21:07 - loss: 4.3484 - accuracy: 0.0766 2119MiB
1003/5000 [=====>........................] - ETA: 21:07 - loss: 4.3479 - accuracy: 0.0766 2119MiB
1004/5000 [=====>........................] - ETA: 21:07 - loss: 4.3471 - accuracy: 0.0769 2119MiB
1005/5000 [=====>........................] - ETA: 21:07 - loss: 4.3463 - accuracy: 0.0770 2119MiB
1006/5000 [=====>...

1079/5000 [=====>........................] - ETA: 20:38 - loss: 4.3381 - accuracy: 0.0777 2119MiB
1080/5000 [=====>........................] - ETA: 20:38 - loss: 4.3368 - accuracy: 0.0780 2119MiB
1081/5000 [=====>........................] - ETA: 20:38 - loss: 4.3372 - accuracy: 0.0779 2119MiB
1082/5000 [=====>........................] - ETA: 20:38 - loss: 4.3374 - accuracy: 0.0779 2119MiB
1083/5000 [=====>........................] - ETA: 20:37 - loss: 4.3369 - accuracy: 0.0780 2119MiB
1084/5000 [=====>........................] - ETA: 20:37 - loss: 4.3365 - accuracy: 0.0781 2119MiB
1085/5000 [=====>........................] - ETA: 20:37 - loss: 4.3364 - accuracy: 0.0781 2119MiB
1086/5000 [=====>........................] - ETA: 20:36 - loss: 4.3362 - accuracy: 0.0782 2119MiB
1087/5000 [=====>........................] - ETA: 20:36 - loss: 4.3355 - accuracy: 0.0782 2119MiB
1088/5000 [=====>........................] - ETA: 20:35 - loss: 4.3363 - accuracy: 0.0784 2119MiB
1089/5000 [=====>...

1162/5000 [=====>........................] - ETA: 20:08 - loss: 4.3232 - accuracy: 0.0798 2119MiB
1163/5000 [=====>........................] - ETA: 20:07 - loss: 4.3235 - accuracy: 0.0799 2119MiB
1164/5000 [=====>........................] - ETA: 20:07 - loss: 4.3234 - accuracy: 0.0799 2119MiB
1165/5000 [=====>........................] - ETA: 20:06 - loss: 4.3232 - accuracy: 0.0799 2119MiB
1166/5000 [=====>........................] - ETA: 20:06 - loss: 4.3230 - accuracy: 0.0799 2119MiB


























































ResourceExhaustedError: OOM when allocating tensor with shape[187200,600] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:MatMul]

In [None]:
import matplotlib.pyplot as plt

plt.plot(memory_usage)