In [0]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [0]:
%cd '/content/drive/My Drive/shared-works/Thread-1'

/content/drive/My Drive/shared-works/Thread-1


In [0]:
!pip install ftfy



In [0]:
import tensorflow as tf
import pickle
from model import Transformer
from utils import iter_data, Logger
import time
import numpy as np

In [0]:
def save(model, train_loss_results, validation_loss_results, cnt):
    with open('./Data/interpretation/equal_segment/pickle/train_loss_results.pkl', 'wb') as pkl:
        pickle.dump(train_loss_results, pkl)

    with open('./Data/interpretation/equal_segment/pickle/validation_loss_results.pkl', 'wb') as pkl:
        pickle.dump(validation_loss_results, pkl)

    model.save_weights("./checkpoints/interpretation/equal_segment/model/cp-{}.ckpt".format(format(cnt)))

In [0]:
def format(x, max_len=4):
    x = str(x)
    return "0" * (max_len - len(x)) + x

In [0]:
def STLR(t, cut_frac = 0.1, ratio = 32, lr_max = 0.002, T = 1400000):
    cut = int(T * cut_frac)
    p = t / cut if t < cut else 1 - (t - cut) / (cut * (ratio - 1))
    lr = lr_max * (1 + p * (ratio - 1)) / ratio
    return lr

In [0]:
def decay(lr, t):
    lr -= lr * (1 / (t ** 0.5))
    return lr

In [0]:
def train(model, val_data_len, learning_rate=0.00025, 
          n_epochs=100, n_embd = 768, n_vocab = 40478, n_batch=64, n_ctx=512, n_special = 1, n_segment = 3,
          train_steps=100, validation_steps=5000, save_steps=5000, log_path='train.log', lr_fn = 'STLR', cnt=0):
    
    """
        X : (batch size, seq len, 3 (IDs and positions and segments)) -> tokens
        M1: (batch size, seq len) -> masks for getting 2nd paragraph in input tokens
        M2: (batch size, seq len) -> masks for getting 2nd paragraph in predicted tokens
    """
    
    
    def parse_function(data_record):
        features={
            'triple': tf.VarLenFeature(tf.int64),
            'tokens_mask': tf.VarLenFeature(tf.int64),
            'preds_mask': tf.VarLenFeature(tf.int64),
            'book_id': tf.FixedLenFeature([], tf.int64),
            'counter': tf.FixedLenFeature([], tf.int64),
            'file_id': tf.FixedLenFeature([], tf.int64)
        }
        
        example=tf.parse_single_example(data_record, features)
        tp=example['triple'].values
        tp=tp-1
        tm=example['tokens_mask'].values
        tm=tm-1
        pm=example['preds_mask'].values
        pm=pm-1
        md=(example['book_id'], example['counter'], example['file_id'])
        print(".............")
        print(type(tp), type(tm), type(pm), type(md))
        
        return (tp, tm, pm, md[0], md[1], md[2])
        
    def create_dataset(n_files, n_batch, train=True):
        tfrecord_files=[]
        if train:
            for b in range(n_files-1):
                f_name='/content/drive/My Drive/shared-works/Thread-1/Data/interpretation/equal_segment/'+str(b)+'.tfrecord'
                tfrecord_files.append(f_name)
                
        else:
            f_name='/content/drive/My Drive/shared-works/Thread-1/Data/interpretation/equal_segment/'+str(n_files-1)+'.tfrecord'
            tfrecord_files.append(f_name)

        dataset=tf.data.TFRecordDataset(tfrecord_files)
        dataset=dataset.map(parse_function)
#         dataset=dataset.shuffle(buffer_size=10)
        dataset=dataset.padded_batch(batch_size=n_batch, padded_shapes=((None,),(None,),(None,),(),(),()), drop_remainder=True)
        dataset=dataset.repeat()
        return dataset
    
    tf.reset_default_graph()
    
    t_dataset=create_dataset(7, n_batch)
    t_iterator=t_dataset.make_one_shot_iterator()
    X_t,M1_t,M2_t,md_t0, md_t1, md_t2=t_iterator.get_next()
    
    v_dataset=create_dataset(7, n_batch, train=False)
    v_iterator=v_dataset.make_one_shot_iterator()
    X_v,M1_v,M2_v,md_v0, md_v1, md_v2=v_iterator.get_next()
    
    t_saveable = tf.data.experimental.make_saveable_from_iterator(t_iterator)
    tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, t_saveable)
    
    v_saveable = tf.data.experimental.make_saveable_from_iterator(v_iterator)
    tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, v_saveable)
        
        

    
#     tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, model.variables)
        
    X = tf.placeholder(tf.int32, [None, None, 3])
    M1 = tf.placeholder(tf.int32, [None, None])
    M2 = tf.placeholder(tf.int32, [None, None])
    logits, losses = model([X, M1, M2])
    # Create model's graph

    

    lr = tf.placeholder(tf.float32, shape=[])
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
    # Define Optimizer
    
    
    real_grads = tf.gradients(losses, model.variables)
    real_grads[0] = tf.convert_to_tensor(real_grads[0])
#     fake_grads = [tf.zeros_like(grad, dtype = tf.float32) for grad in real_grads]
    
#     def set_zero():
#         return real_grads
    
#     def preserve():
#         return fake_grads
    
#     limit = tf.constant(5.64)
#     grads = tf.cond(tf.greater(losses, limit), set_zero, preserve)
    grads_and_vars = zip(real_grads, model.variables)
    capped_grads_and_vars = [(tf.clip_by_norm(grad, 0.25), var) for grad, var in grads_and_vars]
    train_op = optimizer.apply_gradients(capped_grads_and_vars)
    # Create nodes for applying gradients
    

    print(tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
    print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
    
    
    sess = tf.Session()
    tf.keras.backend.set_session(sess)
    sess.run(tf.global_variables_initializer())
    
    saver=tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS), max_to_keep=0)
    
    if cnt==0:
        
        model.load_weights('./checkpoints/Segment/cp-0001.ckpt')
        train_loss_results=[]
        validation_loss_results=[]
    
    if cnt>0:
        
#         tf.reset_default_graph()
#         imported_graph = tf.train.import_meta_graph('./checkpoints/language_model/my-model0.meta')
#         imported_graph.restore(sess, './checkpoints/language_model/my-model0')
        
        
        print('before')
        saver.restore(sess, "./checkpoints/interpretation/equal_segment/iterator/svble-{}".format(format(cnt)))
        print('after')
        load_path="./checkpoints/interpretation/equal_segment/model/cp-{}.ckpt".format(format(cnt))
        model.load_weights(load_path)
        with open('./Data/interpretation/equal_segment/pickle/train_loss_results.pkl', 'rb') as pkl:
            train_loss_results = pickle.load(pkl)
        with open('./Data/interpretation/equal_segment/pickle/validation_loss_results.pkl', 'rb') as pkl:
            validation_loss_results = pickle.load(pkl)
    

    train_losses = []
    

    

    step=370000
#     step = 510000
#     cnt = 85
    max_loss = 0
#     train_generator = iter_data(n_batch, n_epochs)
    logger = Logger(path=log_path)
    start = time.time()
    
    np.set_printoptions(precision = 4)

    while True:
        step += 1
        if step%100==0:
            print('step:',step)
        
        tokenss, maskss1, maskss2, m1, m2, m3=sess.run([X_t,M1_t,M2_t,md_t0, md_t1, md_t2])
        tokens=np.asarray(tokenss).reshape([n_batch,-1,3])
        masks1=np.asarray(maskss1)
        masks2=np.asarray(maskss2)
#         print('shapes')
#         print(tokens)
#         print(masks1)
#         print(masks2)
#         print(tokens.shape)
#         print(masks1.shape)
#         print(masks2.shape)
        metadata=(m1,m2,m3)
        if lr_fn == 'STLR':
            _, train_loss, gradients = sess.run([train_op, losses, real_grads], {X: tokens, M1: masks1, M2: masks2, lr: STLR(step)})

        else:
            _, train_loss, gradients = sess.run([train_op, losses, real_grads], {X: tokens, M1: masks1, M2: masks2, lr: learning_rate})   
            

        if train_loss > max_loss:
            max_loss = train_loss
            print('\n {} {} , {}\n'.format(step, metadata, np.power(2, train_loss)))
            
        else:
            train_losses.append(train_loss)
            
        if step % 1000 == 1:
            print('\n' + str(np.linalg.norm(gradients[3])) + "  " + str(np.linalg.norm(gradients[0])) + '\n')
        
        if step % train_steps == 0:
            if lr_fn != 'STLR':
                learning_rate = decay(learning_rate, step)

            train_loss_results.append(sum(train_losses) / len(train_losses))
            train_losses = []
            logger.log(step=step, train_loss=train_loss_results[-1], time=time.time() - start)
            print('Step: {} -- Time: {} => ppl: {}'.format(step, int(time.time() - start), np.power(2, train_loss_results[-1])))

        if step % validation_steps == 0:
            print('validation')
#             validation_generator = iter_data(n_batch, train=False)
            validation_losses = []
            for rr in range(val_data_len):
                validation_tokenss, validation_maskss1, validation_maskss2, _,_,_=sess.run([X_v,M1_v,M2_v,md_v0, md_v1, md_v2])
                validation_tokens=np.asarray(validation_tokenss).reshape([n_batch,-1,3])
                validation_masks1=np.asarray(validation_maskss1)
                validation_masks2=np.asarray(validation_maskss2)
                validation_losses.append(sess.run(losses, {X: validation_tokens, M1: validation_masks1, M2: validation_masks2}))

            validation_loss_results.append(sum(validation_losses) / len(validation_losses))
            logger.log(step=step, validation_loss=validation_loss_results[-1], time=time.time() - start)
            print('### Step: {} -- Time: {} => ppl: {}'.format(step, int(time.time() - start), np.power(2, validation_loss_results[-1])))

        if step % save_steps == 0:
            cnt += 1
            print('cnt',cnt)
            save(model, train_loss_results, validation_loss_results, cnt)
            save_path=saver.save(sess, "./checkpoints/interpretation/equal_segment/iterator/svble-{}".format(format(cnt)))
                

In [0]:
model = Transformer("Model", 40478)
val_data_len=500


In [0]:
train(model, val_data_len,n_batch = 2, n_epochs = 1, cnt=74)

W0705 13:58:37.042046 140359178696576 deprecation.py:323] From <ipython-input-8-a5a44df9df52>:56: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.


.............
<class 'tensorflow.python.framework.ops.Tensor'> <class 'tensorflow.python.framework.ops.Tensor'> <class 'tensorflow.python.framework.ops.Tensor'> <class 'tuple'>
.............
<class 'tensorflow.python.framework.ops.Tensor'> <class 'tensorflow.python.framework.ops.Tensor'> <class 'tensorflow.python.framework.ops.Tensor'> <class 'tuple'>


W0705 13:58:38.047585 140359178696576 deprecation_wrapper.py:119] From /content/drive/My Drive/shared-works/Thread-1/model.py:55: The name tf.keras.initializers.random_normal is deprecated. Please use tf.compat.v1.keras.initializers.random_normal instead.

W0705 13:58:38.049215 140359178696576 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/initializers.py:143: calling RandomNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0705 13:58:38.434679 140359178696576 deprecation_wrapper.py:119] From /content/drive/My Drive/shared-works/Thread-1/model.py:124: The name tf.matrix_band_part is deprecated. Please use tf.linalg.band_part instead.

W0705 13:58:38.536351 140359178696576 deprecation_wrapper.py:119] From /content/drive/My Drive/shared-works/Thread-1/model.py:36

[<tensorflow.python.data.experimental.ops.iterator_ops._Saveable object at 0x7fa7012d8f28>, <tensorflow.python.data.experimental.ops.iterator_ops._Saveable object at 0x7fa7012e4710>]
[<tf.Variable 'Model/embedding/we:0' shape=(40994, 768) dtype=float32>, <tf.Variable 'Model/h//attn/conv1d/w:0' shape=(1, 768, 2304) dtype=float32>, <tf.Variable 'Model/h//attn/conv1d/b:0' shape=(2304,) dtype=float32>, <tf.Variable 'Model/h//attn/conv1d_1/w:0' shape=(1, 768, 768) dtype=float32>, <tf.Variable 'Model/h//attn/conv1d_1/b:0' shape=(768,) dtype=float32>, <tf.Variable 'Model/h//ln_1/g:0' shape=(768,) dtype=float32>, <tf.Variable 'Model/h//ln_1/b:0' shape=(768,) dtype=float32>, <tf.Variable 'Model/h//mlp/conv1d_2/w:0' shape=(1, 768, 3072) dtype=float32>, <tf.Variable 'Model/h//mlp/conv1d_2/b:0' shape=(3072,) dtype=float32>, <tf.Variable 'Model/h//mlp/conv1d_3/w:0' shape=(1, 3072, 768) dtype=float32>, <tf.Variable 'Model/h//mlp/conv1d_3/b:0' shape=(768,) dtype=float32>, <tf.Variable 'Model/h//ln_2/

W0705 13:58:51.820018 140359178696576 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


before
after

 370001 (array([ 1257, 10081]), array([436, 266]), array([1, 1])) , 6.468842877371718


1.563  3.1315894


 370002 (array([25762,  2038]), array([1155,  288]), array([1, 1])) , 14.131115044541147


 370003 (array([ 1366, 17870]), array([3204,  338]), array([1, 1])) , 16.299456964696443


 370013 (array([43235,  9785]), array([ 364, 2079]), array([1, 1])) , 17.217200718002687


 370015 (array([ 1921, 26536]), array([1200,   27]), array([1, 1])) , 32.88426943187013

step: 370100
Step: 370100 -- Time: 39 => ppl: 11.032262153247041

 370141 (array([18734, 51294]), array([ 678, 2747]), array([1, 1])) , 47.740846831727936

step: 370200
Step: 370200 -- Time: 68 => ppl: 11.234475561141181
step: 370300
Step: 370300 -- Time: 98 => ppl: 11.689617628251975
step: 370400
Step: 370400 -- Time: 127 => ppl: 11.460643995319549
step: 370500
Step: 370500 -- Time: 157 => ppl: 10.456807423179255
step: 370600
Step: 370600 -- Time: 186 => ppl: 10.262371862681386
step: 370700
Step: 370700 -- Time