In [1]:
import os
import json
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer, BertConfig
import pickle
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import  Input, Embedding, LSTM, Dense, Bidirectional, Dropout, MultiHeadAttention
from tensorflow.keras.metrics import Precision, Recall, BinaryAccuracy

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
     tf.config.experimental.set_memory_growth(gpu, True)

## Data Preprocessing

In [3]:
training_data_path = ["dataset/train1.jsonl", "dataset/train2.jsonl"]
checkpoint_dir = "logs"
output_dir = "output/"

In [4]:
# Data loading

def load_data(file_path):
    data = []
    for file in file_path:
        with open(file, 'r', encoding='utf-8') as f:
            data.extend([json.loads(line) for line in f])
    return data

In [5]:
# Define MAX_LENGTH

MAX_LENGTH = 60


In [6]:
def clean_data(data):
    sentences = []
    valid_sentences = []

    for line in data:
        split_list = [part.strip() for part in line['stripped_sentence'].split('. ') if part.strip()]
        for i in range(len(split_list)-1):
            split_list[i] = split_list[i]+'.'  
        sentences.extend(split_list)

    for sentence in sentences:
        if 4 < len(sentence.split()) <= 50:
            if any(token in sentence for token in ['~<X>', '(<X>)', '<X>']):
                if "  \\\\" not in sentence: # remove sentences with "   \\\\"
                    valid_sentences.append(sentence)

    return valid_sentences


In [7]:
# Data Preprocessing

def preprocess_data(data):
    tokenizer  = BertTokenizer.from_pretrained("bert-base-uncased")
    
    input_ids, attention_masks, target_tags = [], [], []
    
    for sentence in data:

        masked_sentence = sentence.lower().replace('~<x>', '[MASK]').replace('(<x>)', '[MASK]').replace('<x>', '[MASK]')
        
        encoded_dict = tokenizer(masked_sentence, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors="tf")
        
        mask_indexes = [i for i, token in enumerate(encoded_dict["input_ids"][0]) if token == tokenizer.convert_tokens_to_ids('[MASK]')]
   
        # Preprocessing input_ids
        ids_without_mask = np.delete(encoded_dict["input_ids"], mask_indexes)
        if len(ids_without_mask) < MAX_LENGTH:
            ids_without_mask = np.pad(ids_without_mask, (0, MAX_LENGTH-len(ids_without_mask)), mode='constant')
        input_ids.append([ids_without_mask])

        # Preprocessing attention_mask
        attention_without_mask = np.delete(encoded_dict["attention_mask"], mask_indexes)
        if len(attention_without_mask) < MAX_LENGTH:
            attention_without_mask = np.pad(attention_without_mask, (0, MAX_LENGTH-len(attention_without_mask)), mode='constant')
        attention_masks.append([attention_without_mask])
        
        # Preprocessing target_tags
        labels = np.zeros(MAX_LENGTH, dtype=int)
        for i in mask_indexes:
            labels[i - 1] = 1
        labels_without_mask = np.delete(labels, mask_indexes)  
        if len(labels_without_mask) < MAX_LENGTH:
            labels_without_mask = np.pad(labels_without_mask, (0, MAX_LENGTH-len(labels_without_mask)), mode='constant')
        target_tags.append([labels_without_mask])

    return tf.concat(input_ids, axis=0), tf.concat(attention_masks, axis=0), tf.concat(target_tags, axis=0)
    

## Model Implementation

In [8]:
def build_model(fine_tune_bert=False):
    bert_model = TFBertModel.from_pretrained("bert-base-uncased")

    # Freeze BERT layers for fine-tuning
    if fine_tune_bert:
        for layer in bert_model.layers:
            layer.trainable = True
    else:
        for layer in bert_model.layers:
            layer.trainable = False


    input_ids = Input(shape=(None,), dtype=tf.int32, name="input_ids")
    attention_mask = Input(shape=(None,), dtype=tf.int32, name="attention_mask")
    
    bert_output = bert_model([input_ids, attention_mask])[0]

    dense_layer1 = tf.keras.layers.Dense(512, activation='relu')(bert_output)
    dropout_layer1 = tf.keras.layers.Dropout(0.5)(dense_layer1)
    
    dense_layer2 = tf.keras.layers.Dense(256, activation='relu')(dropout_layer1)
    dropout_layer2 = tf.keras.layers.Dropout(0.5)(dense_layer2)

    dense_layer3 = tf.keras.layers.Dense(128, activation='relu')(dropout_layer2)
    dropout_layer3 = tf.keras.layers.Dropout(0.5)(dense_layer3)
    
    prediction = Dense(1, activation="sigmoid", name="prediction")(dropout_layer3)
    
    model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=prediction)

    optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)
    loss = tf.keras.losses.BinaryCrossentropy()
    metrics = [F1Score(name='F1Score'), Precision(name='precision'), Recall(name='recall')]

    # Compile model with binary cross-entropy loss
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    
    return model


In [9]:
def train_model(model, train_preprocessed_data, val_preprocessed_data, epochs=5, resume_training=False, save_checkpoint=False, early_stopping=False):
    train_input_ids, train_attention_masks, train_target_tags = train_preprocessed_data
    val_input_ids, val_attention_masks, val_target_tags = val_preprocessed_data

    tf.random.set_seed(12)

    # Early stopping callback to monitor validation accuracy
    if early_stopping:
        callbacks = [tf.keras.callbacks.EarlyStopping(monitor="F1Score", patience=2)]
    else:
        callbacks = []
    
    # Load weights from the latest checkpoint if resume_training is True
    if resume_training:
        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
        if latest_checkpoint:
            model.load_weights(latest_checkpoint)
            print(f"Loaded weights from checkpoint: {latest_checkpoint}")
        else:
            print("No checkpoint found for resuming training.")

    # Create checkpoint callback if saving is enabled
    if save_checkpoint:
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(checkpoint_dir, "ckpt"),
            save_weights_only=True,
            save_best_only=True,
            verbose=1
        )
        callbacks.append(checkpoint_callback)

    # Train the model with callbacks
    history = model.fit(
        x=[train_input_ids, train_attention_masks],
        y=train_target_tags,
        epochs=epochs,
        validation_data=([val_input_ids, val_attention_masks], val_target_tags),
        callbacks=callbacks,
        shuffle=True
    )

    return history


In [10]:
# Define F1 score metric

class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision = Precision()
        self.recall = Recall()

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision.update_state(y_true, y_pred)
        self.recall.update_state(y_true, y_pred)

    def result(self):
        precision_result = self.precision.result()
        recall_result = self.recall.result()
        return 2 * ((precision_result * recall_result) / (precision_result + recall_result + tf.keras.backend.epsilon()))

    def reset_state(self):
        self.precision.reset_state()
        self.recall.reset_state()

## Data Loading and Model Training

In [11]:
data = load_data(training_data_path)

In [12]:
cleaned_data = clean_data(data)
train_data, val_data = train_test_split(cleaned_data, test_size=0.2, random_state=42)

In [13]:
train_preprocessed_data = None
if os.path.isfile(output_dir + 'train_tokenized_data.pkl'):
    with open(output_dir + 'train_tokenized_data.pkl', 'rb') as f:
        train_preprocessed_data = pickle.load(f)
else:
    train_preprocessed_data = preprocess_data(train_data)
    with open(output_dir + 'train_tokenized_data.pkl', 'wb') as f:
        pickle.dump(train_preprocessed_data, f)
      
val_preprocessed_data = None
if os.path.isfile(output_dir + 'val_tokenized_data.pkl'):
    with open(output_dir + 'val_tokenized_data.pkl', 'rb') as f:
        val_preprocessed_data = pickle.load(f)
else:
    val_preprocessed_data = preprocess_data(val_data)
    with open(output_dir + 'val_tokenized_data.pkl', 'wb') as f:
        pickle.dump(val_preprocessed_data, f)

In [14]:
model = build_model(fine_tune_bert=True)
model.summary()

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, None)]       0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 tf_bert_model (TFBertModel)    TFBaseModelOutputWi  109482240   ['input_ids[0][0]',              
                                thPoolingAndCrossAt               'attention_mask[0][0]']         
                                tentions(last_hidde                                               
                                n_state=(None, None                                           

In [17]:
history = train_model(model, train_preprocessed_data, val_preprocessed_data, epochs=2, save_checkpoint=True, resume_training=True, early_stopping=True)

Loaded weights from checkpoint: logs\ckpt
Epoch 1/2
Epoch 1: val_loss improved from inf to 0.01479, saving model to logs\ckpt
Epoch 2/2
Epoch 2: val_loss improved from 0.01479 to 0.01449, saving model to logs\ckpt


In [18]:
# Save the model along with custom objects

model.save(output_dir + 'model/token_insertion_model')

tf.keras.models.save_model(model, filepath = output_dir + 'model/token_insertion_model.h5', 
                           include_optimizer=True, save_format='h5')



INFO:tensorflow:Assets written to: output/model/token_insertion_model\assets


INFO:tensorflow:Assets written to: output/model/token_insertion_model\assets
