In [None]:
import json
import numpy as np
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, LSTM, Dense, Embedding, Attention, Concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

# Load your dataset from the JSON file
with open('/content/datalabs-project/dataset.json', 'r') as json_file:
    dataset = json.load(json_file)

# Extract conversations and responses
conversations = []
responses = []

for key, data in dataset["prepDataset"].items():
    conversation = data["past_convo"]
    response = data["tutorResponses"]  # Assuming tutorResponses are the correct responses

    flat_conversation = " ".join(conversation)

    conversations.append(flat_conversation)
    responses.append(response)

# Tokenize the text
tokenizer = Tokenizer()
tokenizer.fit_on_texts(conversations + responses)

# Convert text to sequences
input_sequences = tokenizer.texts_to_sequences(conversations)
target_sequences = tokenizer.texts_to_sequences(responses)

# Pad sequences to have consistent lengths
max_sequence_length = max(map(len, input_sequences + target_sequences))
input_sequences = pad_sequences(input_sequences, maxlen=max_sequence_length, padding='post')
target_sequences = pad_sequences(target_sequences, maxlen=max_sequence_length, padding='post')

# Convert to numpy arrays
input_sequences = np.array(input_sequences)
target_sequences = np.array(target_sequences)

# Split the data into training, validation, and test sets
X_train, X_val, y_train, y_val = train_test_split(input_sequences, target_sequences, test_size=0.2, random_state=42)

# Define the model architecture
vocab_size = len(tokenizer.word_index) + 1  # Vocabulary size
embedding_dim = 256
hidden_units = 512
attention_units = 64

# Encoder
encoder_input = Input(shape=(max_sequence_length,))
encoder_embedding = Embedding(vocab_size, embedding_dim)(encoder_input)
encoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(encoder_embedding)
encoder_states = [state_h, state_c]

# Decoder
decoder_input = Input(shape=(max_sequence_length,))
decoder_embedding = Embedding(vocab_size, embedding_dim)(decoder_input)
decoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_embedding, initial_state=encoder_states)

# Key for attention
key_dense = Dense(hidden_units, activation='tanh')(encoder_outputs)

# Attention layer
attention_layer = Attention(use_scale=True)
attention_outputs = attention_layer([decoder_outputs, encoder_outputs, key_dense])

# Concatenate attention outputs with decoder outputs
decoder_combined_context = Concatenate(axis=-1)([decoder_outputs, attention_outputs])

# Output layer
decoder_dense = Dense(vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_combined_context)

# Model
model = Model(inputs=[encoder_input, decoder_input], outputs=decoder_outputs)

# Compile the model
model.compile(optimizer=Adam(), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

# Display the model summary
model.summary()

# Train the model
model.fit([X_train, X_train], y_train, validation_data=([X_val, X_val], y_val), epochs=10, batch_size=32)

# Save the trained model
model.save('trained_model.h5')

# Inference using the trained model

# Load the trained model
trained_model = load_model('trained_model.h5')

def get_model_response(user_input, conversation_context):
    # Add the user input to the conversation context
    user_sequence = tokenizer.texts_to_sequences([user_input])
    user_sequence = pad_sequences(user_sequence, maxlen=max_sequence_length, padding='post')

    # Add the user input to the conversation context
    conversation_context += user_sequence.tolist()

    # Find the index of the conversation context in the dataset
    context_index = -1
    for i, data in enumerate(dataset["prepDataset"].values()):
        if data["past_convo"] == [tokenizer.index_word.get(num, "") for num in conversation_context[0]]:
            context_index = i
            break

    # Use the dataset to retrieve the correct response
    if context_index != -1:
        model_response = dataset["prepDataset"][str(context_index)]["tutorResponses"][0]
    else:
        model_response = "I'm sorry, I don't understand your question."

    return model_response, conversation_context





Model: "model_8"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_29 (InputLayer)       [(None, 206)]                0         []                            
                                                                                                  
 input_30 (InputLayer)       [(None, 206)]                0         []                            
                                                                                                  
 embedding_28 (Embedding)    (None, 206, 256)             882432    ['input_29[0][0]']            
                                                                                                  
 embedding_29 (Embedding)    (None, 206, 256)             882432    ['input_30[0][0]']            
                                                                                            

In [None]:
conversation_context = []  # Initialize an empty conversation context

while True:
    user_input = input("Ask a question: ")
    if user_input.lower() == 'exit':
        break

    model_response, conversation_context = get_model_response(user_input, conversation_context)

    print("User Question:", user_input)
    print("Model Response:", model_response)
    print("=" * 50)



User Question: tree
Model Response: I'm sorry, I don't understand your question.
User Question: pink in italian
Model Response: I'm sorry, I don't understand your question.
