In [10]:
def setup_gpu_strategy():
    try:
        # Configure GPU memory growth
        physical_devices = tf.config.list_physical_devices('GPU')
        if physical_devices:
            for device in physical_devices:
                try:
                    tf.config.experimental.set_memory_growth(device, True)
                except RuntimeError as e:
                    print(f"Could not set memory growth for {device.name}: {e}")
            print(f"Found {len(physical_devices)} GPU(s). GPU configuration successful.")

            # Create and return GPU strategy
            strategy = tf.distribute.MirroredStrategy()
            print(f"Number of devices: {strategy.num_replicas_in_sync}")
            return strategy
        else:
            print("No GPUs found. Falling back to CPU strategy.")
            return tf.distribute.OneDeviceStrategy(device="/cpu:0")
    except RuntimeError as e:
        print(f"GPU configuration failed: {e}")
        print("Falling back to default strategy.")
        return tf.distribute.OneDeviceStrategy(device="/cpu:0")

In [11]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import *
from datetime import datetime
import os
from transformers import AutoConfig, TFBertModel
import logging

In [12]:
# Configuration Constants
FILE_NAME = '/content/chatbot-med-df.tfrecord'
MAX_LENGTH = 256
BATCH_SIZE = 64
SUBSET_SIZE = 10000
VAL_SPLIT = 0.1
TEST_SPLIT = 0.1
NUM_EPOCHS = 5
BERT_MODEL_NAME = 'dmis-lab/biobert-base-cased-v1.1'
MODEL_SAVE_PATH = '/saved_model/model.keras'

In [13]:
def load_large_dataset(filename, batch_size, subset_size=None, val_split=0.1, test_split=0.1):
    # Feature description for parsing the TFRecord
    feature_description = {
        'query_input_ids': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
        'query_attention_mask': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
        'response_input_ids': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
        'response_attention_mask': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    }

    def parse_example(example):
      parsed = tf.io.parse_single_example(example, feature_description)

      # Define inputs and targets
      inputs = {
          'query_input_ids': tf.cast(parsed['query_input_ids'], tf.int32),
          'query_attention_mask': tf.cast(parsed['query_attention_mask'], tf.int32),
      }

      targets = tf.cast(parsed['response_input_ids'], tf.int32)
      response_attention_mask = tf.cast(parsed['response_attention_mask'], tf.int32)

      # Apply response_attention_mask to mask out the padding tokens in the targets
      targets = targets * response_attention_mask

      return inputs, targets

    # Load the TFRecord dataset
    raw_dataset = tf.data.TFRecordDataset([filename])
    parsed_dataset = raw_dataset.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)

    # Determine dataset size
    total_size = sum(1 for _ in parsed_dataset)
    print(f"Original dataset size: {total_size}")

    # Subset handling
    if subset_size and subset_size < total_size:
        total_size = subset_size
        parsed_dataset = parsed_dataset.take(subset_size)
        print(f"Taking subset of size: {subset_size}")

    # Calculate split sizes
    train_size = int(total_size * (1 - val_split - test_split))
    val_size = int(total_size * val_split)
    test_size = total_size - train_size - val_size

    # Splits
    train_dataset = parsed_dataset.take(train_size).shuffle(buffer_size=10000).repeat()
    train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    remaining_dataset = parsed_dataset.skip(train_size)

    val_dataset = remaining_dataset.take(val_size).repeat()
    val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    test_dataset = remaining_dataset.skip(val_size).take(test_size).repeat()
    test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    print(f"Dataset successfully loaded:")
    print(f"  - Total size: {total_size}")
    print(f"  - Train size: {train_size}")
    print(f"  - Validation size: {val_size}")
    print(f"  - Test size: {test_size}")

    return train_dataset, val_dataset, test_dataset, train_size, val_size, test_size

In [14]:
@tf.keras.utils.register_keras_serializable(package="Custom")
class BioBertEncoder(tf.keras.layers.Layer):
    def __init__(self, bert_model_name, trainable=False, **kwargs):
        super().__init__(**kwargs)
        self.bert_model_name = bert_model_name
        self.trainable = trainable
        self.bert_model = None

    def build(self, input_shape):
        # Initialize the BERT model
        self.bert_config = AutoConfig.from_pretrained(self.bert_model_name)
        self.bert_model = TFBertModel.from_pretrained(
            self.bert_model_name, config=self.bert_config, from_pt=True
        )
        self.bert_model.trainable = self.trainable
        super().build(input_shape)

    def call(self, inputs, training=False):
        input_ids, attention_mask = inputs
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask, training=training)
        return outputs.last_hidden_state

    def get_config(self):
        config = super().get_config()
        config.update({"bert_model_name": self.bert_model_name, "trainable": self.trainable})
        return config

In [15]:
class BioBertCnnBiLSTM:
    def __init__(self, bert_model_name, vocab_size):
        self.bert_model_name = bert_model_name
        self.vocab_size = vocab_size

    def build(self, config):
        # Define inputs
        query_input_ids = tf.keras.Input(shape=(MAX_LENGTH,), dtype=tf.int32, name='query_input_ids')
        query_attention_mask = tf.keras.Input(shape=(MAX_LENGTH,), dtype=tf.int32, name='query_attention_mask')

        # Encoder
        bert_layer = BioBertEncoder(self.bert_model_name, trainable=False)
        query_bert_output = bert_layer([query_input_ids, query_attention_mask])

        # CNN Layer
        kernel_sizes = config['kernels']
        query_cnn = []
        for kernel_size in kernel_sizes:
            conv = tf.keras.layers.Conv1D(
                filters=config['cnn_filters'],
                kernel_size=kernel_size,
                padding='same',
                activation=config['activation'],
                kernel_regularizer=l2(config['cnn_regularization'])
            )(query_bert_output)
            norm = tf.keras.layers.LayerNormalization()(conv)
            bn = tf.keras.layers.BatchNormalization()(norm)
            dropout = tf.keras.layers.Dropout(config['dropout_cnn'])(bn)
            query_cnn.append(dropout)
        query_cnn = tf.keras.layers.Concatenate()(query_cnn)

        # LSTM Layer with Bidirectional and Residual Connection
        query_lstm = tf.keras.layers.Bidirectional(
            tf.keras.layers.LSTM(config['lstm_units'], dropout=config['dropout_lstm'], return_sequences=True)
        )(query_cnn)
        query_lstm = tf.keras.layers.LayerNormalization()(query_lstm)

        # Attention Mechanism
        attention = tf.keras.layers.Attention()([query_lstm, query_lstm])
        query_lstm = tf.keras.layers.Concatenate()([query_lstm, attention])

        output = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(
                self.vocab_size,
                activation='softmax'
            )
        )(query_lstm)

        # Model Creation
        model = tf.keras.Model(
            inputs={'query_input_ids': query_input_ids, 'query_attention_mask': query_attention_mask},
            outputs=output
        )

        # Optimizer with learning rate schedule
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=config['learning_rate'],
            decay_steps=100000,
            decay_rate=0.96,
            staircase=True
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipvalue=1.0)

        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='sum_over_batch_size'),
            metrics=['accuracy', tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
        )

        return model

In [16]:
def save_model_artifacts(model, base_path):
    os.makedirs(base_path, exist_ok=True)
    model_path = os.path.join(base_path, 'model.keras')
    model.save(model_path, save_format='keras')
    print(f"Model saved to {model_path}")

In [17]:
def train_chatbot(file_path):
    config = {
        'cnn_filters': 64,
        'kernels': [3, 5, 7],
        'dropout_cnn': 0.3,
        'cnn_regularization': 0.02,
        'lstm_units': 256,
        'dropout_lstm': 0.5,
        'lstm_recurrent_regularization': 5e-5,
        'learning_rate': 5e-5,
        'activation': 'relu'
    }

    # Load and split dataset
    print("Loading datasets...")
    train_dataset, val_dataset, test_dataset, train_size, val_size, test_size = load_large_dataset(
        filename=file_path,
        batch_size=BATCH_SIZE,
        subset_size=SUBSET_SIZE,
        val_split=VAL_SPLIT,
        test_split=TEST_SPLIT
    )

    # Calculate steps per epoch
    train_size = (train_size // BATCH_SIZE) * BATCH_SIZE
    val_size = (val_size // BATCH_SIZE) * BATCH_SIZE

    train_steps = train_size // BATCH_SIZE
    val_steps = val_size // BATCH_SIZE

    # Initialize and build model
    model_builder = BioBertCnnBiLSTM(BERT_MODEL_NAME, vocab_size=28996)
    model = model_builder.build(config)
    model.summary()

    # Check if model weights already exist for resumption
    checkpoint_path = 'checkpoint/model_checkpoint.keras'
    if os.path.exists(checkpoint_path):
        print(f"Resuming from saved checkpoint: {checkpoint_path}")
        model.load_weights(checkpoint_path)
    else:
        print("No checkpoint found, starting training from scratch.")

    # Create callbacks
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_best_only=True, save_weights_only=False),
    ]

    # Train model
    history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=NUM_EPOCHS,
            callbacks=callbacks,
            verbose=1,
            steps_per_epoch=train_steps,
            validation_steps=val_steps
        )


    # Evaluate on test set
    test_results = model.evaluate(test_dataset)
    print(f"Test results: Loss = {test_results[0]:.4f}, Accuracy = {test_results[1]:.4f}")

    print("\nSaving model artifacts...")
    save_model_artifacts(model, MODEL_SAVE_PATH)

    return model, test_results

In [18]:
 if __name__ == "__main__":
    logging.getLogger("tensorflow").setLevel(logging.ERROR)
    logging.getLogger("transformers").setLevel(logging.CRITICAL)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    file_path = FILE_NAME
    model, test_results = train_chatbot(file_path)

Loading datasets...
Original dataset size: 262886
Taking subset of size: 30000
Dataset successfully loaded:
  - Total size: 30000
  - Train size: 24000
  - Validation size: 3000
  - Test size: 3000


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

No checkpoint found, starting training from scratch.
Epoch 1/3
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1249s[0m 2s/step - accuracy: 0.4438 - loss: 10.8119 - sparse_top_k_categorical_accuracy: 0.4845 - val_accuracy: 0.5329 - val_loss: 6.0196 - val_sparse_top_k_categorical_accuracy: 0.5923
Epoch 2/3
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1219s[0m 2s/step - accuracy: 0.5321 - loss: 5.5447 - sparse_top_k_categorical_accuracy: 0.5914 - val_accuracy: 0.5347 - val_loss: 4.4789 - val_sparse_top_k_categorical_accuracy: 0.5944
Epoch 3/3
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1220s[0m 2s/step - accuracy: 0.5330 - loss: 4.2936 - sparse_top_k_categorical_accuracy: 0.5928 - val_accuracy: 0.5355 - val_loss: 3.8723 - val_sparse_top_k_categorical_accuracy: 0.5960
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 874ms/step - accuracy: 0.5395 - loss: 3.8413 - sparse_top_k_categorical_accuracy: 0.5989


  self.gen.throw(typ, value, traceback)


Test results: Loss = 3.8588, Accuracy = 0.5364

Saving model artifacts...
Model saved to /saved_model/model.keras
