# FewShotLearner class walkthrough

This notebook demonstrates the use of the `FewShotX` package (version 0.1.2), available for download [here](https://github.com/RenatoVassallo/BSE-ForecastNLP/releases/download/0.1.2/fewshotx-0.1.2-py3-none-any.whl).

In [None]:
import pandas as pd

# Creating a Toy Dataset
support_data = {
    'text': ['Cats are cute', 'Dogs are loyal', 'Birds are awesome',
             'I love programming', 'I like coding', 'I am data scientist'],
    'label': ['Pets', 'Pets', 'Pets',
              'Code', 'Code', 'Code']
}
support_set = pd.DataFrame(support_data)
support_set

In [None]:
from FewShotX import Embeddings, FewShotLearner

# Instantiate the Embeddings class
embedding_model = Embeddings(model_name='all-MiniLM-L6-v2')

# Instantiate the FewShotLearner class with the toy dataset
learner = FewShotLearner(support_set, text_col='text', label_col='label', embedding_model=embedding_model)

In [None]:
# Prepare the training data using the _prepare_training_data method
(X_train, y_train), (X_val, y_val), input_dim, output_dim = learner._prepare_training_data(val_split=0.2)
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)

+ The shape of `X_tensor` is: 4 samples × 384-dimensional embeddings → (4, 384).
+ The labels "Pets" and "Code" are also embedded using the same model.
+ This 2 unique embeddings are then mapped to its respective example, resulting in a `y_tensor` of shape (4, 384).

In [None]:
learner._train_model((X_train, y_train), (X_val, y_val), input_dim, output_dim, lam=0.1, lr=0.1, 
                     epochs=20, early_stop=5, verbose=True)

If the validation loss does **not improve** for 5 consecutive epochs, the training process is `stopped early`.

In [None]:
query_data = {
    'text': ['Parrots can talk and mimic sounds',
             'Developing machine learning models is fascinating'],
    'label': ['Pets', 'Code']
}
query_set = pd.DataFrame(query_data)
query_set

In [None]:
# Compute predictions
predictions, acc = learner.predict(query_set, k=3, return_accuracy=True)
print("Accuracy: ", acc)
predictions