# Lab 2: AG News

In [1]:
import pandas as pd

# Load the Parquet file
df = pd.read_parquet('datasets/agnews2k.parquet')
print(df["label"].value_counts())
df

label
Sports      500
World       500
Sci/Tech    500
Business    500
Name: count, dtype: int64


Unnamed: 0,text,label
0,"Latest BCS a feel-good for Cal, upset stomach ...",Sports
1,Sindelar snaps slump with 66 at rainy Open Joe...,Sports
2,Seven killed in Karbala mosque explosion At le...,World
3,Building Dedicated to Space Shuttle Columbia A...,Sci/Tech
4,Guatemala to pay paramilitaries Guatemala's go...,World
...,...,...
1995,Dodgers Clobber Padres 9-6 (AP) AP - Cesar Izt...,Sports
1996,Digital photo album bypasses PCs SanDisk's med...,Sci/Tech
1997,Greek hero Charisteas the man for the big occa...,Sports
1998,FilePlanet Daily Download Marvel Sues NCSoft a...,Sci/Tech


In [2]:
def split_support_query(df, support_pct, random_state=42):
    """
    Split the dataset into support and query sets based on a percentage.
    """
    # Ensure percentage is within the valid range
    if not (0 < support_pct < 1):
        raise ValueError("support_pct must be a float between 0 and 1.")
    
    # Split the dataset
    support_set = df.sample(frac=support_pct, random_state=random_state)
    query_set = df.drop(support_set.index)

    return support_set, query_set

In [3]:
support_set, query_set = split_support_query(df, support_pct=0.25, random_state=42)

print(f"Support Set: ({len(support_set)} rows):")
print(support_set.head())

print(f"\nQuery Set: ({len(query_set)} rows):")
print(query_set.head())

Support Set: (500 rows):
                                                   text     label
1860  Falluja Rebels Had Enough Arms to Rule Iraq -U...     World
353   Notables With a 9:15 p.m. curfew imposed becau...    Sports
1333  Oil Prices Generate Winners and Losers With cr...     World
905   Offshore drilling rig missing after Ivan An of...  Business
1289  Top UN envoy details lack of progress in Darfu...     World

Query Set: (1500 rows):
                                                text     label
0  Latest BCS a feel-good for Cal, upset stomach ...    Sports
1  Sindelar snaps slump with 66 at rainy Open Joe...    Sports
3  Building Dedicated to Space Shuttle Columbia A...  Sci/Tech
4  Guatemala to pay paramilitaries Guatemala's go...     World
5  Manning to Get First Start Vs. Panthers (AP) A...    Sports


In [4]:
from FewShotX.embeddings.embed import Embeddings
from FewShotX.scoring.fewshot import FewShotLearner

# Initialize Embedder
embed_model = Embeddings(model_name='all-MiniLM-L6-v2', verbose=False)

# Initialize the FewShotLearner
few_shot_learner = FewShotLearner(
    support_set=support_set.copy(),
    text_col="text",
    label_col="label",
    embedding_model=embed_model
)

In [5]:
# Train the model using the support set
few_shot_learner.fit(lam=0.1, lr=0.05, epochs=50, early_stop=10, verbose=True)

Epoch 1/50 - Training Loss: 0.1604 - Validation Loss: 0.0055
Epoch 2/50 - Training Loss: 0.1119 - Validation Loss: 0.0054
Epoch 3/50 - Training Loss: 0.0900 - Validation Loss: 0.0043
Epoch 4/50 - Training Loss: 0.0734 - Validation Loss: 0.0037
Epoch 5/50 - Training Loss: 0.0667 - Validation Loss: 0.0036
Epoch 6/50 - Training Loss: 0.0611 - Validation Loss: 0.0033
Epoch 7/50 - Training Loss: 0.0618 - Validation Loss: 0.0034
Epoch 8/50 - Training Loss: 0.0608 - Validation Loss: 0.0034
Epoch 9/50 - Training Loss: 0.0617 - Validation Loss: 0.0034
Epoch 10/50 - Training Loss: 0.0648 - Validation Loss: 0.0036
Epoch 11/50 - Training Loss: 0.0665 - Validation Loss: 0.0038
Epoch 12/50 - Training Loss: 0.0694 - Validation Loss: 0.0040
Epoch 13/50 - Training Loss: 0.0746 - Validation Loss: 0.0038
Epoch 14/50 - Training Loss: 0.0763 - Validation Loss: 0.0042
Epoch 15/50 - Training Loss: 0.0746 - Validation Loss: 0.0036
Epoch 16/50 - Training Loss: 0.0720 - Validation Loss: 0.0040
Early stopping at

In [6]:
# Make predictions on the query set
df_pred, acc = few_shot_learner.predict(query_set.copy(), return_accuracy=True)
print("Accuracy: ", acc)
df_pred

Accuracy:  0.748


Unnamed: 0,text,label,pred,pred_label,true_label_idx
0,"Latest BCS a feel-good for Cal, upset stomach ...",Sports,3,Sci/Tech,1
1,Sindelar snaps slump with 66 at rainy Open Joe...,Sports,1,Sports,1
3,Building Dedicated to Space Shuttle Columbia A...,Sci/Tech,3,Sci/Tech,3
4,Guatemala to pay paramilitaries Guatemala's go...,World,2,Business,0
5,Manning to Get First Start Vs. Panthers (AP) A...,Sports,1,Sports,1
...,...,...,...,...,...
1994,"Nuggets 93, Rockets 88 DerMarr Johnson scored ...",Sports,1,Sports,1
1995,Dodgers Clobber Padres 9-6 (AP) AP - Cesar Izt...,Sports,1,Sports,1
1996,Digital photo album bypasses PCs SanDisk's med...,Sci/Tech,3,Sci/Tech,3
1997,Greek hero Charisteas the man for the big occa...,Sports,1,Sports,1


## Hyperparameter tuning

In [7]:
learning_rates = [0.0001, 0.001, 0.01]
lambdas = [0.01, 0.1, 0.5]
best_score = 0

for lr in learning_rates:
    for lam in lambdas:

        # Initialize the embedding model and FewShotLearner
        embed_model = Embeddings(model_name='all-MiniLM-L6-v2', verbose=False)
        few_shot_learner = FewShotLearner(
            support_set=support_set.copy(),
            text_col="text",
            label_col="label",
            embedding_model=embed_model
        )

        # Fit the model and print the loss for debugging
        few_shot_learner.fit(lam=lam, lr=lr, epochs=100, early_stop=10, verbose=False)

        # Evaluate the model
        _, new_score = few_shot_learner.predict(query_set.copy(), return_accuracy=True)
        print(f"LR: {lr}, Lambda: {lam}, Score: {new_score}")

        # Check if the new score is the best so far
        if new_score > best_score:
            best_score = new_score
            best_lr = lr
            best_lam = lam
            best_model = few_shot_learner
            print(f"New best hps: LR: {best_lr}, Lambda: {best_lam}, with a score of {best_score}")
            
print(f"Best Hyperparameters: LR: {best_lr}, Lambda: {best_lam}, with Score = {best_score}")

LR: 0.0001, Lambda: 0.01, Score: 0.8573333333333333
New best hps: LR: 0.0001, Lambda: 0.01, with a score of 0.8573333333333333
LR: 0.0001, Lambda: 0.1, Score: 0.8633333333333333
New best hps: LR: 0.0001, Lambda: 0.1, with a score of 0.8633333333333333
LR: 0.0001, Lambda: 0.5, Score: 0.8566666666666667
LR: 0.001, Lambda: 0.01, Score: 0.8513333333333334
LR: 0.001, Lambda: 0.1, Score: 0.8586666666666667
LR: 0.001, Lambda: 0.5, Score: 0.8473333333333334
LR: 0.01, Lambda: 0.01, Score: 0.804
LR: 0.01, Lambda: 0.1, Score: 0.8486666666666667
LR: 0.01, Lambda: 0.5, Score: 0.8446666666666667
Best Hyperparameters: LR: 0.0001, Lambda: 0.1, with Score = 0.8633333333333333
