# SetFit for Text Classification

In this notebook, we'll learn how to do few-shot text classification with SetFit.

## Setup

If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:

In [1]:
%pip install setfit datasets

Note: you may need to restart the kernel to use updated packages.


To be able to share your model with the community, there are a few more steps to follow.

First, you have to store your authentication token from the Hugging Face Hub (sign up [here](https://huggingface.co/join) if you haven't already!). To do so, execute the following cell and input an [access token](https://huggingface.co/docs/hub/security-tokens) associated with your account:

In [None]:
from huggingface_hub import notebook_login

notebook_login()

Then you need to install Git-LFS, which you can do by uncommenting and running following command:

In [None]:
# !apt install git-lfs

Finally, you may need to configue Git on your system by providing details about who you are:

In [None]:
# !git config --global user.email "you@example.com"
# !git config --global user.name "Your Name"

This notebook is designed to work with any multiclass [text classification dataset](https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads) and pretrained [Sentence Transformer](https://huggingface.co/models?library=sentence-transformers&sort=downloads) on the Hub. Change the values below to try a different dataset / model!

In [2]:
# Using your custom intent classification data
model_id = "sentence-transformers/all-MiniLM-L6-v2"  # Same model used in your RAG system

## Loading and sampling the dataset

We will use the ðŸ¤— Datasets library to download the data, which can be done as follows:

In [3]:
from datasets import Dataset
from training_data.intent_examples import TRAINING_DATA, LABELS, LABEL_TO_ID

# Convert your training data to HuggingFace Dataset format
texts = [text for text, label in TRAINING_DATA]
labels = [LABEL_TO_ID[label] for text, label in TRAINING_DATA]

# Create the dataset
full_dataset = Dataset.from_dict({
    "text": texts,
    "label": labels
})

print(f"Dataset created with {len(full_dataset)} examples")
print(f"Labels: {LABELS}")
print(f"Label mapping: {LABEL_TO_ID}")
full_dataset

  from .autonotebook import tqdm as notebook_tqdm


Dataset created with 512 examples
Labels: ['casual', 'visa_query', 'follow_up', 'booking', 'ticket_change', 'flight_info', 'clarification_origin', 'clarification_destination']
Label mapping: {'casual': 0, 'visa_query': 1, 'follow_up': 2, 'booking': 3, 'ticket_change': 4, 'flight_info': 5, 'clarification_origin': 6, 'clarification_destination': 7}


Dataset({
    features: ['text', 'label'],
    num_rows: 512
})

Most datasets on the Hub have many more labeled examples than those one encounters in few-shot settings. To simulate the effect of training on a limited number of examples, let's subsample the training set to have 8 labeled examples per class:

In [4]:
# Split into train (80%) and eval (20%)
# Shuffle first to ensure random distribution
full_dataset = full_dataset.shuffle(seed=42)
split = full_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]

print(f"Train size: {len(train_dataset)}")
print(f"Eval size: {len(eval_dataset)}")

# Show class distribution
from collections import Counter
train_labels = [LABELS[l] for l in train_dataset["label"]]
eval_labels = [LABELS[l] for l in eval_dataset["label"]]
print(f"\nTrain distribution: {Counter(train_labels)}")
print(f"Eval distribution: {Counter(eval_labels)}")

train_dataset

Train size: 409
Eval size: 103

Train distribution: Counter({'flight_info': 55, 'casual': 54, 'clarification_origin': 53, 'ticket_change': 52, 'visa_query': 52, 'booking': 51, 'follow_up': 48, 'clarification_destination': 44})
Eval distribution: Counter({'clarification_destination': 20, 'follow_up': 16, 'booking': 13, 'ticket_change': 12, 'visa_query': 12, 'clarification_origin': 11, 'casual': 10, 'flight_info': 9})


Dataset({
    features: ['text', 'label'],
    num_rows: 409
})

Here we have 16 total examples to train with since the `sst2` dataset has two classes (positive and negative). For evaluation, we'll use the validation split, since the test split of `sst2` is unlabeled:

In [5]:
# Eval dataset already created above - let's preview it
print("Sample eval examples:")
for i in range(min(3, len(eval_dataset))):
    text = eval_dataset[i]["text"]
    label = LABELS[eval_dataset[i]["label"]]
    print(f"  '{text}' -> {label}")

Sample eval examples:
  'I want a refund airline changed the schedule' -> ticket_change
  'british national' -> follow_up
  'do i need a visa' -> visa_query


Okay, now we have the dataset, let's load and train a model!

## Fine-tuning the model

To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the `from_pretrained()` method associated with the `SetFitModel` class:

In [6]:
from setfit import SetFitModel

model = SetFitModel.from_pretrained(model_id)

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the `SetFitTrainer` class as follows:

In [7]:
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import Trainer, TrainingArguments

# Training arguments - REDUCED iterations to prevent overfitting
args = TrainingArguments(
    batch_size=16,
    num_epochs=1,
    num_iterations=10,  # Reduced from 20 to prevent overfitting
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 409/409 [00:00<00:00, 25931.47 examples/s]


The main arguments to notice in the trainer is the following:

* `loss_class`: The loss function to use for contrastive learning with the Sentence Transformer body
* `num_iterations`: The number of text pairs to generate for contrastive learning
* `column_mapping`: The `SetFitTrainer` expects the inputs to be found in a `text` and `label` column. This mapping automatically formats the training and evaluation datasets for us.

Now that we've created a trainer, we can train it!

In [8]:
trainer.train()

***** Running training *****
  Num unique pairs = 8180
  Batch size = 16
  Num epochs = 1


Step,Training Loss
1,0.5136
50,0.2357
100,0.1545
150,0.1071
200,0.0768
250,0.0651
300,0.0484
350,0.0398
400,0.0313
450,0.0258


The final step is to compute the model's performance using the `evaluate()` method:

In [9]:
metrics = trainer.evaluate()
metrics

***** Running evaluation *****


{'accuracy': 0.9514563106796117}

Now save the model locally for use in your chatbot:

In [10]:
# Save the model locally
model.save_pretrained("models/intent_classifier_setfit")
print("Model saved to models/intent_classifier_setfit/")

Model saved to models/intent_classifier_setfit/


Now test the model on some examples to see how it performs:

In [11]:
# Test on critical examples that were failing before
test_examples = [

    # CASUAL
    "i want to create an app for visa applications can you help me",
    "why is dubai so expensive",
    "are you sure about that",
    "but i heard i dont need visa",
    "whats your name",
    "bro what is this",
    "this system makes no sense",
    "how does this work",
    "tell me a story",
    "i dont know man",
    "let me think first",
    "thats crazy",
    "tell me something interesting",
    "what are you doing",
    "youre too slow today",

    # === NON-TRAVEL CASUAL (NEW) ===
    "what's the weather like today",
    "how do i cook biryani",
    "who won the football match yesterday",
    "what is the capital of mars",
    "can you explain quantum physics",
    "why is the sky blue",
    "who is the president of the world",
    "how do I fix my laptop",
    "what's your favourite food",
    "can you help me with my homework",
    "how do i start a business",
    "why do cats sleep so much",
    "explain minecraft speedrunning",
    "should i buy an iphone or samsung",
    "how many pushups should i do daily",
    "tell me a funny joke",
    "why am i always tired",
    "what's 2948 times 334",
    "who invented the light bulb",
    "why do humans dream",
    "what's better gym or running",
    "teach me programming",
    "how do i lose weight fast",
    "why do people lie",
    "can you help me write my essay",
    "explain how cameras work",
    "why does coffee keep me awake",
    "how to improve my memory",
    "who is the richest person today",
    "how do i become famous",

    # VISA_QUERY
    "do i need a visa",
    "i want to go to dubai",
    "visa requirements for japan",
    "is visa required for thailand",
    "documents needed for usa entry",
    "can pakistanis visit turkey",
    "does indian passport need visa for germany",
    "can i travel without visa",
    "what visa should i apply for",
    "entry requirements for canada",

    # FOLLOW_UP
    "pakistani",
    "im from pakistan",
    "what about turkey",
    "indian passport",
    "and japan",
    "destination dubai",
    "british citizen",
    "my nationality is egyptian",
    "thats my destination",
    "france",

    # BOOKING
    "i want to book a flight to dubai for next week",
    "find me the cheapest ticket to istanbul",
    "can you book a hotel for me in doha",
    "i need 3 tickets for me and my family",
    "look for flights under 400 pounds",
    "book me anything to turkey tomorrow morning",
    "i wanna fly to abu dhabi with 30kg baggage",
    "search for return flights to singapore in april",
    "find me a nonstop flight to new york",
    "i want a package holiday to maldives",
    "can u find flights for 4 people to saudi",
    "i need a ticket asap cheapest possible",
    "show me business class prices for jeddah",
    "compare flight options for next friday",
    "i want to reserve a hotel with breakfast",

    # TICKET_CHANGE
    "i need to change my flight date",
    "can i get a refund for my cancelled flight",
    "my name is spelled wrong on the ticket",
    "i missed my flight what are my options",
    "move my return flight to next week",
    "please cancel my booking",
    "i want to upgrade to business class",
    "can you rebook me on another flight",
    "change my seat to window",
    "i want compensation for delay",
    "correct my travel date",
    "add baggage to my existing booking",
    "the airline changed the time i want refund",
    "change the outbound flight only",
    "i booked the wrong airport fix it",

    # FLIGHT_INFO
    "how much baggage can i take",
    "can i choose my seat beforehand",
    "is food served on qatar airways",
    "does emirates have wifi",
    "what time is check in",
    "are blankets provided",
    "can i take baby milk onboard",
    "is online check in available now",
    "what's the hand luggage weight",
    "is there entertainment on this flight",
    "do i need to print my boarding pass",
    "what are the rules for power banks",
    "is the flight full today",
    "is it ok to bring snacks",
    "how long is the layover in doha",
     # LOOKS LIKE VISA QUERY â€” BUT ISNâ€™T
    "do i need permission to enter your app",
    "what documents do i need to access this website",
    "can i get a pass to enter your system",
    "is there an entry requirement to join your service",
    "do i need authorization to proceed",
    "how do i apply for access",
    "is registration required to enter this chat",
    "what permit do i need to use this program",

    # LOOKS LIKE BOOKING â€” BUT ISNâ€™T
    "can you book me an appointment with my dentist",
    "i want to reserve a seat in the classroom",
    "can you schedule a meeting for me",
    "find me a free slot this afternoon",
    "can u book me a study session",
    "reserve space in my calendar tomorrow",
    "i wanna book a table at a restaurant",
    "help me book a haircut appointment",
    "find cheap deals for a gaming pc",
    "i want to book a slot at the gym",

    # LOOKS LIKE FLIGHT INFO â€” BUT ISNâ€™T
    "how much baggage is allowed in my relationship",
    "what time does my motivation take off",
    "how long is the layover between breakfast and lunch",
    "can i carry emotional baggage into this conversation",
    "is there wifi in my dreams",
    "can i choose a window seat in life",
    "is check-in open for my new job",
    "do i need to upgrade my brain to premium",
    "how strict is the weight limit on my backpack of stress",
    "do i get free snacks during my daily routine",
    "is turbulence normal in relationships",

    # LOOKS LIKE FOLLOW-UP â€” BUT ISNâ€™T
    "germany is my favourite football team",
    "turkey is what i'm cooking tonight",
    "japan is the anime capital",
    "dubai perfume smells good",
    "india is where my favourite actor is from",
    "france is my new laptop wallpaper",
    "canada is the brand of my jacket",
    "oman is the name of my cat",

    # LOOKS LIKE TRAVEL BUT IS COMPLETELY RANDOM
    "can i board the train of success",
    "do i need a visa to enter her heart",
    "when does my motivation depart",
    "is my life economy class or business class",
    "bro thats wild",
    "do you like cats",
    "tell me something cool",
    "what should i eat today",
    "is today a lucky day",
    "explain why humans sneeze",
    "what's your favourite movie",
    "i think i'm bored",
    "you answer too fast lol",
    "not sure what to ask",
    "do you know any riddles",
    "why do people sleep",
    "that's interesting tell me more",
    "explain gravity in simple words",
    "my phone keeps lagging help",
    "what music do you listen to",
    "tell me a random fact",
    "you're confusing me",
    "idk man feels weird",
    "thanks but that's not what i meant",

    # ============================
    # VISA_QUERY (20)
    # ============================
    "does a kenyan passport need visa for spain",
    "what documents are required to enter malaysia",
    "is visa on arrival available for viet nam",
    "can a filipino enter jordan without visa",
    "requirements for nepali citizens to visit qatar",
    "is biometrics needed for schengen",
    "do i need a visa to visit south korea",
    "how do bangladeshi citizens enter oman",
    "is a transit visa required for hong kong",
    "can sri lankans travel to uae visa free",
    "what permit do i need for canada visit",
    "visa policy for mexican passport holders",
    "is evisa possible for kenya",
    "can i enter germany with expired schengen",
    "what entry rules apply for japan",
    "does my nationality need an evisa",
    "im planning a vacation to poland what visa",
    "how to get tourist visa for egypt",
    "visa process for turkey explained",
    "do i need a visa for layover in beijing",

    # ============================
    # FOLLOW_UP (15)
    # ============================
    "nigerian",
    "bangladeshi",
    "im from morocco",
    "my passport is sri lankan",
    "saudi citizen",
    "oman",
    "doha",
    "and italy",
    "for japan",
    "my nationality is spanish",
    "that's my country",
    "where im going? turkey",
    "portugal",
    "im italian",
    "destination is thailand",

    # ============================
    # BOOKING (15)
    # ============================
    "search flights for next weekend to rome",
    "i need a ticket from dubai to london asap",
    "look for hotels near istanbul city center",
    "find the cheapest one way flight to india",
    "book me a return trip to kuwait",
    "can you check fares for bali in june",
    "need accommodation in doha for 2 nights",
    "reserve 2 seats for me and my brother",
    "i want direct flights only to toronto",
    "help me find a trip under 300 pounds",
    "show me evening flights to madrid",
    "get me a flight with 25kg baggage included",
    "find holiday packages for malaysia",
    "compare prices for flights to saudi next month",
    "book the earliest flight tomorrow morning",

    # ============================
    # TICKET_CHANGE (15)
    # ============================
    "i need to push my flight one day later",
    "cancel the whole trip please",
    "can you modify my booking reference",
    "i want to move my outbound flight earlier",
    "the spelling of my middle name is wrong",
    "i want a refund for the delayed flight",
    "switch my return to a different airport",
    "can i change the class to economy plus",
    "please fix the date on my ticket",
    "i chose the wrong seat change it please",
    "my booking email is wrong update it",
    "add a meal request to my ticket",
    "rebook me for the next available flight",
    "change my ticket type to flexible",
    "cancel the outbound but keep return",

    # ============================
    # FLIGHT_INFO (15)
    # ============================
    "can i bring a guitar on the flight",
    "is dinner served on long flights",
    "what's the check in counter opening time",
    "are small power banks allowed",
    "does this plane have usb ports",
    "can infants get their own seat",
    "is cabin luggage included",
    "do they provide pillows in economy",
    "how early can i drop off my baggage",
    "is there free water onboard",
    "how strict is the carry on weight",
    "can i bring homemade food",
    "is transit security check required",
    "do they show movies during flight",
    "can i take my medication onboard",
    # Add these to test_examples:

    # ============================
    # CLARIFICATION_ORIGIN (20) - User indicating country is their nationality
    # ============================
    "nationality",
    "from there",
    "my country",
    "where im from",
    "citizen",
    "passport",
    "i live there",
    "thats my home",
    "origin",
    "im from there",
    "thats where i live",
    "my homeland",
    "born there",
    "i was born there",
    "home country",
    "where i was born",
    "my passport country",
    "i am from there",
    "thats my nationality",
    "where i come from",

    # ============================
    # CLARIFICATION_DESTINATION (20) - User indicating country is their destination
    # ============================
    "travel",
    "going there",
    "destination",
    "visiting",
    "traveling",
    "travelling",
    "want to go there",
    "where im going",
    "thats where i want to visit",
    "trip",
    "holiday",
    "vacation",
    "going to visit",
    "planning to go",
    "want to travel there",
    "thats my destination",
    "where i want to go",
    "i want to visit there",
    "flying there",
    "headed there",

    # ============================
    # EDGE CASES - Short ambiguous responses
    # ============================
    "there",
    "yes there",
    "yep",
    "that one",
    "the first one",
    "second option",
    
    # ============================
    # MIXED/TRICKY - Should be casual, not clarification
    # ============================
    "i travel a lot for work",
    "my passport expired last year",
    "nationality doesnt matter to me",
    "im not going anywhere",
    "i dont have a destination in mind",
    "my home is where my heart is",
    "i was born to be wild",
    "vacation mode activated lol",
    "holiday shopping is expensive",
    "traveling is my passion",
    
    # ============================
    # MORE CASUAL EDGE CASES
    # ============================
    "lol",
    "haha",
    "bruh",
    "hmm",
    "interesting",
    "wow",
    "cool",
    "nice",
    "okay then",
    "alright",
    "sure thing",
    "whatever",
    "maybe",
    "perhaps",
    "i guess",
    "probably",
    "definitely",
    "absolutely",
    "of course",
    "why not",
    # CASUAL (general confusion, complaints, bot-related)
    "why are you responding like this",
    "bro this app is tripping",
    "what kind of system are you",
    "this makes zero sense",
    "are you even working properly",
    "explain what you just said",
    "tell me something random",
    "i cant understand you",
    "youre acting weird today",
    "what do you mean by that",
    "i dont get it man",
    "say something interesting",
    "whats your purpose",
    "bro calm down",
    "you sound confused",

    # === NON-TRAVEL CASUAL (NEW) ===
    "why do people get headaches",
    "how do i bake a cake",
    "what's the capital of saturn",
    "which animal runs fastest",
    "why do birds migrate",
    "explain nuclear fusion",
    "how can i fix my keyboard",
    "teach me how to study better",
    "what's your favourite colour",
    "tell me a story about space",
    "do ants sleep",
    "why do people forget things",
    "which phone has the best camera",
    "how do i train my dog",
    "what's a good movie to watch",
    "how do i stop procrastinating",
    "explain how headphones work",
    "why does sugar dissolve in water",
    "who discovered gravity",
    "why do babies cry",

    # VISA_QUERY
    "is visa needed for a trip to norway",
    "documents required for uae trip",
    "do russians need visa for turkey",
    "visa type needed for canada visit",
    "is evisa available for oman",
    "how to get visa for malaysia",
    "do italians need visa for morocco",
    "is visa required to enter chile",
    "visiting japan what visa needed",
    "do i need visa for mexico holiday",

    # FOLLOW_UP
    "uzbek",
    "im algerian",
    "my passport is brazilian",
    "ethiopian citizen",
    "from jordan",
    "lebanon",
    "qatar national",
    "im from nigeria",
    "iranian",
    "south african passport",

    # BOOKING
    "find me a cheap flight to muscat",
    "i want a direct flight to delhi",
    "search hotels near abu dhabi beach",
    "book me two seats for saturday",
    "look for tickets under 500 dollars",
    "get me morning flights only",
    "reserve a hotel with free breakfast",
    "find me flights landing before 6pm",
    "i need a one way to kuwait tonight",
    "book me a 5 night stay in doha",

    # TICKET_CHANGE
    "move my flight to next monday",
    "change my seat to aisle please",
    "refund my ticket airline cancelled",
    "i need to adjust departure time",
    "update name spelling on my booking",
    "switch me to an earlier flight",
    "add extra baggage to my ticket",
    "fix the wrong birthdate on ticket",
    "cancel my entire booking",
    "modify my return flight only",

    # FLIGHT_INFO
    "is 25kg baggage allowed",
    "can i take a stroller onboard",
    "is food free on gulf air",
    "how early does boarding start",
    "do they provide pillows",
    "whats the allowed cabin size",
    "are pets allowed in cabin",
    "is there free wi-fi on board",
    "can i carry liquids in hand bag",
    "do infants get free seats",

    # LOOKS LIKE VISA_QUERY â€” BUT ISNâ€™T (system/app related)
    "do i need permission to enter the server",
    "what documents do i need to sign up",
    "is registration mandatory for this platform",
    "do i need an access pass for this account",
    "how do i apply for system access",
    "do i need authorization to continue",
    "is there a requirement to unlock features",
    "what permit lets me use this tool",

    # LOOKS LIKE BOOKING â€” BUT ISNâ€™T
    "can you book me a dentist checkup",
    "i want to reserve a library seat",
    "schedule a reminder for me",
    "book me a study room",
    "help me reserve a parking space",
    "find me an appointment slot tomorrow",
    "book a reservation at a cafe",
    "help me book a barber slot",
    "find me a discount for a laptop",
    "reserve gym time for me",

    # LOOKS LIKE FLIGHT INFO â€” BUT ISNâ€™T
    "how much emotional baggage can i carry",
    "when does my productivity take off",
    "how long is the layover between naps",
    "can i carry my stress onboard",
    "is there wifi in my dreams",
    "can i upgrade my life to premium",
    "is turbulence normal in friendships",
    "whats the weight limit of my problems",
    "do i get snacks during my daily routine",
    "is check-in open for happiness",

    # LOOKS LIKE FOLLOW-UP â€” BUT ISNâ€™T
    "france is my favourite holiday photo",
    "germany makes good cars",
    "turkey is what i had for dinner",
    "dubai is my favourite perfume scent",
    "china makes great gadgets",
    "italy has amazing food",
    "qatar is my favourite football team",
    "oman is my cat's name",

    # RANDOM TRAVEL-SOUNDING BUT NOT TRAVEL
    "can i board the train of ambition",
    "is my future economy or business",
    "when does my motivation land",
    "do i need a visa to enter success",
    "my confidence took off today",
    "tell me something exciting",
    "whats something fun to do today",
    "why do humans sneeze",
    "explain daydreaming",
    "what music should i listen to",
    "i think im hungry",
    "you respond too quickly lol",
    "i dont know what to ask",
    "tell me something surprising",
    "explain rainbows",
    "my laptop keeps freezing",
    "what snack should i eat",
    "my brain is tired",
    "thanks but that didnt help",

    # ============================
    # VISA_QUERY (more tricky)
    # ============================
    "does a jordanian need visa for cyprus",
    "entry rules for thailand trip",
    "is schengen needed for germany visit",
    "do filipinos need visa for korea",
    "turkey visit what documents needed",
    "is visa required for argentina",
    "how long does dubai visa last",
    "can yemenis travel to qatar visa free",
    "do i need evisa for cambodia",
    "requirements for iraqi visiting malaysia",

    # ============================
    # FOLLOW_UP (more tricky)
    # ============================
    "tunisian",
    "im from sudan",
    "my passport is kuwaiti",
    "im lebanese",
    "oman citizen",
    "im from spain",
    "egyptian passport",
    "im from turkey",
    "saudi",
    "im qatari",

    # ============================
    # BOOKING (more tricky)
    # ============================
    "find me late night flights to india",
    "i need a direct flight to paris tomorrow",
    "book me the cheapest hotel in muscat",
    "search for flights with free baggage",
    "i need a flexible ticket to dubai",
    "get me flights to bahrain this week",
    "reserve a room in central paris",
    "can you book a flight with no stopover",
    "find me seats for 3 adults",
    "i want morning flights only to cairo",

    # ============================
    # TICKET_CHANGE (more tricky)
    # ============================
    "fix the wrong passenger name",
    "i want to postpone my flight",
    "update my booking details",
    "cancel only my return flight",
    "modify the itinerary please",
    "i want a refund airline changed time",
    "switch my seat to front row",
    "i need to add baggage allowance",
    "change my trip to next wednesday",
    "correct my flight schedule",

    # ============================
    # FLIGHT_INFO (more tricky)
    # ============================
    "is 30kg baggage allowed in economy",
    "are power outlets available",
    "is food free on long routes",
    "whats the cabin bag size",
    "is check-in online available",
    "do they offer kids meals",
    "can i bring sweets on board",
    "is water served for free",
    "do all seats have screens",
    "how early can i check in",

]


print("Testing model predictions:\n")
preds = model.predict(test_examples)
for text, pred in zip(test_examples, preds):
    label = LABELS[pred]
    print(f"'{text}'")
    print(f"  â†’ {label}\n")

Testing model predictions:

'i want to create an app for visa applications can you help me'
  â†’ visa_query

'why is dubai so expensive'
  â†’ casual

'are you sure about that'
  â†’ casual

'but i heard i dont need visa'
  â†’ visa_query

'whats your name'
  â†’ casual

'bro what is this'
  â†’ casual

'this system makes no sense'
  â†’ casual

'how does this work'
  â†’ casual

'tell me a story'
  â†’ casual

'i dont know man'
  â†’ casual

'let me think first'
  â†’ casual

'thats crazy'
  â†’ casual

'tell me something interesting'
  â†’ casual

'what are you doing'
  â†’ casual

'youre too slow today'
  â†’ casual

'what's the weather like today'
  â†’ casual

'how do i cook biryani'
  â†’ casual

'who won the football match yesterday'
  â†’ casual

'what is the capital of mars'
  â†’ visa_query

'can you explain quantum physics'
  â†’ casual

'why is the sky blue'
  â†’ casual

'who is the president of the world'
  â†’ casual

'how do I fix my laptop'
  â†’ ticket_change

'what'

## Fine-tuning with a pure PyTorch model

`setfit` also provides a pure PyTorch implementation of `SetFitModel`, where the head is a dense layer instead of a classifier from `scikit-learn`. This allows one to do backprop end-to-end and have more fine-grained control over the training process.

To use the PyTorch model, we load a pretrained model with `use_differentiable_head=True` and specify the number of classes to include in the head:

In [None]:
from setfit import SetFitModel

num_classes = len(train_dataset.unique("label"))
model = SetFitModel.from_pretrained(model_id, use_differentiable_head=True, head_params={"out_features": num_classes})

As before, we instantiate the trainer:

In [None]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    num_iterations=20,
    column_mapping={"sentence": "text", "label": "label"},
)

Next, we freeze the weights of the final layer and apply contrastive learning:

In [None]:
trainer.freeze()
trainer.train(body_learning_rate=1e-5, num_epochs=1)

Note that here we can specify the learning rate for the model's body - we find that small values in 1e-5 range work well for this step.

Now that the model body is tuned, we can unfreeze the head and train it:

In [None]:
trainer.unfreeze(keep_body_frozen=True)
trainer.train(learning_rate=1e-2, num_epochs=50)

Note that a larger learning rate is used when training the head. We recommend using values in the 1e-2 range. Now that the model is trained, we can evaluate it as usual:

In [None]:
trainer.evaluate()

Nice! This is comparable to the results found with the `scikit-learn` head.