In [1]:
from sentence_transformers import SentenceTransformer, util
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the pre-trained Sentence Transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')

In [3]:
# Define the question-answer dataset
qa_pairs = [
    ("What is supervised learning?", "Supervised learning is a type of machine learning where a model is trained on labeled data."),
    ("What is unsupervised learning?", "Unsupervised learning is a type of machine learning where a model is trained on data without labels."),
    ("What is overfitting?", "Overfitting occurs when a model learns noise in the training data and performs poorly on new data."),
    ("What is underfitting?", "Underfitting happens when a model is too simple and fails to capture patterns in the training data."),
    ("What is cross-validation?", "Cross-validation is a technique used to assess the performance of a model by splitting the data into multiple training and testing sets."),
    ("What is regularization?", "Regularization is a technique used to prevent overfitting by adding a penalty to the model's complexity."),
    ("What is a decision tree?", "A decision tree is a supervised learning algorithm that splits data based on feature values to make predictions."),
    ("What is logistic regression?", "Logistic regression is a classification algorithm used to predict the probability of a binary outcome."),
    ("What is gradient descent?", "Gradient descent is an optimization algorithm used to minimize the loss function in machine learning models."),
    ("What is an activation function?", "An activation function in neural networks decides whether a neuron should be activated or not."),
    ("What is backpropagation?", "Backpropagation is an algorithm used to train neural networks by adjusting weights based on error gradients."),
    ("What is feature scaling?", "Feature scaling is a preprocessing technique used to normalize numerical values in a dataset."),
    ("What is a confusion matrix?", "A confusion matrix is a table used to evaluate the performance of a classification model."),
    ("What is precision in classification?", "Precision is the ratio of correctly predicted positive observations to the total predicted positive observations."),
    ("What is recall in classification?", "Recall measures the ability of a classifier to find all relevant instances in a dataset."),
    ("What is an ROC curve?", "An ROC curve is a graphical representation of a classifier’s performance across different threshold values."),
    ("What is PCA in machine learning?", "Principal Component Analysis (PCA) is a dimensionality reduction technique used to transform data into fewer dimensions."),
    ("What is reinforcement learning?", "Reinforcement learning is a type of machine learning where an agent learns to make decisions by receiving rewards."),
    ("What is transfer learning?", "Transfer learning is a machine learning technique where a pre-trained model is fine-tuned for a different but related task."),
    ("What is a support vector machine?", "A Support Vector Machine (SVM) is a supervised learning algorithm used for classification and regression tasks."),
]

In [4]:
# Extract questions from the dataset
questions = [q for q, a in qa_pairs]

# Encode questions into sentence embeddings
question_embeddings = model.encode(questions, convert_to_tensor=True)

In [5]:
# Finds the most relevant answer for a given query using cosine similarity.
def find_best_answer(query): 
    query_embedding = model.encode(query, convert_to_tensor=True)

    # Compute cosine similarity between query and stored questions
    similarities = util.cos_sim(query_embedding, question_embeddings)

    # Get the index of the most similar question
    best_match_index = similarities.argmax().item()
    return qa_pairs[best_match_index][1]

In [6]:
# Example Usage
new_question = "Can you define generalization for a model?"
best_answer = find_best_answer(new_question)
print("Question:", new_question)
print("Best Answer:", best_answer)

Question: Can you define generalization for a model?
Best Answer: Regularization is a technique used to prevent overfitting by adding a penalty to the model's complexity.


“How does a model perform well on unseen data?”

“What makes a model generalize?”