In [3]:
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification
import torch

intents_data = {
    "intents": [
        {
            "intent": "GetSpicyDishesForFever",
            "text": [
                "Can you tell me about the spicy dishes?",
                "Any spicy dishes recommendation?",
                "Spicy dishes?"
            ],
            "responses": [
                "Here are some spicy dishes suitable for fever: [list of spicy dish names]."
            ],
            "context": {
                "in": "",
                "out": "SpicyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "spicy_dishes = []\nfor item in menu_data:\n    description = item.get('description', '')\n    if 'spicy' in description.lower():\n        spicy_dishes.append(item)\nreturn spicy_dishes"
        },
        {
            "intent": "GetKidsFriendlyDishes",
            "text": [
                "Which dishes are kids-friendly?",
                "Which dishes are kids friendly?",
                "Kids friendly dishes recommendation?"
            ],
            "responses": [
                "Here are some kids-friendly dishes: [list of kids-friendly dish names]."
            ],
            "context": {
                "in": "",
                "out": "KidsFriendlyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "kids_friendly_dishes = [item for item in menu_data if item.get('kidsFriendly') == True]\nreturn kids_friendly_dishes"
        },
        {
            "intent": "GetVeganDishes",
            "text": [
                "What are the vegan options?",
                "Vegan options recommendation?"
            ],
            "responses": [
                "Here are some vegan dishes: [list of vegan dish names]."
            ],
            "context": {
                "in": "",
                "out": "VeganDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "vegan_dishes = []\nfor item in menu_data:\n    item_filters = item.get('itemFilter', [])\n    is_vegan = any(\n        filter_item.get('name', '').lower() == 'vegan'\n        for filter_item in item_filters\n    )\n    if is_vegan:\n        vegan_dishes.append(item)\nreturn vegan_dishes"
        },
        {
            "intent": "GetNutFreeDishes",
            "text": [
                "List the nut-free dishes."
            ],
            "responses": [
                "Here are some nut-free dishes: [list of nut-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "NutFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "nut_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'nuts' not in item_info:\n        nut_free_dishes.append(item)\nreturn nut_free_dishes"
        },
        {
            "intent": "GetFishFreeDishes",
            "text": [
                "List the fish-free dishes."
            ],
            "responses": [
                "Here are some fish-free dishes: [list of fish-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "FishFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "fish_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'fish' not in item_info:\n        fish_free_dishes.append(item)\nreturn fish_free_dishes"
        },
        {
            "intent": "FindDishWithLeastPrepTime",
            "text": [
                "Find the dish with the least prep time.",
                "Find the dish with the least prep time for appetizers.",
                "Find the dish with the least prep time for A2B combos."
            ],
            "responses": [
                "The dish with the least prep time is [dish name] with a prep time of [prep time] minutes."
            ],
            "context": {
                "in": "",
                "out": "LeastPrepTimeDish",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "filtered_items = [item for item in menu_data if item.get('subCategory', '').lower() == subcategory.lower()]\nif filtered_items:\n    min_item = min(filtered_items, key=lambda x: int(x['prepTimeInMins']))\n    return min_item\nreturn None"
        },
        {
            "intent": "RetrieveDishDescription",
            "text": [
                "What is the description of the [dish name]?",
                "What is the description of [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [description]."
            ],
            "context": {
                "in": "",
                "out": "DishDescription",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishAllergicInfo",
            "text": [
                "What is the allergic info for [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [allergicInfo]."
            ],
            "context": {
                "in": "",
                "out": "DishAllergicInfo",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishPrice",
            "text": [
                "How much does the [dish name] cost?"
            ],
            "responses": [
                "The price of [dish name] is $[price]."
            ],
            "context": {
                "in": "",
                "out": "DishPrice",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        }
    ]
}


# Load pre-trained model with classification head
# Make sure num_labels corresponds to the number of intents, which is the length of intents_data['intents']
model = GPT2ForSequenceClassification.from_pretrained('distilgpt2', num_labels=len(intents_data['intents']))
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

# Function to recognize intent
def recognize_intent(user_input, intents_data):
    # Tokenize the input
    inputs = tokenizer(user_input, return_tensors="pt")
    # Get model outputs
    outputs = model(**inputs)
    logits = outputs.logits
    # Find the predicted class (intent)
    predicted_class_id = torch.argmax(logits, dim=-1).item()
    # Access the intent from intents_data['intents'] list
    return intents_data['intents'][predicted_class_id]

# Adjusted intents structure

# Example usage
user_input = "What is the price of spicy dish ?"
recognized_intent = recognize_intent(user_input, intents_data)
print(f"Recognized Intent: {recognized_intent}")


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Recognized Intent: {'intent': 'RetrieveDishPrice', 'text': ['How much does the [dish name] cost?'], 'responses': ['The price of [dish name] is $[price].'], 'context': {'in': '', 'out': 'DishPrice', 'clear': False}, 'entityType': 'DishName', 'entities': [{'entity': 'DishName', 'rangeFrom': 5, 'rangeTo': 6}], 'query': "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"}


In [4]:
# Example usage
user_input2 = "What is the price of keerai vadai ?"
recognized_intent2 = recognize_intent(user_input2, intents_data)
print(f"Recognized Intent: {recognized_intent2}")


Recognized Intent: {'intent': 'RetrieveDishPrice', 'text': ['How much does the [dish name] cost?'], 'responses': ['The price of [dish name] is $[price].'], 'context': {'in': '', 'out': 'DishPrice', 'clear': False}, 'entityType': 'DishName', 'entities': [{'entity': 'DishName', 'rangeFrom': 5, 'rangeTo': 6}], 'query': "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"}


In [11]:
# Example usage
user_input3 = "get Nut FreeDishes"
recognized_intent3 = recognize_intent(user_input3, intents_data)
print(f"Recognized Intent: {recognized_intent3}")


Recognized Intent: {'intent': 'GetFishFreeDishes', 'text': ['List the fish-free dishes.'], 'responses': ['Here are some fish-free dishes: [list of fish-free dish names].'], 'context': {'in': '', 'out': 'FishFreeDishesList', 'clear': False}, 'entityType': 'NA', 'entities': [], 'query': 'fish_free_dishes = []\nfor item in menu_data:\n    item_info = f"{item.get(\'description\', \'\')} {item.get(\'allergicInfo\', \'\')} {item.get(\'itemName\', \'\')}".lower()\n    if \'fish\' not in item_info:\n        fish_free_dishes.append(item)\nreturn fish_free_dishes'}


# Training intent on distilGPT

In [3]:
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset

# Load tokenizer and pre-trained GPT2 model for sequence classification
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

# Add a pad token to the tokenizer (GPT-2 doesn't have one by default)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Load the model and resize embeddings to match the tokenizer
model = GPT2ForSequenceClassification.from_pretrained('distilgpt2', num_labels=8)  # Adjust num_labels as needed
model.resize_token_embeddings(len(tokenizer))  # Resize the model's token embeddings to accommodate the added pad_token

# Set pad_token_id to avoid issues when padding
model.config.pad_token_id = tokenizer.pad_token_id

# Define the Dataset class for training data
class IntentDataset(Dataset):
    def __init__(self, tokenizer, intents_data, max_len=128):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.intents_data = intents_data['intents']
        self.texts = []
        self.labels = []
        
        # Prepare the texts and labels
        for idx, intent in enumerate(self.intents_data):
            for example_text in intent['text']:
                self.texts.append(example_text)
                self.labels.append(idx)  # Label is the index of the intent
                
    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Tokenize the text
        inputs = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
        input_ids = inputs["input_ids"].squeeze()  # Squeeze out the extra dimension
        attention_mask = inputs["attention_mask"].squeeze()  # Squeeze out the extra dimension
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label, dtype=torch.long)
        }

# Example intents_data (use your full intents_data)
intents_data = {
    "intents": [
        {
            "intent": "GetSpicyDishesForFever",
            "text": [
                "Can you tell me about the spicy dishes?",
                "Any spicy dishes recommendation?",
                "Spicy dishes?"
            ],
            "responses": [
                "Here are some spicy dishes suitable for fever: [list of spicy dish names]."
            ],
            "context": {
                "in": "",
                "out": "SpicyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "spicy_dishes = []\nfor item in menu_data:\n    description = item.get('description', '')\n    if 'spicy' in description.lower():\n        spicy_dishes.append(item)\nreturn spicy_dishes"
        },
        {
            "intent": "GetKidsFriendlyDishes",
            "text": [
                "Which dishes are kids-friendly?",
                "Which dishes are kids friendly?",
                "Kids friendly dishes recommendation?"
            ],
            "responses": [
                "Here are some kids-friendly dishes: [list of kids-friendly dish names]."
            ],
            "context": {
                "in": "",
                "out": "KidsFriendlyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "kids_friendly_dishes = [item for item in menu_data if item.get('kidsFriendly') == True]\nreturn kids_friendly_dishes"
        },
        {
            "intent": "GetVeganDishes",
            "text": [
                "What are the vegan options?",
                "Vegan options recommendation?"
            ],
            "responses": [
                "Here are some vegan dishes: [list of vegan dish names]."
            ],
            "context": {
                "in": "",
                "out": "VeganDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "vegan_dishes = []\nfor item in menu_data:\n    item_filters = item.get('itemFilter', [])\n    is_vegan = any(\n        filter_item.get('name', '').lower() == 'vegan'\n        for filter_item in item_filters\n    )\n    if is_vegan:\n        vegan_dishes.append(item)\nreturn vegan_dishes"
        },
        {
            "intent": "GetNutFreeDishes",
            "text": [
                "List the nut-free dishes."
            ],
            "responses": [
                "Here are some nut-free dishes: [list of nut-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "NutFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "nut_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'nuts' not in item_info:\n        nut_free_dishes.append(item)\nreturn nut_free_dishes"
        },
        {
            "intent": "GetFishFreeDishes",
            "text": [
                "List the fish-free dishes."
            ],
            "responses": [
                "Here are some fish-free dishes: [list of fish-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "FishFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "fish_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'fish' not in item_info:\n        fish_free_dishes.append(item)\nreturn fish_free_dishes"
        },
        {
            "intent": "FindDishWithLeastPrepTime",
            "text": [
                "Find the dish with the least prep time.",
                "Find the dish with the least prep time for appetizers.",
                "Find the dish with the least prep time for A2B combos."
            ],
            "responses": [
                "The dish with the least prep time is [dish name] with a prep time of [prep time] minutes."
            ],
            "context": {
                "in": "",
                "out": "LeastPrepTimeDish",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "filtered_items = [item for item in menu_data if item.get('subCategory', '').lower() == subcategory.lower()]\nif filtered_items:\n    min_item = min(filtered_items, key=lambda x: int(x['prepTimeInMins']))\n    return min_item\nreturn None"
        },
        {
            "intent": "RetrieveDishDescription",
            "text": [
                "What is the description of the [dish name]?",
                "What is the description of [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [description]."
            ],
            "context": {
                "in": "",
                "out": "DishDescription",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishAllergicInfo",
            "text": [
                "What is the allergic info for [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [allergicInfo]."
            ],
            "context": {
                "in": "",
                "out": "DishAllergicInfo",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishPrice",
            "text": [
                "How much does the [dish name] cost?"
            ],
            "responses": [
                "The price of [dish name] is $[price]."
            ],
            "context": {
                "in": "",
                "out": "DishPrice",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        }
    ]
}
# Create dataset object
train_dataset = IntentDataset(tokenizer, intents_data)

# Training arguments - these control how training is performed
training_args = TrainingArguments(
    output_dir='./results',          # Output directory to save checkpoints and model
    num_train_epochs=3,              # Number of training epochs
    per_device_train_batch_size=8,   # Batch size per device during training
    warmup_steps=500,                # Number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # Strength of weight decay
    logging_dir='./logs',            # Directory to store logs
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="no",        # We don't need evaluation right now
)

# Initialize the Trainer for the fine-tuning process
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

# Start the fine-tuning process
trainer.train()

# Save the fine-tuned model
trainer.save_model('./fine_tuned_distilgpt2')
tokenizer.save_pretrained('./fine_tuned_distilgpt2')


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/9 [05:52<?, ?it/s]
                                             
100%|██████████| 9/9 [00:10<00:00,  1.17s/it]


{'train_runtime': 10.4386, 'train_samples_per_second': 4.886, 'train_steps_per_second': 0.862, 'train_loss': 2.659591462877062, 'epoch': 3.0}


('./fine_tuned_distilgpt2/tokenizer_config.json',
 './fine_tuned_distilgpt2/special_tokens_map.json',
 './fine_tuned_distilgpt2/vocab.json',
 './fine_tuned_distilgpt2/merges.txt',
 './fine_tuned_distilgpt2/added_tokens.json')

# Testing

In [1]:
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification
import torch

# Load the fine-tuned model and tokenizer
model = GPT2ForSequenceClassification.from_pretrained('./fine_tuned_distilgpt2')
tokenizer = GPT2Tokenizer.from_pretrained('./fine_tuned_distilgpt2')

# Define the intents data (make sure this matches what you trained on)
intents_data = {
    "intents": [
        {
            "intent": "GetSpicyDishesForFever",
            "text": [
                "Can you tell me about the spicy dishes?",
                "Any spicy dishes recommendation?",
                "Spicy dishes?"
            ],
            "responses": [
                "Here are some spicy dishes suitable for fever: [list of spicy dish names]."
            ],
            "context": {
                "in": "",
                "out": "SpicyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "spicy_dishes = []\nfor item in menu_data:\n    description = item.get('description', '')\n    if 'spicy' in description.lower():\n        spicy_dishes.append(item)\nreturn spicy_dishes"
        },
        {
            "intent": "GetKidsFriendlyDishes",
            "text": [
                "Which dishes are kids-friendly?",
                "Which dishes are kids friendly?",
                "Kids friendly dishes recommendation?"
            ],
            "responses": [
                "Here are some kids-friendly dishes: [list of kids-friendly dish names]."
            ],
            "context": {
                "in": "",
                "out": "KidsFriendlyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "kids_friendly_dishes = [item for item in menu_data if item.get('kidsFriendly') == True]\nreturn kids_friendly_dishes"
        },
        {
            "intent": "GetVeganDishes",
            "text": [
                "What are the vegan options?",
                "Vegan options recommendation?"
            ],
            "responses": [
                "Here are some vegan dishes: [list of vegan dish names]."
            ],
            "context": {
                "in": "",
                "out": "VeganDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "vegan_dishes = []\nfor item in menu_data:\n    item_filters = item.get('itemFilter', [])\n    is_vegan = any(\n        filter_item.get('name', '').lower() == 'vegan'\n        for filter_item in item_filters\n    )\n    if is_vegan:\n        vegan_dishes.append(item)\nreturn vegan_dishes"
        },
        {
            "intent": "GetNutFreeDishes",
            "text": [
                "List the nut-free dishes."
            ],
            "responses": [
                "Here are some nut-free dishes: [list of nut-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "NutFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "nut_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'nuts' not in item_info:\n        nut_free_dishes.append(item)\nreturn nut_free_dishes"
        },
        {
            "intent": "GetFishFreeDishes",
            "text": [
                "List the fish-free dishes."
            ],
            "responses": [
                "Here are some fish-free dishes: [list of fish-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "FishFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "fish_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'fish' not in item_info:\n        fish_free_dishes.append(item)\nreturn fish_free_dishes"
        },
        {
            "intent": "FindDishWithLeastPrepTime",
            "text": [
                "Find the dish with the least prep time.",
                "Find the dish with the least prep time for appetizers.",
                "Find the dish with the least prep time for A2B combos."
            ],
            "responses": [
                "The dish with the least prep time is [dish name] with a prep time of [prep time] minutes."
            ],
            "context": {
                "in": "",
                "out": "LeastPrepTimeDish",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "filtered_items = [item for item in menu_data if item.get('subCategory', '').lower() == subcategory.lower()]\nif filtered_items:\n    min_item = min(filtered_items, key=lambda x: int(x['prepTimeInMins']))\n    return min_item\nreturn None"
        },
        {
            "intent": "RetrieveDishDescription",
            "text": [
                "What is the description of the [dish name]?",
                "What is the description of [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [description]."
            ],
            "context": {
                "in": "",
                "out": "DishDescription",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishAllergicInfo",
            "text": [
                "What is the allergic info for [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [allergicInfo]."
            ],
            "context": {
                "in": "",
                "out": "DishAllergicInfo",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishPrice",
            "text": [
                "How much does the [dish name] cost?"
            ],
            "responses": [
                "The price of [dish name] is $[price]."
            ],
            "context": {
                "in": "",
                "out": "DishPrice",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        }
    ]
}

# List of intent labels
intent_labels = [intent['intent'] for intent in intents_data['intents']]

# Function to recognize the intent
def recognize_intent(user_input):
    # Tokenize the input
    inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True, max_length=128)

    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get the predicted class (the index of the maximum logit)
    predicted_class_id = torch.argmax(logits, dim=-1).item()

    # Map the predicted class to the intent label
    predicted_intent = intent_labels[predicted_class_id]
    
    return predicted_intent



  from .autonotebook import tqdm as notebook_tqdm


Recognized Intent: GetSpicyDishesForFever


# Question 1

In [10]:
# Example usage
user_input = "Can you tell me about the spicy dishes?"
predicted_intent = recognize_intent(user_input)
print(f"Recognized Intent: {predicted_intent}")

Recognized Intent: GetSpicyDishesForFever


# Question 2

In [7]:
# Example usage
user_input2 = "Kids friendly dishes recommendation?"
predicted_intent2 = recognize_intent(user_input2)
print(f"Recognized Intent: {predicted_intent2}")


Recognized Intent: GetSpicyDishesForFever


# Question 3

In [9]:
# Example usage
user_input3 = "Find the dish with the least prep time."
predicted_intent3 = recognize_intent(user_input3)
print(f"Recognized Intent: {predicted_intent3}")

Recognized Intent: GetSpicyDishesForFever
