In [1]:
import os

In [2]:
%pwd

'c:\\Users\\ankit.rohilla\\Documents\\fake_news_classification\\research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'c:\\Users\\ankit.rohilla\\Documents\\fake_news_classification'

In [5]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dir: Path
    data_path: Path
    model_ckpt: Path
    maxlen: int
    n_epochs: int
    metrics: str

In [6]:
from text_classifier.constants import *
from text_classifier.utils.common import read_yaml, create_directories

In [7]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])


    
    def get_model_trainer_config(self) -> ModelTrainerConfig:
        config = self.config.model_trainer
        params = self.params.TrainingArguments

        create_directories([config.root_dir])

        model_trainer_config = ModelTrainerConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            model_ckpt = config.model_ckpt,
            maxlen = params.maxlen,
            n_epochs = params.n_epochs,
            metrics = params.metrics,
        )

        return model_trainer_config

In [8]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout,Input
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.ops.numpy_ops import np_utils
from transformers import BertModel, TFBertModel 
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers
import torch
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
class ModelTrainer:
    def __init__(self, config: ModelTrainerConfig):
        self.config = config

    def create_model(self, bert_model):
        input_ids=Input(shape=(self.config.maxlen,),dtype=tf.int32)
        input_mask=Input(shape=(self.config.maxlen,),dtype=tf.int32)
        bert_layer=bert_model([input_ids,input_mask])[1]
        x=Dropout(0.5)(bert_layer)
        x=Dense(64,activation="tanh")(x)
        x=Dropout(0.2)(x)
        x=Dense(1,activation="sigmoid")(x)
        model = Model(inputs=[input_ids, input_mask], outputs=x)
        return model
    
    def train(self):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        bert_model = TFBertModel.from_pretrained(self.config.model_ckpt)
        with tf.device(device):
            model = self.create_model(bert_model)
        print(model.summary())
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=0.01,
            decay_steps=10000,
            decay_rate=0.9)
        optimizer = Adam(learning_rate=lr_schedule, epsilon=1e-08,clipnorm=1.0)
        model.compile(optimizer = optimizer, loss = 'binary_crossentropy', metrics = self.config.metrics)

        callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='max', verbose=1, 
                                                    patience=50,baseline=0.4,min_delta=0.0001,
                                                    restore_best_weights=False)
        with open(os.path.join(self.config.data_path,'train_encodings.pkl'), 'rb') as file:
            loaded_data = pickle.load(file)
        X_train = loaded_data['X_train']
        y_train = loaded_data['y_train']
        
        history = model.fit(x = {'input_1':X_train['input_ids'],'input_2':X_train['attention_mask']}, 
                            y = y_train, epochs=self.config.n_epochs, validation_split = 0.2, 
                            batch_size = 30, callbacks=[callback])
       
        ## Save model
        model.save_pretrained(os.path.join(self.config.root_dir,"bert-news-classify-model"))

        # Save history
        with open(os.path.join(self.config.root_dir,'history'), 'wb') as file:
            pickle.dump(history, file)




In [10]:
try:
    config = ConfigurationManager()
    model_trainer_config = config.get_model_trainer_config()
    model_trainer_config = ModelTrainer(config=model_trainer_config)
    model_trainer_config.train()
except Exception as e:
    raise e

[2023-09-06 23:26:38,532: INFO: common: yaml file: config\config.yaml loaded successfully]
[2023-09-06 23:26:38,534: INFO: common: yaml file: params.yaml loaded successfully]
[2023-09-06 23:26:38,536: INFO: common: created directory at: artifacts]
[2023-09-06 23:26:38,537: INFO: common: created directory at: artifacts/model_trainer]


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 120)]                0         []                            
                                                                                                  
 input_2 (InputLayer)        [(None, 120)]                0         []                            
                                                                                                  
 tf_bert_model (TFBertModel  TFBaseModelOutputWithPooli   1094822   ['input_1[0][0]',             
 )                           ngAndCrossAttentions(last_   40         'input_2[0][0]']             
                             hidden_state=(None, 120, 7                                           
                             68),                                                             

KeyboardInterrupt: 