# FewShot Tutorial

This notebook demonstrates the use of the `FewShotX` package, available for download [here](https://github.com/RenatoVassallo/FewShotX).

## Classifying Pets and Code-Related Texts

In this tutorial, we will use the `FewShotLearner` class to classify texts into categories such as pets or code-related content. The workflow includes the following steps:

1. **Model Initialization:**  
   - Instantiate the `FewShotLearner` class with the support set (examples and observed labels) and the chosen encoder.

2. **Training the Model:**  
   - The `.fit()` method consists of two stages:  
     - **Data Preparation:** Splits the support set into training and validation subsets.  
     - **Model Training:** A linear model with L2 regularization and Bayesian MSE loss is applied. Key hyperparameters include `lambda`, `learning rate`, and `early stopping` to prevent overfitting.

     $$\mathbf{W}^* = \arg \min_{\mathbf{W}} \left( \| \mathbf{X}^\top \mathbf{W} - \mathbf{Y} \|^2 + \lambda \| \mathbf{W} - \mathbb{I} \|^2 \right) $$

3. **Prediction:**  
   - Predict categories by computing the interaction between the query set embeddings and the learned mapping matrix ($W^* $).

In [1]:
import pandas as pd

# Creating a Toy Dataset
support_data = {
    'text': ['Cats are cute', 'Dogs are loyal', 'Birds are awesome',
             'I love programming', 'I like coding', 'I am data scientist'],
    'label': ['Pets', 'Pets', 'Pets',
              'Code', 'Code', 'Code']
}
support_set = pd.DataFrame(support_data)
support_set

Unnamed: 0,text,label
0,Cats are cute,Pets
1,Dogs are loyal,Pets
2,Birds are awesome,Pets
3,I love programming,Code
4,I like coding,Code
5,I am data scientist,Code


## 1. Step-by-step method

In [2]:
from FewShotX import Embeddings, FewShotLearner

# Instantiate the Embeddings class
embedding_model = Embeddings(model_name='all-MiniLM-L6-v2')

# Instantiate the FewShotLearner class
learner = FewShotLearner(support_set, text_col='text', label_col='label', embedding_model=embedding_model)

In [3]:
# Prepare the training data using the _prepare_training_data method
(X_train, y_train), (X_val, y_val), input_dim, output_dim = learner._prepare_training_data(val_split=0.2)
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

X_train shape: torch.Size([4, 384])
y_train shape: torch.Size([4, 384])


+ The shape of `X_tensor` is: 4 samples × 384-dimensional embeddings → (4, 384).
+ The labels "Pets" and "Code" are also embedded using the same model.
+ This 2 unique embeddings are then mapped to its respective example, resulting in a `y_tensor` of shape (4, 384).

In [4]:
# We train the model using a validation set and early stopping to prevent overfitting
learner._train_model((X_train, y_train), (X_val, y_val), input_dim, output_dim, lam=0.1, lr=0.1, 
                     epochs=20, early_stop=5, verbose=True)

Epoch 1/20 - Training Loss: 0.0276 - Validation Loss: 0.4302
Epoch 2/20 - Training Loss: 1.1491 - Validation Loss: 0.0895
Epoch 3/20 - Training Loss: 0.1122 - Validation Loss: 0.1326
Epoch 4/20 - Training Loss: 0.2852 - Validation Loss: 0.2582
Epoch 5/20 - Training Loss: 0.6347 - Validation Loss: 0.2057
Epoch 6/20 - Training Loss: 0.4643 - Validation Loss: 0.0893
Epoch 7/20 - Training Loss: 0.1510 - Validation Loss: 0.0361
Epoch 8/20 - Training Loss: 0.0440 - Validation Loss: 0.0645
Epoch 9/20 - Training Loss: 0.1586 - Validation Loss: 0.1089
Epoch 10/20 - Training Loss: 0.2799 - Validation Loss: 0.1115
Epoch 11/20 - Training Loss: 0.2609 - Validation Loss: 0.0754
Epoch 12/20 - Training Loss: 0.1477 - Validation Loss: 0.0366
Early stopping at epoch 12


If the validation loss does **not improve** for 5 consecutive epochs, the training process is `stopped early`.

In [5]:
query_data = {
    'text': ['Parrots can talk and mimic sounds',
             'Developing machine learning models is fascinating'],
    'label': ['Pets', 'Code']
}
query_set = pd.DataFrame(query_data)
query_set

Unnamed: 0,text,label
0,Parrots can talk and mimic sounds,Pets
1,Developing machine learning models is fascinating,Code


In [6]:
# Compute predictions
predictions, acc = learner.predict(query_set, k=3, return_accuracy=True)
print("Accuracy: ", acc)
predictions

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Accuracy:  1.0


Unnamed: 0,text,label,pred,pred_label,true_label_idx
0,Parrots can talk and mimic sounds,Pets,0,Pets,0
1,Developing machine learning models is fascinating,Code,1,Code,1


## 2. Direct method

In [7]:
from FewShotX import Embeddings, FewShotLearner

# Instantiate the Embeddings class
embedding_model = Embeddings(model_name='all-MiniLM-L6-v2')

# Train our learner with the support set
learner = FewShotLearner(support_set, text_col='text', label_col='label', embedding_model=embedding_model)
learner.fit(val_split=0.2, lam=0.1, lr=0.1, epochs=20, early_stop=5, verbose=True)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1/20 - Training Loss: 0.0276 - Validation Loss: 0.4302
Epoch 2/20 - Training Loss: 1.1491 - Validation Loss: 0.0895
Epoch 3/20 - Training Loss: 0.1122 - Validation Loss: 0.1326
Epoch 4/20 - Training Loss: 0.2852 - Validation Loss: 0.2582
Epoch 5/20 - Training Loss: 0.6347 - Validation Loss: 0.2057
Epoch 6/20 - Training Loss: 0.4643 - Validation Loss: 0.0893
Epoch 7/20 - Training Loss: 0.1510 - Validation Loss: 0.0361
Epoch 8/20 - Training Loss: 0.0440 - Validation Loss: 0.0645
Epoch 9/20 - Training Loss: 0.1586 - Validation Loss: 0.1089
Epoch 10/20 - Training Loss: 0.2799 - Validation Loss: 0.1115
Epoch 11/20 - Training Loss: 0.2609 - Validation Loss: 0.0754
Epoch 12/20 - Training Loss: 0.1477 - Validation Loss: 0.0366
Early stopping at epoch 12


In [8]:
# Compute predictions
predictions, acc = learner.predict(query_set, k=3, return_accuracy=True)
predictions

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,text,label,pred,pred_label,true_label_idx
0,Parrots can talk and mimic sounds,Pets,0,Pets,0
1,Developing machine learning models is fascinating,Code,1,Code,1
