In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

# Load pre-trained T5 model and tokenizer
model_name = 't5-small'  # You can use 't5-base' or 't5-large' for larger models
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Function to recognize intent using T5
def recognize_intent(user_input, intents_data):
    # Add a prefix for the task to help the model understand the task
    input_text = f"intent recognition: {user_input}"
    
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    
    # Generate output (predicted intent)
    with torch.no_grad():
        outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=10)
        predicted_intent = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return predicted_intent

# Adjusted intents structure
intents_data = {
    "intents": [
        {
            "intent": "GetSpicyDishesForFever",
            "text": [
                "Can you tell me about the spicy dishes?",
                "Any spicy dishes recommendation?",
                "Spicy dishes?"
            ],
            "responses": [
                "Here are some spicy dishes suitable for fever: [list of spicy dish names]."
            ],
            "context": {
                "in": "",
                "out": "SpicyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "spicy_dishes = []\nfor item in menu_data:\n    description = item.get('description', '')\n    if 'spicy' in description.lower():\n        spicy_dishes.append(item)\nreturn spicy_dishes"
        },
        {
            "intent": "GetKidsFriendlyDishes",
            "text": [
                "Which dishes are kids-friendly?",
                "Which dishes are kids friendly?",
                "Kids friendly dishes recommendation?"
            ],
            "responses": [
                "Here are some kids-friendly dishes: [list of kids-friendly dish names]."
            ],
            "context": {
                "in": "",
                "out": "KidsFriendlyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "kids_friendly_dishes = [item for item in menu_data if item.get('kidsFriendly') == True]\nreturn kids_friendly_dishes"
        },
        {
            "intent": "GetVeganDishes",
            "text": [
                "What are the vegan options?",
                "Vegan options recommendation?"
            ],
            "responses": [
                "Here are some vegan dishes: [list of vegan dish names]."
            ],
            "context": {
                "in": "",
                "out": "VeganDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "vegan_dishes = []\nfor item in menu_data:\n    item_filters = item.get('itemFilter', [])\n    is_vegan = any(\n        filter_item.get('name', '').lower() == 'vegan'\n        for filter_item in item_filters\n    )\n    if is_vegan:\n        vegan_dishes.append(item)\nreturn vegan_dishes"
        },
        {
            "intent": "GetNutFreeDishes",
            "text": [
                "List the nut-free dishes."
            ],
            "responses": [
                "Here are some nut-free dishes: [list of nut-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "NutFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "nut_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'nuts' not in item_info:\n        nut_free_dishes.append(item)\nreturn nut_free_dishes"
        },
        {
            "intent": "GetFishFreeDishes",
            "text": [
                "List the fish-free dishes."
            ],
            "responses": [
                "Here are some fish-free dishes: [list of fish-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "FishFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "fish_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'fish' not in item_info:\n        fish_free_dishes.append(item)\nreturn fish_free_dishes"
        },
        {
            "intent": "FindDishWithLeastPrepTime",
            "text": [
                "Find the dish with the least prep time.",
                "Find the dish with the least prep time for appetizers.",
                "Find the dish with the least prep time for A2B combos."
            ],
            "responses": [
                "The dish with the least prep time is [dish name] with a prep time of [prep time] minutes."
            ],
            "context": {
                "in": "",
                "out": "LeastPrepTimeDish",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "filtered_items = [item for item in menu_data if item.get('subCategory', '').lower() == subcategory.lower()]\nif filtered_items:\n    min_item = min(filtered_items, key=lambda x: int(x['prepTimeInMins']))\n    return min_item\nreturn None"
        },
        {
            "intent": "RetrieveDishDescription",
            "text": [
                "What is the description of the [dish name]?",
                "What is the description of [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [description]."
            ],
            "context": {
                "in": "",
                "out": "DishDescription",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishAllergicInfo",
            "text": [
                "What is the allergic info for [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [allergicInfo]."
            ],
            "context": {
                "in": "",
                "out": "DishAllergicInfo",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishPrice",
            "text": [
                "How much does the [dish name] cost?"
            ],
            "responses": [
                "The price of [dish name] is $[price]."
            ],
            "context": {
                "in": "",
                "out": "DishPrice",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        }
    ]
}

# Example usage
user_input = "Can you tell me about the spicy dishes?"
recognized_intent = recognize_intent(user_input, intents_data)
print(f"Recognized Intent: {recognized_intent}")


Recognized Intent: Können Sie mir sagen über die spicy dishes


In [5]:
# Example usage
user_input = "How much does keeari vadai cost?"
recognized_intent = recognize_intent(user_input, intents_data)
print(f"Recognized Intent: {recognized_intent}")

Recognized Intent: keeari vadai cost cost


# Training Intent on T5

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments

# Load pre-trained T5 tokenizer and model
model_name = 't5-small'  # You can also use 't5-base' or 't5-large'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Custom dataset class
class IntentDataset(Dataset):
    def __init__(self, intents_data, tokenizer, max_length=128):
        self.intents_data = intents_data['intents']
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.inputs = []
        self.targets = []
        self._build()

    def _build(self):
        # Create input-output pairs
        for intent in self.intents_data:
            for query in intent['text']:
                input_text = f"intent recognition: {query}"
                target_text = intent['intent']
                self.inputs.append(input_text)
                self.targets.append(target_text)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = self.inputs[idx]
        target_text = self.targets[idx]
        input_enc = self.tokenizer(input_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
        target_enc = self.tokenizer(target_text, padding='max_length', truncation=True, max_length=10, return_tensors="pt")  # Short max_length for intent

        input_ids = input_enc.input_ids.squeeze()
        attention_mask = input_enc.attention_mask.squeeze()
        labels = target_enc.input_ids.squeeze()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Prepare intents data (your JSON structure)
intents_data = {
    "intents": [
        {
            "intent": "GetSpicyDishesForFever",
            "text": [
                "Can you tell me about the spicy dishes?",
                "Any spicy dishes recommendation?",
                "Spicy dishes?"
            ],
            "responses": [
                "Here are some spicy dishes suitable for fever: [list of spicy dish names]."
            ],
            "context": {
                "in": "",
                "out": "SpicyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "spicy_dishes = []\nfor item in menu_data:\n    description = item.get('description', '')\n    if 'spicy' in description.lower():\n        spicy_dishes.append(item)\nreturn spicy_dishes"
        },
        {
            "intent": "GetKidsFriendlyDishes",
            "text": [
                "Which dishes are kids-friendly?",
                "Which dishes are kids friendly?",
                "Kids friendly dishes recommendation?"
            ],
            "responses": [
                "Here are some kids-friendly dishes: [list of kids-friendly dish names]."
            ],
            "context": {
                "in": "",
                "out": "KidsFriendlyDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "kids_friendly_dishes = [item for item in menu_data if item.get('kidsFriendly') == True]\nreturn kids_friendly_dishes"
        },
        {
            "intent": "GetVeganDishes",
            "text": [
                "What are the vegan options?",
                "Vegan options recommendation?"
            ],
            "responses": [
                "Here are some vegan dishes: [list of vegan dish names]."
            ],
            "context": {
                "in": "",
                "out": "VeganDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "vegan_dishes = []\nfor item in menu_data:\n    item_filters = item.get('itemFilter', [])\n    is_vegan = any(\n        filter_item.get('name', '').lower() == 'vegan'\n        for filter_item in item_filters\n    )\n    if is_vegan:\n        vegan_dishes.append(item)\nreturn vegan_dishes"
        },
        {
            "intent": "GetNutFreeDishes",
            "text": [
                "List the nut-free dishes."
            ],
            "responses": [
                "Here are some nut-free dishes: [list of nut-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "NutFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "nut_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'nuts' not in item_info:\n        nut_free_dishes.append(item)\nreturn nut_free_dishes"
        },
        {
            "intent": "GetFishFreeDishes",
            "text": [
                "List the fish-free dishes."
            ],
            "responses": [
                "Here are some fish-free dishes: [list of fish-free dish names]."
            ],
            "context": {
                "in": "",
                "out": "FishFreeDishesList",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "fish_free_dishes = []\nfor item in menu_data:\n    item_info = f\"{item.get('description', '')} {item.get('allergicInfo', '')} {item.get('itemName', '')}\".lower()\n    if 'fish' not in item_info:\n        fish_free_dishes.append(item)\nreturn fish_free_dishes"
        },
        {
            "intent": "FindDishWithLeastPrepTime",
            "text": [
                "Find the dish with the least prep time.",
                "Find the dish with the least prep time for appetizers.",
                "Find the dish with the least prep time for A2B combos."
            ],
            "responses": [
                "The dish with the least prep time is [dish name] with a prep time of [prep time] minutes."
            ],
            "context": {
                "in": "",
                "out": "LeastPrepTimeDish",
                "clear": False
            },
            "entityType": "NA",
            "entities": [],
            "query": "filtered_items = [item for item in menu_data if item.get('subCategory', '').lower() == subcategory.lower()]\nif filtered_items:\n    min_item = min(filtered_items, key=lambda x: int(x['prepTimeInMins']))\n    return min_item\nreturn None"
        },
        {
            "intent": "RetrieveDishDescription",
            "text": [
                "What is the description of the [dish name]?",
                "What is the description of [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [description]."
            ],
            "context": {
                "in": "",
                "out": "DishDescription",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishAllergicInfo",
            "text": [
                "What is the allergic info for [dish name]?"
            ],
            "responses": [
                "[dish name] is described as: [allergicInfo]."
            ],
            "context": {
                "in": "",
                "out": "DishAllergicInfo",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        },
        {
            "intent": "RetrieveDishPrice",
            "text": [
                "How much does the [dish name] cost?"
            ],
            "responses": [
                "The price of [dish name] is $[price]."
            ],
            "context": {
                "in": "",
                "out": "DishPrice",
                "clear": False
            },
            "entityType": "DishName",
            "entities": [
                {
                    "entity": "DishName",
                    "rangeFrom": 5,
                    "rangeTo": 6
                }
            ],
            "query": "for item in menu_data:\n    if item['itemName'].lower() in query.lower():\n        return item\nreturn None"
        }
    ]
}

# Create dataset and data loader
dataset = IntentDataset(intents_data, tokenizer)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          
    per_device_train_batch_size=4,  
    num_train_epochs=3,              
    logging_dir='./logs',            
    logging_steps=10,
    save_steps=1000,
    evaluation_strategy="no",
    save_total_limit=1,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,                        
    args=training_args,                  
    train_dataset=dataset                
)

# Train the model
trainer.train()

# Save the fine-tuned model
model.save_pretrained('./fine_tuned_t5')
tokenizer.save_pretrained('./fine_tuned_t5')


 73%|███████▎  | 11/15 [00:03<00:00,  5.55it/s]

{'loss': 7.7775, 'grad_norm': 97.10293579101562, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}


100%|██████████| 15/15 [00:05<00:00,  2.87it/s]

{'train_runtime': 5.2347, 'train_samples_per_second': 9.743, 'train_steps_per_second': 2.865, 'train_loss': 7.641370391845703, 'epoch': 3.0}





('./fine_tuned_t5/tokenizer_config.json',
 './fine_tuned_t5/special_tokens_map.json',
 './fine_tuned_t5/spiece.model',
 './fine_tuned_t5/added_tokens.json')

# Testing

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load the fine-tuned model and tokenizer
model_path = './fine_tuned_t5'  # Path to the saved fine-tuned model
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(model_path)

# Function to recognize intent using the fine-tuned model
def recognize_intent(user_input):
    # Prepare the input in the format that was used during training
    input_text = f"intent recognition: {user_input}"
    
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    
    # Use the model to generate the predicted intent
    outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=10)
    
    # Decode the generated output to get the predicted intent
    predicted_intent = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return predicted_intent

# Example usage
user_input = "Can you tell me about the spicy dishes?"
recognized_intent = recognize_intent(user_input)
print(f"Recognized Intent: {recognized_intent}")

  from .autonotebook import tqdm as notebook_tqdm


Recognized Intent: Können Sie mir sagen über die spicy dishes
