In [None]:
!pip install transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
import tensorflow as tf
import numpy as np
import datetime
import warnings
import pickle as pkl
import json
from tqdm import tqdm
warnings.filterwarnings("ignore")
from transformers import BertConfig,BertTokenizerFast
from transformers import TFBertModel
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
import os
from collections import OrderedDict
max_word_arg_head_dist=30
ctx_len = 5
word_embed_dim=50
word_density=10
max_length=60
MAX_SEQ_LENGTH=60
BATCH_SIZE=32
max_word_arg_head_dist = 30
dist_vocab_size = 2 * max_word_arg_head_dist + 1
ignore_rel_list = ['None', 'NA', 'Other']

In [None]:
USE_TPU=True
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except Exception as ex:
  print(ex)
  USE_TPU=False

print("        USE_TPU:", USE_TPU)
print("Eager Execution:", tf.executing_eagerly())

In [None]:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

##Load Pretrained Model

In [None]:
with tpu_strategy.scope():
  mlm_model = TFBertModel.from_pretrained('bert-base-uncased',output_attentions=True)

##Custom Layer Initialization for RExAS

In [None]:
class pool_sum(tf.keras.layers.Layer):

    def __init__(self, axis, **kwargs):
        super(pool_sum, self).__init__(**kwargs)
        self.axis = axis
    def build(self, input_shape):
        super(pool_sum, self).build(input_shape)

    def call(self, x):
        return tf.reduce_sum(x, axis=self.axis)

    def compute_output_shape(self, input_shape):
        return input_shape[:self.axis]+input_shape[self.axis+1:]

In [None]:
class pool_mean(tf.keras.layers.Layer):

    def __init__(self, axis, **kwargs):
        super(pool_mean, self).__init__(**kwargs)
        self.axis = axis
    def build(self, input_shape):
        super(pool_mean, self).build(input_shape)

    def call(self, x):
        return tf.reduce_mean(x, axis=self.axis)

    def compute_output_shape(self, input_shape):
        return input_shape[:self.axis]+input_shape[self.axis+1:]

In [None]:
class stack_layer(tf.keras.layers.Layer):

    def __init__(self, axis, **kwargs):
        super(stack_layer, self).__init__(**kwargs)
        self.axis = axis
    def build(self, input_shape):
        super(stack_layer, self).build(input_shape)

    def call(self, x):
        return tf.stack(x, axis=self.axis)

    def compute_output_shape(self, input_shape):
        return input_shape[:self.axis]+input_shape[self.axis+1:]

In [None]:
class Multiplication_Layer(tf.keras.layers.Layer):
  def __init__(self,activation=None, **kwargs):
        super(Multiplication_Layer, self).__init__(**kwargs)
  
  def build(self,input_shape):
        super(Multiplication_Layer, self).build(input_shape)
  def call(self, x):
        x[1]=tf.expand_dims(x[1],-1)
        mat_mul=tf.multiply(x[0],x[1])
        return mat_mul

In [None]:
class Reshape_Layer(tf.keras.layers.Layer):

    def __init__(self, axis=None, **kwargs):
        super(Reshape_Layer, self).__init__(**kwargs)
        self.axis = axis
    def build(self, input_shape):
        super(Reshape_Layer, self).build(input_shape)

    def call(self, x):
        attn_head=[]
        for i in range(len(x)):
          pool_attn=tf.reduce_mean(x[i],axis=1)
          pool_attn_shape=tf.reshape(pool_attn,shape=[-1,60*60])
          pool_attn_shape=tf.nn.softmax(pool_attn_shape)
          attn_head.append(pool_attn_shape)
        pool_attn_stacked=tf.stack(attn_head,axis=1)
        reshaped_out=tf.reduce_mean(pool_attn_stacked,axis=1)
        return reshaped_out

    def compute_output_shape(self, input_shape):
        return input_shape[:self.axis]+input_shape[self.axis+1:]

In [None]:
class Attention_layer_graph(tf.keras.layers.Layer):
  def __init__(self,head,**kawargs):
    super(Attention_layer_graph, self).__init__()
    self.head=head
  
  def build(self,input_shape):
    #self.w_query = self.add_weight(shape=(self.head,input_shape[0][-1],input_shape[0][-1]),initializer="glorot_uniform",name="query_weights",trainable=True)
    self.bert_weight=self.add_weight(shape=(self.head,input_shape[-1],input_shape[-1]),initializer="glorot_uniform",name="bert_query",trainable=True)
    #self.w_key =  self.add_weight(shape=(self.head,input_shape[-1],input_shape[-1]),initializer="glorot_uniform",name="key_weights",trainable=True)
  
  def call(self,X):
    Y=[]
    for i in range(self.head):
      query=tf.matmul(X,self.bert_weight[i])
      #key= tf.matmul(X,self.w_key[i])
      Y.append(tf.matmul(query,X,transpose_b=True))
    Y = tf.stack(Y,axis=1)
    Y = Y/tf.sqrt(tf.cast(tf.shape(Y), tf.float32)[-1])
    Y = tf.nn.softmax(Y)
    return Y

In [None]:
class Normal_adjacency(tf.keras.layers.Layer):
    
  def __init__(self, **kwargs):
    super(Normal_adjacency, self).__init__()
    
  def call(self, A):
    #returns vector a_prime
    #creating degree normalized tensors from the input tensor
    I = tf.eye(A.get_shape().as_list()[-1])
    A=A+I
    d1 = tf.reduce_sum(A, axis=-2)+ tf.keras.backend.epsilon()
    #print(d1.shape)
    d1_inv = tf.pow(d1, -0.5)
        
    d2 = tf.reduce_sum(A, axis=-1)+ tf.keras.backend.epsilon()
    d2_inv = tf.pow(d2, -0.5)
        
    d1_inv = tf.linalg.diag(d1_inv)
    #print(d1_inv.shape)
    d2_inv = tf.linalg.diag(d2_inv)
    #computing a_prime
    a_prime = tf.matmul(d1_inv, A, transpose_a=True)
    #print(a_prime.shape)
    a_prime = tf.matmul(a_prime, d2_inv, transpose_a=True)
    #a_prime = tf.eye(a_prime.get_shape().as_list()[-1]) - a_prime
    return a_prime
    
  def compute_mask(self, inputs, mask=None):
    return mask

  def get_config(self):
    config = super().get_config().copy()
    return config

In [None]:
class Graph_Layer(tf.keras.layers.Layer):
  def __init__(self, output_dim,feature_regularizer=None, **kwargs):
        self.output_dim = output_dim
        self.feature_regularizer=feature_regularizer
        super(Graph_Layer, self).__init__(**kwargs)
  
  def build(self,input_shape):
        self.fkernel = self.add_weight(name='feature_kernel',
                                      shape=(input_shape[0][1],input_shape[1][-1], self.output_dim),
                                      initializer='glorot_uniform',regularizer=self.feature_regularizer,trainable=True)
        super(Graph_Layer, self).build(input_shape)
  def call(self, x):
        X_ = tf.keras.backend.batch_dot(x[0], x[1], axes=[-1,1])
        mat_mult=[]
        for j in range(x[0].shape[1]):
          Res=tf.matmul(X_[:,j,:,:],self.fkernel[j])
          mat_mult.append(Res)
        Res=tf.stack(mat_mult,axis=1)
        A1=tf.nn.relu(Res)
        Y_=tf.reduce_sum(A1,axis=1)
        return Y_

In [None]:
def get_model(max_seq_length):
  input_word_ids = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), dtype=tf.int32,name="input_word_ids")
  token_type_ids = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), dtype=tf.int32,name="token_type_ids")
  attention_mask = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), dtype=tf.int32,name="attention_mask")
  input_entity_indicator=tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,),name="entity_indicator")
  input_entity_left_dist=tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,),name="entity_left_dist")
  input_entity_right_dist=tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,),name="entity_right_dist")
  input_entity_left_mask=tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,),name="entity_left_mask")
  input_entity_right_mask=tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,),name="entity_right_mask")

  entity_indicator = tf.keras.layers.Embedding(output_dim=10,input_dim=4,input_length=MAX_SEQ_LENGTH, trainable=True)(input_entity_indicator)
  entity_indicator = tf.keras.layers.Dropout(.5)(entity_indicator)

  entity_left_dist_embed=tf.keras.layers.Embedding(output_dim=5,input_dim=dist_vocab_size,input_length=MAX_SEQ_LENGTH,trainable=True)(input_entity_left_dist)
  entity_left_dist_embed=tf.keras.layers.Dropout(.5)(entity_left_dist_embed)
  entity_right_dist_embed=tf.keras.layers.Embedding(output_dim=5,input_dim=dist_vocab_size,input_length=MAX_SEQ_LENGTH,trainable=True)(input_entity_right_dist)
  entity_right_dist_embed=tf.keras.layers.Dropout(.5)(entity_right_dist_embed)
  mlm_model.layers[0].trainable=False
  sequence_output=mlm_model([input_word_ids,token_type_ids,attention_mask])
  embed_concat=tf.keras.layers.Concatenate()([sequence_output[0],entity_indicator,entity_left_dist_embed,entity_right_dist_embed])
  embed_concat=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True))(embed_concat)
  adjacency_=Attention_layer_graph(4)(embed_concat)
  normalized_adjacency_=Normal_adjacency()(adjacency_)
  graph_out_1=Graph_Layer(256)([normalized_adjacency_,embed_concat])
  graph_out_2=Graph_Layer(512)([normalized_adjacency_,graph_out_1])
  entity_1=Multiplication_Layer()([embed_concat,input_entity_left_mask])
  entity_2=Multiplication_Layer()([embed_concat,input_entity_right_mask])
  concatenated_output=tf.keras.layers.Concatenate()([entity_1,entity_2])
  output_1=pool_sum(axis=1)(embed_concat)
  output_2=tf.keras.layers.Dense(512,activation="relu")(output_1)
  output_f=tf.keras.layers.Dropout(0.5)(output_2)
  relation_out=tf.keras.layers.Dense(5,activation="softmax")(output_f)

  model=tf.keras.Model(inputs=[input_word_ids,token_type_ids,attention_mask,input_entity_indicator,input_entity_left_dist,input_entity_right_dist,input_entity_left_mask,input_entity_right_mask],outputs=relation_out)


  return model

In [None]:
with tpu_strategy.scope():
  model=get_model(100)
  optimizer = tf.keras.optimizers.Adam()
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.CategoricalAccuracy('training_accuracy', dtype=tf.float32)

##Data Preparation

In [None]:
def contains(sub, pri):
    M, N = len(pri), len(sub)
    i, LAST = 0, M-N+1
    while True:
        try:
            found = pri.index(sub[0], i, LAST) # find first elem in sub
        except ValueError:
            return False
        if pri[found:found+N] == sub:
            return [found, found+N-1]
        else:
            i = found+1

In [None]:
print("-------------Data Preprocessing Started-------------")
fl_train=open("/content/drive/MyDrive/GIDS/gids_data/gids_train.json","rb")
fl_valid=open("/content/drive/MyDrive/GIDS/gids_data/gids_dev.json","rb")
fl_test=open("/content/drive/MyDrive/GIDS/gids_data/gids_test.json","rb")

train_data=fl_train.readlines()
valid_data=fl_valid.readlines()
test_data=fl_test.readlines()


dict_rel={}
for i,j in enumerate(open("/content/drive/MyDrive/GIDS/Resource/relations_gids.txt","r")):
  dict_rel[j.strip()]=i
relation_out=len(dict_rel)

configuration = BertConfig(num_labels=relation_out)


def bytes_to_list(fl):
  data=[]
  for i,j in enumerate(fl):
    dict_1=json.loads(j)
    data.append(dict_1)
  return data

In [None]:
def get_entity_indicator(data):
  cnt_1=0
  cnt_2=0
  entity_left_mask_matrix=[]
  entity_right_mask_matrix=[]
  entity_indicator_list=[]
  data_text=[]
  rel_out=[]
  data_new=[]
  for i,j in enumerate(tqdm(data)):
    sent=' '.join(j['sent'])
    sent=sent.lower()
    words=tokenizer.tokenize(sent,add_special_tokens=True)
    arg1_text=tokenizer.tokenize(j['sub'].lower())
    arg2_text=tokenizer.tokenize(j['obj'].lower())
    arg1_mask = [0]*60
    arg2_mask = [0]*60
    ent_ind=[3]*60
    sub_index_1=contains(arg1_text,words)
    sub_index_2=contains(arg2_text,words)
    if type(sub_index_1)==bool or type(sub_index_2)==bool:
      continue
    data_new.append(j)
    if len(sub_index_1)>0:
      for k in range(sub_index_1[0],sub_index_1[1]+1):
        if k<60:
          arg1_mask[k]=1
          ent_ind[k]=1
    if len(sub_index_2)>0:
      for k in range(sub_index_2[0],sub_index_2[1]+1):
        if k<60:
          arg2_mask[k]=1
          ent_ind[k]=2
    entity_left_mask_matrix.append(arg1_mask)
    entity_right_mask_matrix.append(arg2_mask)
    entity_indicator_list.append(ent_ind)
    data_text.append(sent)
    prop_vec=j['rel']
    if prop_vec in dict_rel:
      prop_index=dict_rel[prop_vec]
      rel_out.append(prop_index)
    elif prop_vec not in dict_rel:
      prop_index=dict_rel['NA']
      rel_out.append(prop_index)
  return entity_indicator_list,entity_left_mask_matrix, entity_right_mask_matrix,data_text,rel_out,data_new                       

In [None]:
def get_entity_dist(data):
  entity_left_dist_matrix=[]
  entity_right_dist_matrix=[]
  for i,j in enumerate(tqdm(data)):
    sent=' '.join(j['sent'])
    words=tokenizer.tokenize(sent,add_special_tokens=True)
    arg1_text=tokenizer.tokenize(j['sub'])
    arg2_text=tokenizer.tokenize(j['obj'])
    sub_index_1=contains(arg1_text,words)
    sub_index_2=contains(arg2_text,words)
    if type(sub_index_1)==bool or type(sub_index_2)==bool:
      continue
    arg1_start=sub_index_1[0]
    arg1_end=sub_index_1[1]
    arg2_start=sub_index_2[0]
    arg2_end=sub_index_2[1]
    arg1_head_dist_lst = [0]*60
    arg2_head_dist_lst = [0]*60
    for ind in range(0, len(words)):
      dist = arg1_start - ind
      if dist >= 0:
        dist += 1
        dist = min(dist, max_word_arg_head_dist)
      else:
        dist *= -1
        dist = min(dist, max_word_arg_head_dist)
        dist += max_word_arg_head_dist
      if ind<60:
        arg1_head_dist_lst[ind]=dist
      dist = arg2_start - ind
      if dist >= 0:
        dist += 1
        dist = min(dist, max_word_arg_head_dist)
      else:
        dist *= -1
        dist = min(dist, max_word_arg_head_dist)
        dist += max_word_arg_head_dist
      if ind<60:
        arg2_head_dist_lst[ind]=dist
    for ind in range(arg1_start, arg1_end + 1):
      if ind<60:
        arg1_head_dist_lst[ind] = 1
    for ind in range(arg2_start, arg2_end + 1):
      if ind<60:
        arg2_head_dist_lst[ind] = 1
    if(len(arg1_head_dist_lst)>max_length):
      arg1_head_dist_lst=arg1_head_dist_lst[:max_length]
    if(len(arg2_head_dist_lst)>max_length):
      arg2_head_dist_lst=arg2_head_dist_lst[:max_length]
    entity_left_dist_matrix.append(arg1_head_dist_lst)
    entity_right_dist_matrix.append(arg2_head_dist_lst)
  return entity_left_dist_matrix, entity_right_dist_matrix

In [None]:
training_input=bytes_to_list(train_data)
validation_input=bytes_to_list(valid_data)
test_input=bytes_to_list(test_data)

In [None]:
def get_ragged_tensor_representation(encodings,entity_indicator,left_entity_dist,right_entity_dist,left_entity,right_entity):
  encodings['input_ids']=tf.dtypes.cast(tf.constant(encodings['input_ids']),tf.int32)
  encodings['token_type_ids']=tf.dtypes.cast(tf.constant(encodings['token_type_ids']),tf.int32)
  encodings['attention_mask']=tf.dtypes.cast(tf.constant(encodings['attention_mask']),tf.int32)
  entity_indicator=tf.dtypes.cast(tf.constant(entity_indicator),tf.int32)
  left_entity_dist=tf.dtypes.cast(tf.constant(left_entity_dist),tf.int32)
  right_entity_dist=tf.dtypes.cast(tf.constant(right_entity_dist),tf.int32)
  left_entity=tf.dtypes.cast(tf.constant(left_entity),tf.int32)
  right_entity=tf.dtypes.cast(tf.constant(right_entity),tf.int32)
  return encodings,entity_indicator,left_entity_dist,right_entity_dist,left_entity,right_entity

def dense_tensor_creation(data,entity_indicator,left_entity_dist,right_entity_dist,left_entity,right_entity,y):
  input_ids=data['input_ids']
  token_type_ids=data['token_type_ids']
  attention_mask=data['attention_mask']
  return (input_ids,token_type_ids,attention_mask,entity_indicator,left_entity_dist,right_entity_dist,left_entity,right_entity),y

In [None]:
training_indicator,left_training_entity,right_training_entity,training_input_text,training_rel_input,training_input_=get_entity_indicator(training_input)          
left_training_entity_dist,right_training_entity_dist=get_entity_dist(training_input)

validation_indicator,left_validation_entity,right_validation_entity,validation_input_text,validation_rel_input,validation_input_=get_entity_indicator(validation_input)
left_validation_entity_dist,right_validation_entity_dist=get_entity_dist(validation_input)

test_indicator,left_test_entity,right_test_entity,test_input_text,test_rel_input,test_input_=get_entity_indicator(test_input)
left_test_entity_dist,right_test_entity_dist=get_entity_dist(test_input)

training_encodings = tokenizer(training_input_text,truncation=True,max_length=60,padding='max_length')
training_output=tf.keras.utils.to_categorical(training_rel_input,num_classes=relation_out)

validation_encodings = tokenizer(validation_input_text,truncation=True,max_length=60,padding='max_length')
validation_output=tf.keras.utils.to_categorical(validation_rel_input,num_classes=relation_out)

test_encodings = tokenizer(test_input_text,truncation=True,max_length=60,padding='max_length')


training_input_bert,training_indicator_,left_training_entity_dist_,right_training_entity_dist_,left_training_entity_,right_training_entity_=get_ragged_tensor_representation(training_encodings,training_indicator,left_training_entity_dist,right_training_entity_dist,left_training_entity,right_training_entity)
validation_input_bert,validation_indicator_,left_validation_entity_dist_,right_validation_entity_dist_,left_validation_entity_,right_validation_entity_=get_ragged_tensor_representation(validation_encodings,validation_indicator,left_validation_entity_dist,right_validation_entity_dist,left_validation_entity,right_validation_entity)
test_input_bert,test_indicator_,left_test_entity_dist_,right_test_entity_dist_,left_test_entity_,right_test_entity_=get_ragged_tensor_representation(test_encodings,test_indicator,left_test_entity_dist,right_test_entity_dist,left_test_entity,right_test_entity)

In [None]:
test_output=tf.keras.utils.to_categorical(test_rel_input,num_classes=relation_out)
train_dataset = tf.data.Dataset.from_tensor_slices((training_input_bert,training_indicator_,left_training_entity_dist_,right_training_entity_dist_,left_training_entity_,right_training_entity_,training_output))
valid_dataset = tf.data.Dataset.from_tensor_slices((validation_input_bert,validation_indicator_,left_validation_entity_dist_,right_validation_entity_dist_,left_validation_entity_,right_validation_entity_,validation_output))
test_dataset = tf.data.Dataset.from_tensor_slices((test_input_bert,test_indicator_,left_test_entity_dist_,right_test_entity_dist_,left_test_entity_,right_test_entity_,test_output))
print("----------Data Preprocessing Completed-------------")

In [None]:
training_input=training_input_
validation_input=validation_input_
test_input=test_input_

In [None]:
def make_batches(ds,batch_size):
  return (
      ds
      .cache()
      .map(dense_tensor_creation, num_parallel_calls=tf.data.AUTOTUNE)
      .shuffle(1000)
      .repeat()
      .batch(batch_size)
      .prefetch(tf.data.AUTOTUNE))

In [None]:
def make_batches_test(ds,batch_size):
  return (
      ds
      .cache()
      .map(dense_tensor_creation, num_parallel_calls=tf.data.AUTOTUNE)
      .batch(batch_size)
      .prefetch(tf.data.AUTOTUNE))

In [None]:
batch_size=4
per_replica_batch_size = batch_size
global_batch_size=32

In [None]:
train_data = tpu_strategy.experimental_distribute_datasets_from_function(lambda _: make_batches(train_dataset,per_replica_batch_size))
valid_data = tpu_strategy.experimental_distribute_datasets_from_function(lambda _: make_batches_test(valid_dataset,per_replica_batch_size))
test_data = tpu_strategy.experimental_distribute_datasets_from_function(lambda _: make_batches_test(test_dataset,per_replica_batch_size))

## GCS Bucket Path to Save the Model

In [None]:
checkpoint_path="gs://nlp_4/RExAS_GIDs/"

## Training ReXAS Model


In [None]:
@tf.function
def train_step(iterator):
  """The step function for one training step"""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.categorical_crossentropy(
          labels, logits)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=global_batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * tpu_strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  tpu_strategy.run(step_fn, args=(next(iterator),))

@tf.function
def valid_step(iterator):
  """The step function for one training step"""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    logits = model(images, training=False)
    loss = tf.keras.losses.categorical_crossentropy(labels, logits)
    loss = tf.nn.compute_average_loss(loss, global_batch_size=global_batch_size)
    validation_loss.update_state(loss * tpu_strategy.num_replicas_in_sync)
    validation_accuracy.update_state(labels, logits)

  tpu_strategy.run(step_fn, args=(next(iterator),))

In [None]:
def get_F1(data, preds, th=0.0):
    gt_pos = 0
    pred_pos = 0
    correct_pos = 0
    for i in range(0, len(data)):
        org_rel_name = data[i]['rel']
        #print("Actual",org_rel_name)
        pred_val = np.argmax(preds[i])
        pred_rel_name = list(dict_rel)[pred_val]
        #print(preds[i].shape)
        #print("Predicted",list(dict_rel)[pred_val])
        if org_rel_name not in ignore_rel_list:
            gt_pos += 1
        if pred_rel_name not in ignore_rel_list and np.max(preds[i]) > th:
            pred_pos += 1
        if org_rel_name == pred_rel_name and pred_rel_name not in ignore_rel_list and np.max(preds[i]) > th:
            correct_pos += 1
    return pred_pos, gt_pos, correct_pos

In [None]:
steps_per_epoch=int(len(training_input)/32)+1

In [None]:
with tpu_strategy.scope():
  ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
  ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)
  if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored; Model was trained for {} steps.'.format(ckpt.optimizer.iterations.numpy()))
  else:
    print('Training from scratch!')

Training from scratch!


In [None]:
steps_per_eval = 10000 // batch_size
best_dev_acc=-1.0
count=0
test_f1_score=[]
train_iterator = iter(train_data)
valid_iterator = iter(valid_data)
test_iterator = iter(test_data)
for epoch in range(80):
  print('Epoch: {}/80'.format(epoch+1))
  start_time = datetime.datetime.now()
  for step in range(steps_per_epoch):
    train_step(train_iterator)
  print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
  training_loss.reset_states()
  training_accuracy.reset_states()
  end_time=datetime.datetime.now()
  print("Time taken for training is {}s\n".format((end_time-start_time).total_seconds()))
  print("------Validation DataSet Performance------")
  prediction_valid=model.predict(valid_data,steps=int(len(validation_input)/32)+1,verbose=1)
  pred_pos, gt_pos, correct_pos = get_F1(validation_input, prediction_valid)
  p = float(correct_pos) / (pred_pos + 1e-8)
  r = float(correct_pos) / (gt_pos +  1e-8)
  dev_acc = (2 * p * r) / (p + r +  1e-8)
  print("Now Validation Precision is {}, Recall is {},  F1-Score is  {}".format(round(p,4),round(r,4),round(dev_acc,4)))
  print("------Test DataSet Performance------")
  prediction=model.predict(test_data,steps=int(len(test_input)/32)+1,verbose=1)
  pred_pos, gt_pos, correct_pos = get_F1(test_input, prediction)
  #print(pred_pos, '\t', gt_pos, '\t', correct_pos)
  p = float(correct_pos) / (pred_pos+ 1e-8)
  r = float(correct_pos) / (gt_pos+ 1e-8)
  test_acc = (2 * p * r) / (p + r+ 1e-8)
  print("Now test Precision is {}, Recall is {},  F1-Score is  {}".format(round(p,4),round(r,4),round(test_acc,4)))
  if test_acc>best_dev_acc:
    ckpt_save_path = ckpt_manager.save()
    print("Saved checkpoint for step {}: {}".format(int(ckpt.optimizer.iterations.numpy()), ckpt_save_path))
    best_dev_acc=test_acc
    count=0
  elif test_acc<=best_dev_acc:
    count=count+1
  test_f1_score.append(test_acc)
  if count==10:
    print("Last 10 epochs F1-score didn't improve for validation dataset Training completed")
    break