In [1]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from datasets import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
training_data = [
    {"query": "what was the earning of the movie Titanic", "tag": "qna"},
    {"query": "how much did Avatar make", "tag": "qna"},
    {"query": "box office collection of Jurassic Park", "tag": "qna"},
    {"query": "total revenue of Inception", "tag": "qna"},
    {"query": "what was the gross of Frozen", "tag": "qna"},
    {"query": "how much did Avengers Endgame earn", "tag": "qna"},
    {"query": "earnings of The Dark Knight", "tag": "qna"},
    {"query": "how much money did Titanic gross", "tag": "qna"},
    {"query": "how much revenue did Star Wars make", "tag": "qna"},
    {"query": "revenue of Iron Man", "tag": "qna"},
    {"query": "who directed Titanic", "tag": "qna"},
    {"query": "who was the director of Inception", "tag": "qna"},
    {"query": "who directed the movie Avatar", "tag": "qna"},
    {"query": "who directed Star Wars", "tag": "qna"},
    {"query": "who was the director of The Godfather", "tag": "qna"},
    {"query": "director of The Dark Knight", "tag": "qna"},
    {"query": "who directed Forrest Gump", "tag": "qna"},
    {"query": "who is the director of Jurassic Park", "tag": "qna"},
    {"query": "who directed Schindler's List", "tag": "qna"},
    {"query": "director of The Matrix", "tag": "qna"},
    {"query": "who played the lead role in Titanic", "tag": "qna"},
    {"query": "who acted in The Dark Knight", "tag": "qna"},
    {"query": "lead actor in Inception", "tag": "qna"},
    {"query": "who starred in Avatar", "tag": "qna"},
    {"query": "who was the main actor in The Godfather", "tag": "qna"},
    {"query": "lead actor of Iron Man", "tag": "qna"},
    {"query": "who played in The Matrix", "tag": "qna"},
    {"query": "who was the actor in Star Wars", "tag": "qna"},
    {"query": "who acted in Interstellar", "tag": "qna"},
    {"query": "who starred in Pulp Fiction", "tag": "qna"},
    {"query": "when was Titanic released", "tag": "qna"},
    {"query": "release date of Avatar", "tag": "qna"},
    {"query": "when was Inception released", "tag": "qna"},
    {"query": "release year of Star Wars", "tag": "qna"},
    {"query": "when did The Dark Knight come out", "tag": "qna"},
    {"query": "when was Schindler's List released", "tag": "qna"},
    {"query": "release date of The Matrix", "tag": "qna"},
    {"query": "release year of The Godfather", "tag": "qna"},
    {"query": "when did Forrest Gump release", "tag": "qna"},
    {"query": "when was Jurassic Park released", "tag": "qna"},
    {"query": "what genre is Titanic", "tag": "qna"},
    {"query": "genre of Inception", "tag": "qna"},
    {"query": "what type of movie is The Godfather", "tag": "qna"},
    {"query": "what genre is Star Wars", "tag": "qna"},
    {"query": "what genre is The Dark Knight", "tag": "qna"},
    {"query": "genre of Avatar", "tag": "qna"},
    {"query": "what type of movie is Jurassic Park", "tag": "qna"},
    {"query": "what genre is The Matrix", "tag": "qna"},
    {"query": "what type of movie is Schindler's List", "tag": "qna"},
    {"query": "genre of Pulp Fiction", "tag": "qna"},
    {"query": "what is the rating of Titanic", "tag": "qna"},
    {"query": "rating of Inception", "tag": "qna"},
    {"query": "how was Star Wars rated", "tag": "qna"},
    {"query": "rating of The Dark Knight", "tag": "qna"},
    {"query": "what rating does The Matrix have", "tag": "qna"},
    {"query": "rating of Avatar", "tag": "qna"},
    {"query": "what rating does The Godfather have", "tag": "qna"},
    {"query": "rating of Forrest Gump", "tag": "qna"},
    {"query": "how was Pulp Fiction rated", "tag": "qna"},
    {"query": "what is the rating of Schindler's List", "tag": "qna"},
    {"query": "what was the budget of Titanic", "tag": "qna"},
    {"query": "budget of Inception", "tag": "qna"},
    {"query": "how much was spent on Avatar", "tag": "qna"},
    {"query": "budget for The Dark Knight", "tag": "qna"},
    {"query": "how much did Star Wars cost", "tag": "qna"},
    {"query": "what was the budget of The Godfather", "tag": "qna"},
    {"query": "how much did The Matrix cost", "tag": "qna"},
    {"query": "budget for Jurassic Park", "tag": "qna"},
    {"query": "how much did Pulp Fiction cost", "tag": "qna"},
    {"query": "budget of Forrest Gump", "tag": "qna"},
    {"query": "who produced Titanic", "tag": "qna"},
    {"query": "who was the producer of Inception", "tag": "qna"},
    {"query": "producer of Avatar", "tag": "qna"},
    {"query": "who produced The Dark Knight", "tag": "qna"},
    {"query": "who was the producer of Star Wars", "tag": "qna"},
    {"query": "who produced The Matrix", "tag": "qna"},
    {"query": "producer of The Godfather", "tag": "qna"},
    {"query": "who produced Jurassic Park", "tag": "qna"},
    {"query": "producer of Schindler's List", "tag": "qna"},
    {"query": "who produced Pulp Fiction", "tag": "qna"},
    {"query": "who wrote the screenplay for Titanic", "tag": "qna"},
    {"query": "screenwriter of Inception", "tag": "qna"},
    {"query": "who wrote the screenplay for The Dark Knight", "tag": "qna"},
    {"query": "who wrote Star Wars", "tag": "qna"},
    {"query": "who was the screenwriter of Avatar", "tag": "qna"},
    {"query": "screenwriter of The Matrix", "tag": "qna"},
    {"query": "who wrote The Godfather screenplay", "tag": "qna"},
    {"query": "who wrote the script for Pulp Fiction", "tag": "qna"},
    {"query": "who was the screenwriter of Schindler's List", "tag": "qna"},
    {"query": "who wrote the screenplay for Jurassic Park", "tag": "qna"},
    {"query": "when was Leonardo DiCaprio born", "tag": "qna"},
    {"query": "birth date of James Cameron", "tag": "qna"},
    {"query": "when was Robert Downey Jr born", "tag": "qna"},
    {"query": "birth date of Steven Spielberg", "tag": "qna"},
    {"query": "when was Christopher Nolan born", "tag": "qna"},
    {"query": "when was Al Pacino born", "tag": "qna"},
    {"query": "birth date of Keanu Reeves", "tag": "qna"},
    {"query": "when was Quentin Tarantino born", "tag": "qna"},
    {"query": "birth date of Harrison Ford", "tag": "qna"},
    {"query": "when was Samuel L. Jackson born", "tag": "qna"},
    {"query": "where was Leonardo DiCaprio born", "tag": "qna"},
    {"query": "birthplace of James Cameron", "tag": "qna"},
    {"query": "where was Robert Downey Jr born", "tag": "qna"},
    {"query": "where was Steven Spielberg born", "tag": "qna"},
    {"query": "birthplace of Christopher Nolan", "tag": "qna"},
    {"query": "where was Al Pacino born", "tag": "qna"},
    {"query": "where was Keanu Reeves born", "tag": "qna"},
    {"query": "birthplace of Quentin Tarantino", "tag": "qna"},
    {"query": "where was Harrison Ford born", "tag": "qna"},
    {"query": "where was Samuel L. Jackson born", "tag": "qna"},
    {"query":"Can you suggest movies like The Dark Knight?", "tag":"recommendation"},
    {"query":"I loved Inception and Interstellar. What else should I watch?", "tag":"recommendation"},
    {"query":"Give me movie recommendations similar to Toy Story and Finding Nemo.", "tag":"recommendation"},
    {"query":"Recommend some movies if I enjoyed Titanic.", "tag":"recommendation"},
    {"query":"I recently watched Joker, what else might I enjoy?", "tag":"recommendation"},
    {"query":"What are some movies like Avatar and The Matrix?", "tag":"recommendation"},
    {"query":"Suggest films that are similar to Parasite.", "tag":"recommendation"},
    {"query":"I loved Shutter Island and Gone Girl. What do you recommend?", "tag":"recommendation"},
    {"query":"If I liked Star Wars and Guardians of the Galaxy, what other movies should I watch?", "tag":"recommendation"},
    {"query":"What are some great movies for fans of The Lord of the Rings?", "tag":"recommendation"},
    {"query":"Can you recommend movies like The Godfather?", "tag":"recommendation"},
    {"query":"I enjoyed Black Panther. Can you suggest some similar films?", "tag":"recommendation"},
    {"query":"What movies should I watch if I liked Forrest Gump?", "tag":"recommendation"},
    {"query":"Are there any films similar to Pulp Fiction and Kill Bill?", "tag":"recommendation"},
    {"query":"Recommend some good films like Harry Potter.", "tag":"recommendation"},
    {"query":"I just watched The Grand Budapest Hotel. What else would I enjoy?", "tag":"recommendation"},
    {"query":"Suggest movies that are similar to La La Land and Whiplash.", "tag":"recommendation"},
    {"query":"What films are similar to Avengers: Endgame?", "tag":"recommendation"},
    {"query":"I enjoyed The Hunger Games and Divergent. What should I watch next?", "tag":"recommendation"},
    {"query":"Can you give me a list of movies like Coco?", "tag":"recommendation"},
    {"query":"If I liked The Lion King, what other films might I enjoy?", "tag":"recommendation"},
    {"query":"What are some movies like The Silence of the Lambs?", "tag":"recommendation"},
    {"query":"Can you suggest films for fans of Frozen?", "tag":"recommendation"},
    {"query":"I loved The Social Network. Can you recommend similar movies?", "tag":"recommendation"},
    {"query":"What films should I check out if I liked The Shawshank Redemption?", "tag":"recommendation"},
    {"query":"Recommend some movies similar to Mad Max: Fury Road.", "tag":"recommendation"},
    {"query":"What should I watch if I loved 1917 and Dunkirk?", "tag":"recommendation"},
    {"query":"Can you recommend movies like Knives Out?", "tag":"recommendation"},
    {"query":"I’m looking for movies like Spirited Away and My Neighbor Totoro. Any ideas?", "tag":"recommendation"},
    {"query":"What are some great films similar to The Big Short?", "tag":"recommendation"},
    {"query":"I enjoyed Deadpool. Can you suggest similar movies?", "tag":"recommendation"},
    {"query":"Can you recommend movies like The Wolf of Wall Street?", "tag":"recommendation"},
    {"query":"What movies are good if I liked Fight Club?", "tag":"recommendation"},
    {"query":"I liked A Quiet Place and Bird Box. Any recommendations?", "tag":"recommendation"},
    {"query":"Suggest films similar to The Revenant.", "tag":"recommendation"},
    {"query":"What should I watch if I enjoyed Moonlight?", "tag":"recommendation"},
    {"query":"Can you suggest movies like The Perks of Being a Wallflower?", "tag":"recommendation"},
    {"query":"What are some movies like Up and Wall-E?", "tag":"recommendation"},
    {"query":"I loved Good Will Hunting. Any similar recommendations?", "tag":"recommendation"},
    {"query":"If I enjoyed The Shape of Water, what else might I like?", "tag":"recommendation"},
    {"query":"What films are similar to The Irishman?", "tag":"recommendation"},
    {"query":"Can you recommend some good movies like Gladiator?", "tag":"recommendation"},
    {"query":"I loved The Notebook. What other romantic films should I watch?", "tag":"recommendation"},
    {"query":"What are some great animated films like How to Train Your Dragon?", "tag":"recommendation"},
    {"query":"Suggest movies similar to The Sixth Sense.", "tag":"recommendation"},
    {"query":"I enjoyed Jumanji and Zathura. Any recommendations?", "tag":"recommendation"},
    {"query":"What should I watch if I loved The Greatest Showman?", "tag":"recommendation"},
    {"query":"Can you recommend movies like The Departed and Heat?", "tag":"recommendation"},
    {"query":"What films are similar to Pride and Prejudice?", "tag":"recommendation"},
    {"query":"I liked Crazy Rich Asians. Any similar films?", "tag":"recommendation"},
    {"query":"Suggest movies similar to The Fault in Our Stars.", "tag":"recommendation"},
    {"query":"What are some great spy films like Skyfall and Mission Impossible?", "tag":"recommendation"},
    {"query":"If I enjoyed Monsters, Inc. and Despicable Me, what else might I enjoy?", "tag":"recommendation"},
    {"query":"Can you recommend movies like The Bourne Identity?", "tag":"recommendation"},
    {"query":"I enjoyed Inside Out and Soul. What else is similar?", "tag":"recommendation"},
    {"query":"What are some films like The Girl with the Dragon Tattoo?", "tag":"recommendation"},
    {"query":"Can you suggest movies for fans of Rocky and Creed?", "tag":"recommendation"},
    {"query":"I loved The Blind Side. Any recommendations?", "tag":"recommendation"},
    {"query":"What should I watch if I liked The Imitation Game?", "tag":"recommendation"},
    {"query":"Recommend some good musicals like Mamma Mia!", "tag":"recommendation"},
    {"query":"I enjoyed The Lego Movie. What similar movies are out there?", "tag":"recommendation"},
    {"query":"What movies are great if I liked 10 Things I Hate About You?", "tag":"recommendation"},
    {"query":"I loved Enola Holmes. What other mystery films should I watch?", "tag":"recommendation"},
    {"query":"Can you recommend movies like The Conjuring and Annabelle?", "tag":"recommendation"},
    {"query":"What films are similar to The Pianist?", "tag":"recommendation"},
    {"query":"I liked Coco and Ratatouille. What else might I enjoy?", "tag":"recommendation"},
    {"query":"Suggest films similar to Slumdog Millionaire.", "tag":"recommendation"},
    {"query":"I loved The Secret Life of Walter Mitty. What else would you recommend?", "tag":"recommendation"},
    {"query":"What movies are good for fans of The Great Gatsby?", "tag":"recommendation"},
    {"query":"If I enjoyed Catch Me If You Can, what should I watch next?", "tag":"recommendation"},
    {"query":"Recommend some movies like The Twilight Saga.", "tag":"recommendation"},
    {"query":"I liked Coraline. What other dark animated films should I watch?", "tag":"recommendation"},
    {"query":"What are some films similar to The Good, the Bad and the Ugly?", "tag":"recommendation"},
    {"query":"Can you recommend movies like Casablanca?", "tag":"recommendation"},
    {"query":"What should I watch if I enjoyed The Truman Show?", "tag":"recommendation"},
    {"query":"I loved A Star is Born. Any similar suggestions?", "tag":"recommendation"},
    {"query":"What are some films similar to Little Women?", "tag":"recommendation"},
    {"query":"Can you recommend movies like Train to Busan?", "tag":"recommendation"},
    {"query":"I enjoyed Hidden Figures. What else might I like?", "tag":"recommendation"},
    {"query":"What films are similar to 500 Days of Summer?", "tag":"recommendation"},
    {"query": "Show me a picture of Halle Berry", "tag": "image"},
    {"query": "What does Denzel Washington look like", "tag": "image"},
    {"query": "What does Sandra Bullock look like", "tag": "image"},
    {"query": "Show me a poster of Moana", "tag": "image"},
    {"query": "How did Iron Man look?", "tag": "image"},
    {"query": "Give me a glimpse of Harry Potter", "tag": "image"},
    {"query": "Can I see a photo of Tom Hanks?", "tag": "image"},
    {"query": "Show me Scarlett Johansson's portrait", "tag": "image"},
    {"query": "Find a poster of The Godfather", "tag": "image"},
    {"query": "Display an image of Leonardo DiCaprio", "tag": "image"},
    {"query": "What does Gal Gadot look like?", "tag": "image"},
    {"query": "Show me the Avengers Endgame poster", "tag": "image"},
    {"query": "Find a picture of Meryl Streep", "tag": "image"},
    {"query": "What does the Joker look like in the 2019 movie?", "tag": "image"},
    {"query": "Show me a still of Black Panther", "tag": "image"},
    {"query": "Can you display a picture of Chris Hemsworth?", "tag": "image"},
    {"query": "What does Ryan Gosling look like?", "tag": "image"},
    {"query": "Find me a photo of Jennifer Lawrence", "tag": "image"},
    {"query": "Show me the Frozen movie poster", "tag": "image"},
    {"query": "Display an image of the Minions", "tag": "image"},
    {"query": "What does Keanu Reeves look like?", "tag": "image"},
    {"query": "Find a picture of the movie Titanic", "tag": "image"},
    {"query": "Show me what Emma Watson looks like", "tag": "image"},
    {"query": "What does Daniel Radcliffe look like?", "tag": "image"},
    {"query": "Display a poster of Star Wars", "tag": "image"},
    {"query": "Can you find an image of Christian Bale?", "tag": "image"},
    {"query": "What does Zoe Saldana look like?", "tag": "image"},
    {"query": "Show me a still from Pirates of the Caribbean", "tag": "image"},
    {"query": "Find a poster of the movie Inception", "tag": "image"},
    {"query": "What does Robert Downey Jr. look like?", "tag": "image"},
    {"query": "Can you display a photo of Anne Hathaway?", "tag": "image"},
    {"query": "Show me the Toy Story movie poster", "tag": "image"},
    {"query": "Find a picture of Vin Diesel", "tag": "image"},
    {"query": "Display a poster of the movie Interstellar", "tag": "image"},
    {"query": "What does Cate Blanchett look like?", "tag": "image"},
    {"query": "Can you show me a picture of the Hulk?", "tag": "image"},
    {"query": "Find me a photo of Hugh Jackman", "tag": "image"},
    {"query": "Show me a poster of Shrek", "tag": "image"},
    {"query": "What does Margot Robbie look like?", "tag": "image"},
    {"query": "Display an image of the Fast and Furious movie", "tag": "image"},
    {"query": "Can I see a picture of Johnny Depp?", "tag": "image"},
    {"query": "Find a poster of the movie Aladdin", "tag": "image"},
    {"query": "What does Charlize Theron look like?", "tag": "image"},
    {"query": "Show me a still of Wonder Woman", "tag": "image"},
    {"query": "Can you display a photo of Mark Ruffalo?", "tag": "image"},
    {"query": "Find me a picture of Angelina Jolie", "tag": "image"},
    {"query": "Display a poster of Finding Nemo", "tag": "image"},
    {"query": "What does Matt Damon look like?", "tag": "image"},
    {"query": "Show me a poster of Guardians of the Galaxy", "tag": "image"},
    {"query": "Find a picture of Will Smith", "tag": "image"},
    {"query": "Display an image of the character Darth Vader", "tag": "image"},
    {"query": "Can you show me a picture of Mila Kunis?", "tag": "image"},
    {"query": "What does Morgan Freeman look like?", "tag": "image"},
    {"query": "Find a poster of Frozen 2", "tag": "image"},
    {"query": "Show me what Chris Pratt looks like", "tag": "image"},
    {"query": "Can I see a picture of Benedict Cumberbatch?", "tag": "image"},
    {"query": "What does Nicole Kidman look like?", "tag": "image"},
    {"query": "Show me the Batman movie poster", "tag": "image"},
    {"query": "Find a still of Captain America", "tag": "image"},
    {"query": "Display an image of Julia Roberts", "tag": "image"},
    {"query": "What does Idris Elba look like?", "tag": "image"},
    {"query": "Show me a poster of The Lion King", "tag": "image"},
    {"query": "Can you display a photo of Tom Cruise?", "tag": "image"},
    {"query": "Find me a picture of Natalie Portman", "tag": "image"},
    {"query": "What does Joaquin Phoenix look like?", "tag": "image"},
    {"query": "Show me a poster of Coco", "tag": "image"},
    {"query": "Display an image of Eddie Murphy", "tag": "image"},
    {"query": "Find a picture of Bruce Willis", "tag": "image"},
    {"query": "What does Kristen Stewart look like?", "tag": "image"},
    {"query": "Can you show me a poster of Spider-Man?", "tag": "image"},
    {"query": "Find me a photo of Brad Pitt", "tag": "image"},
    {"query": "Show me the Frozen poster with Olaf", "tag": "image"},
    {"query": "What does Vin Diesel's character in Fast and Furious look like?", "tag": "image"},
    {"query": "Display an image of the movie Ratatouille", "tag": "image"},
    {"query": "Can I see a photo of Amy Adams?", "tag": "image"},
    {"query": "What does Jake Gyllenhaal look like?", "tag": "image"},
    {"query": "Find a picture of the movie Tangled", "tag": "image"},
    {"query": "Show me a still from Avengers: Infinity War", "tag": "image"},
    {"query": "What does Chris Evans look like?", "tag": "image"},
    {"query": "Display a poster of Inside Out", "tag": "image"},
    {"query": "Find me a photo of Emily Blunt", "tag": "image"},
    {"query": "Can you show me what Adam Driver looks like?", "tag": "image"},
    {"query": "What does Gal Gadot's Wonder Woman costume look like?", "tag": "image"},
    {"query": "Find a picture of Heath Ledger as the Joker", "tag": "image"},
    {"query": "Show me a poster of Monsters, Inc.", "tag": "image"},
    {"query": "Display an image of the cast of Friends", "tag": "image"},
    {"query": "What does Benedict Wong look like?", "tag": "image"},
    {"query": "Find me a photo of Tom Holland in Spider-Man", "tag": "image"},
    {"query": "Who is the director of Good Will Hunting?", "tag": "qna"},
    {"query": "Who directed The Bridge on the River Kwai?", "tag": "qna"},
    {"query": "Who is the director of Star Wars: Episode VI - Return of the Jedi?", "tag": "qna"},
    {"query": "Who is the screenwriter of The Masked Gang: Cyprus?", "tag": "qna"},
    {"query": "What is the MPAA film rating of Weathering with You? ", "tag": "qna"},
    {"query": "What is the genre of Good Neighbors? ", "tag": "qna"},
    {"query": "Show me a picture of Halle Berry.", "tag": "image"},
    {"query": "What does Denzel Washington look like? ", "tag": "image"},
    {"query": "Let me know what Sandra Bullock looks like. ", "tag": "image"},
    {"query": "Recommend movies similar to Hamlet and Othello. ", "tag": "recommendation"},
    {"query": "Given that I like The Lion King, Pocahontas, and The Beauty and the Beast, can you recommend some movies? ", "tag": "recommendation"},
    {"query": "Recommend movies like Nightmare on Elm Street, Friday the 13th, and Halloween. ", "tag": "recommendation"},
    {"query": "What is the box office of The Princess and the Frog? ", "tag": "qna"},
    {"query": "Can you tell me the publication date of Tom Meets Zizou? ", "tag": "qna"},
    {"query": "Who is the executive producer of X-Men: First Class?", "tag": "qna"},
]

In [3]:
label_list = ["qna", "recommendation", "image"]

# Map tags to numerical labels
tag2id = {tag: idx for idx, tag in enumerate(label_list)}
id2tag = {idx: tag for tag, idx in tag2id.items()}

queries = [d['query'] for d in training_data]
labels = [tag2id[d['tag']] for d in training_data]

In [4]:
# Split into train and test
train_queries, test_queries, train_labels, test_labels = train_test_split(queries, labels, test_size=0.2)

In [5]:
train_queries

['who starred in Pulp Fiction',
 'What does Robert Downey Jr. look like?',
 'rating of The Dark Knight',
 'What should I watch if I enjoyed Moonlight?',
 'What does Jake Gyllenhaal look like?',
 'rating of Inception',
 'Can you find an image of Christian Bale?',
 'screenwriter of The Matrix',
 'If I enjoyed Catch Me If You Can, what should I watch next?',
 'when did Forrest Gump release',
 'What does Ryan Gosling look like?',
 'Find me a photo of Brad Pitt',
 'Recommend some movies like The Twilight Saga.',
 'Show me what Emma Watson looks like',
 'I loved Enola Holmes. What other mystery films should I watch?',
 'Display an image of the cast of Friends',
 'where was Robert Downey Jr born',
 'I liked Crazy Rich Asians. Any similar films?',
 'Can you recommend movies like Casablanca?',
 'What are some movies like Avatar and The Matrix?',
 'I loved The Blind Side. Any recommendations?',
 'when did The Dark Knight come out',
 'what was the gross of Frozen',
 'Find a picture of Bruce Willi

In [6]:
# Load pre-trained tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the input data
train_encodings = tokenizer(train_queries, truncation=True, padding=True, max_length=64)
test_encodings = tokenizer(test_queries, truncation=True, padding=True, max_length=64)

# Convert to PyTorch dataset format
class QTypeDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    
train_dataset = QTypeDataset(train_encodings, train_labels)
test_dataset = QTypeDataset(test_encodings, test_labels)

# Load pre-trained BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_list))

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=50,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
)

# Use the Trainer API to train the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

# Train the model
trainer.train()

# Test the model
predictions = trainer.predict(test_dataset)
predicted_labels = torch.argmax(torch.tensor(predictions.predictions), axis=1)

# Map predicted labels back to tags
predicted_tags = [id2tag[label.item()] for label in predicted_labels]

# Print the results
for query, pred_tag in zip(test_queries, predicted_tags):
    print(f"Query: {query} --> Predicted Tag: {pred_tag}")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1%|          | 10/1500 [00:04<08:37,  2.88it/s]

{'loss': 1.0942, 'grad_norm': 8.825994491577148, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.33}


  1%|▏         | 20/1500 [00:07<08:20,  2.96it/s]

{'loss': 1.0984, 'grad_norm': 7.346462726593018, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.67}


  2%|▏         | 30/1500 [00:11<11:11,  2.19it/s]

{'loss': 1.0576, 'grad_norm': 21.50229263305664, 'learning_rate': 3e-06, 'epoch': 1.0}


                                                 
  2%|▏         | 30/1500 [00:11<11:11,  2.19it/s]

{'eval_loss': 0.9294255375862122, 'eval_runtime': 0.6013, 'eval_samples_per_second': 98.129, 'eval_steps_per_second': 13.306, 'epoch': 1.0}


  3%|▎         | 40/1500 [00:15<08:31,  2.86it/s]

{'loss': 0.9767, 'grad_norm': 10.231073379516602, 'learning_rate': 4.000000000000001e-06, 'epoch': 1.33}


  3%|▎         | 50/1500 [00:18<08:10,  2.96it/s]

{'loss': 0.8852, 'grad_norm': 6.847652912139893, 'learning_rate': 5e-06, 'epoch': 1.67}


  4%|▍         | 60/1500 [00:22<07:52,  3.05it/s]

{'loss': 0.7949, 'grad_norm': 10.28961181640625, 'learning_rate': 6e-06, 'epoch': 2.0}


                                                 
  4%|▍         | 60/1500 [00:22<07:52,  3.05it/s]

{'eval_loss': 0.6760149002075195, 'eval_runtime': 0.3288, 'eval_samples_per_second': 179.422, 'eval_steps_per_second': 24.328, 'epoch': 2.0}


  5%|▍         | 70/1500 [00:25<08:05,  2.94it/s]

{'loss': 0.7129, 'grad_norm': 6.7794508934021, 'learning_rate': 7.000000000000001e-06, 'epoch': 2.33}


  5%|▌         | 80/1500 [00:29<08:11,  2.89it/s]

{'loss': 0.6606, 'grad_norm': 5.4618659019470215, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.67}


  6%|▌         | 90/1500 [00:32<07:45,  3.03it/s]

{'loss': 0.573, 'grad_norm': 7.38939094543457, 'learning_rate': 9e-06, 'epoch': 3.0}


                                                 
  6%|▌         | 90/1500 [00:32<07:45,  3.03it/s]

{'eval_loss': 0.4617913067340851, 'eval_runtime': 0.3311, 'eval_samples_per_second': 178.199, 'eval_steps_per_second': 24.163, 'epoch': 3.0}


  7%|▋         | 100/1500 [00:36<07:58,  2.93it/s]

{'loss': 0.5275, 'grad_norm': 6.095213890075684, 'learning_rate': 1e-05, 'epoch': 3.33}


  7%|▋         | 110/1500 [00:39<07:48,  2.97it/s]

{'loss': 0.4689, 'grad_norm': 11.217734336853027, 'learning_rate': 1.1000000000000001e-05, 'epoch': 3.67}


  8%|▊         | 120/1500 [00:43<07:28,  3.07it/s]

{'loss': 0.3611, 'grad_norm': 2.819086790084839, 'learning_rate': 1.2e-05, 'epoch': 4.0}


                                                  
  8%|▊         | 120/1500 [00:43<07:28,  3.07it/s]

{'eval_loss': 0.27850937843322754, 'eval_runtime': 0.373, 'eval_samples_per_second': 158.198, 'eval_steps_per_second': 21.451, 'epoch': 4.0}


  9%|▊         | 130/1500 [00:46<07:46,  2.94it/s]

{'loss': 0.3596, 'grad_norm': 7.947752475738525, 'learning_rate': 1.3000000000000001e-05, 'epoch': 4.33}


  9%|▉         | 140/1500 [00:50<07:36,  2.98it/s]

{'loss': 0.2251, 'grad_norm': 2.9988133907318115, 'learning_rate': 1.4000000000000001e-05, 'epoch': 4.67}


 10%|█         | 150/1500 [00:53<07:20,  3.06it/s]

{'loss': 0.1537, 'grad_norm': 3.600109100341797, 'learning_rate': 1.5e-05, 'epoch': 5.0}


                                                  
 10%|█         | 150/1500 [00:53<07:20,  3.06it/s]

{'eval_loss': 0.07652691751718521, 'eval_runtime': 0.3292, 'eval_samples_per_second': 179.236, 'eval_steps_per_second': 24.303, 'epoch': 5.0}


 11%|█         | 160/1500 [00:57<07:35,  2.94it/s]

{'loss': 0.112, 'grad_norm': 1.0241100788116455, 'learning_rate': 1.6000000000000003e-05, 'epoch': 5.33}


 11%|█▏        | 170/1500 [01:00<07:30,  2.95it/s]

{'loss': 0.045, 'grad_norm': 0.5955237150192261, 'learning_rate': 1.7000000000000003e-05, 'epoch': 5.67}


 12%|█▏        | 180/1500 [01:03<07:18,  3.01it/s]

{'loss': 0.0273, 'grad_norm': 0.7851498126983643, 'learning_rate': 1.8e-05, 'epoch': 6.0}


                                                  
 12%|█▏        | 180/1500 [01:04<07:18,  3.01it/s]

{'eval_loss': 0.012037835083901882, 'eval_runtime': 0.3284, 'eval_samples_per_second': 179.649, 'eval_steps_per_second': 24.359, 'epoch': 6.0}


 13%|█▎        | 190/1500 [01:07<07:24,  2.95it/s]

{'loss': 0.0143, 'grad_norm': 0.22550523281097412, 'learning_rate': 1.9e-05, 'epoch': 6.33}


 13%|█▎        | 200/1500 [01:11<07:16,  2.98it/s]

{'loss': 0.0096, 'grad_norm': 0.14439156651496887, 'learning_rate': 2e-05, 'epoch': 6.67}


 14%|█▍        | 210/1500 [01:14<07:01,  3.06it/s]

{'loss': 0.0068, 'grad_norm': 0.11178883165121078, 'learning_rate': 2.1e-05, 'epoch': 7.0}


                                                  
 14%|█▍        | 210/1500 [01:14<07:01,  3.06it/s]

{'eval_loss': 0.003632626496255398, 'eval_runtime': 0.3282, 'eval_samples_per_second': 179.794, 'eval_steps_per_second': 24.379, 'epoch': 7.0}


 15%|█▍        | 220/1500 [01:18<07:16,  2.93it/s]

{'loss': 0.0045, 'grad_norm': 0.08390093594789505, 'learning_rate': 2.2000000000000003e-05, 'epoch': 7.33}


 15%|█▌        | 230/1500 [01:21<07:06,  2.98it/s]

{'loss': 0.0042, 'grad_norm': 0.06608575582504272, 'learning_rate': 2.3000000000000003e-05, 'epoch': 7.67}


 16%|█▌        | 240/1500 [01:24<06:59,  3.00it/s]

{'loss': 0.0031, 'grad_norm': 0.07512814551591873, 'learning_rate': 2.4e-05, 'epoch': 8.0}


                                                  
 16%|█▌        | 240/1500 [01:25<06:59,  3.00it/s]

{'eval_loss': 0.002095309318974614, 'eval_runtime': 0.3284, 'eval_samples_per_second': 179.642, 'eval_steps_per_second': 24.358, 'epoch': 8.0}


 17%|█▋        | 250/1500 [01:28<07:05,  2.93it/s]

{'loss': 0.0028, 'grad_norm': 0.04594515636563301, 'learning_rate': 2.5e-05, 'epoch': 8.33}


 17%|█▋        | 260/1500 [01:31<07:00,  2.95it/s]

{'loss': 0.0024, 'grad_norm': 0.04391583055257797, 'learning_rate': 2.6000000000000002e-05, 'epoch': 8.67}


 18%|█▊        | 270/1500 [01:35<06:47,  3.02it/s]

{'loss': 0.0022, 'grad_norm': 0.06757207214832306, 'learning_rate': 2.7000000000000002e-05, 'epoch': 9.0}


                                                  
 18%|█▊        | 270/1500 [01:35<06:47,  3.02it/s]

{'eval_loss': 0.0014017770299687982, 'eval_runtime': 0.3343, 'eval_samples_per_second': 176.472, 'eval_steps_per_second': 23.928, 'epoch': 9.0}


 19%|█▊        | 280/1500 [01:39<07:04,  2.87it/s]

{'loss': 0.0018, 'grad_norm': 0.029394026845693588, 'learning_rate': 2.8000000000000003e-05, 'epoch': 9.33}


 19%|█▉        | 290/1500 [01:42<06:52,  2.93it/s]

{'loss': 0.0017, 'grad_norm': 0.028468681499361992, 'learning_rate': 2.9e-05, 'epoch': 9.67}


 20%|██        | 300/1500 [01:45<06:46,  2.95it/s]

{'loss': 0.0015, 'grad_norm': 0.03354921564459801, 'learning_rate': 3e-05, 'epoch': 10.0}


                                                  
 20%|██        | 300/1500 [01:46<06:46,  2.95it/s]

{'eval_loss': 0.0009872820228338242, 'eval_runtime': 0.3628, 'eval_samples_per_second': 162.63, 'eval_steps_per_second': 22.052, 'epoch': 10.0}


 21%|██        | 310/1500 [01:49<06:48,  2.92it/s]

{'loss': 0.0014, 'grad_norm': 0.025238897651433945, 'learning_rate': 3.1e-05, 'epoch': 10.33}


 21%|██▏       | 320/1500 [01:53<06:38,  2.96it/s]

{'loss': 0.0013, 'grad_norm': 0.019813109189271927, 'learning_rate': 3.2000000000000005e-05, 'epoch': 10.67}


 22%|██▏       | 330/1500 [01:56<06:23,  3.05it/s]

{'loss': 0.0011, 'grad_norm': 0.02340281754732132, 'learning_rate': 3.3e-05, 'epoch': 11.0}


                                                  
 22%|██▏       | 330/1500 [01:56<06:23,  3.05it/s]

{'eval_loss': 0.0007498391787521541, 'eval_runtime': 0.373, 'eval_samples_per_second': 158.191, 'eval_steps_per_second': 21.45, 'epoch': 11.0}


 23%|██▎       | 340/1500 [02:00<06:39,  2.90it/s]

{'loss': 0.001, 'grad_norm': 0.020626824349164963, 'learning_rate': 3.4000000000000007e-05, 'epoch': 11.33}


 23%|██▎       | 350/1500 [02:03<06:31,  2.94it/s]

{'loss': 0.001, 'grad_norm': 0.019910870119929314, 'learning_rate': 3.5e-05, 'epoch': 11.67}


 24%|██▍       | 360/1500 [02:07<06:16,  3.03it/s]

{'loss': 0.0009, 'grad_norm': 0.019854428246617317, 'learning_rate': 3.6e-05, 'epoch': 12.0}


                                                  
 24%|██▍       | 360/1500 [02:07<06:16,  3.03it/s]

{'eval_loss': 0.0005981941940262914, 'eval_runtime': 0.3275, 'eval_samples_per_second': 180.143, 'eval_steps_per_second': 24.426, 'epoch': 12.0}


 25%|██▍       | 370/1500 [02:10<06:25,  2.93it/s]

{'loss': 0.0009, 'grad_norm': 0.0154303889721632, 'learning_rate': 3.7e-05, 'epoch': 12.33}


 25%|██▌       | 380/1500 [02:14<06:19,  2.95it/s]

{'loss': 0.0008, 'grad_norm': 0.013265828602015972, 'learning_rate': 3.8e-05, 'epoch': 12.67}


 26%|██▌       | 390/1500 [02:17<06:04,  3.05it/s]

{'loss': 0.0007, 'grad_norm': 0.013483802787959576, 'learning_rate': 3.9000000000000006e-05, 'epoch': 13.0}


                                                  
 26%|██▌       | 390/1500 [02:17<06:04,  3.05it/s]

{'eval_loss': 0.00047817148151807487, 'eval_runtime': 0.3315, 'eval_samples_per_second': 178.002, 'eval_steps_per_second': 24.136, 'epoch': 13.0}


 27%|██▋       | 400/1500 [02:21<06:16,  2.92it/s]

{'loss': 0.0007, 'grad_norm': 0.00986010767519474, 'learning_rate': 4e-05, 'epoch': 13.33}


 27%|██▋       | 410/1500 [02:24<06:07,  2.97it/s]

{'loss': 0.0007, 'grad_norm': 0.011852752417325974, 'learning_rate': 4.1e-05, 'epoch': 13.67}


 28%|██▊       | 420/1500 [02:28<05:59,  3.00it/s]

{'loss': 0.0006, 'grad_norm': 0.015482986345887184, 'learning_rate': 4.2e-05, 'epoch': 14.0}


                                                  
 28%|██▊       | 420/1500 [02:28<05:59,  3.00it/s]

{'eval_loss': 0.0003931986866518855, 'eval_runtime': 0.3566, 'eval_samples_per_second': 165.472, 'eval_steps_per_second': 22.437, 'epoch': 14.0}


 29%|██▊       | 430/1500 [02:31<06:08,  2.91it/s]

{'loss': 0.0006, 'grad_norm': 0.012839000672101974, 'learning_rate': 4.3e-05, 'epoch': 14.33}


 29%|██▉       | 440/1500 [02:35<05:59,  2.95it/s]

{'loss': 0.0005, 'grad_norm': 0.013947228901088238, 'learning_rate': 4.4000000000000006e-05, 'epoch': 14.67}


 30%|███       | 450/1500 [02:38<05:45,  3.04it/s]

{'loss': 0.0005, 'grad_norm': 0.010091590695083141, 'learning_rate': 4.5e-05, 'epoch': 15.0}


                                                  
 30%|███       | 450/1500 [02:38<05:45,  3.04it/s]

{'eval_loss': 0.0003343716380186379, 'eval_runtime': 0.3316, 'eval_samples_per_second': 177.92, 'eval_steps_per_second': 24.125, 'epoch': 15.0}


 31%|███       | 460/1500 [02:42<05:54,  2.93it/s]

{'loss': 0.0005, 'grad_norm': 0.010360774584114552, 'learning_rate': 4.600000000000001e-05, 'epoch': 15.33}


 31%|███▏      | 470/1500 [02:45<05:49,  2.95it/s]

{'loss': 0.0005, 'grad_norm': 0.010227937251329422, 'learning_rate': 4.7e-05, 'epoch': 15.67}


 32%|███▏      | 480/1500 [02:49<05:33,  3.06it/s]

{'loss': 0.0004, 'grad_norm': 0.007675485219806433, 'learning_rate': 4.8e-05, 'epoch': 16.0}


                                                  
 32%|███▏      | 480/1500 [02:49<05:33,  3.06it/s]

{'eval_loss': 0.0002820275549311191, 'eval_runtime': 0.3279, 'eval_samples_per_second': 179.924, 'eval_steps_per_second': 24.396, 'epoch': 16.0}


 33%|███▎      | 490/1500 [02:52<05:46,  2.91it/s]

{'loss': 0.0004, 'grad_norm': 0.006773033179342747, 'learning_rate': 4.9e-05, 'epoch': 16.33}


 33%|███▎      | 500/1500 [02:56<05:37,  2.96it/s]

{'loss': 0.0004, 'grad_norm': 0.008038283325731754, 'learning_rate': 5e-05, 'epoch': 16.67}


 34%|███▍      | 510/1500 [03:01<05:49,  2.83it/s]

{'loss': 0.0004, 'grad_norm': 0.004830126650631428, 'learning_rate': 4.9500000000000004e-05, 'epoch': 17.0}


                                                  
 34%|███▍      | 510/1500 [03:01<05:49,  2.83it/s]

{'eval_loss': 0.00024132100224960595, 'eval_runtime': 0.3391, 'eval_samples_per_second': 173.974, 'eval_steps_per_second': 23.59, 'epoch': 17.0}


 35%|███▍      | 520/1500 [03:05<06:00,  2.72it/s]

{'loss': 0.0004, 'grad_norm': 0.007854222320020199, 'learning_rate': 4.9e-05, 'epoch': 17.33}


 35%|███▌      | 530/1500 [03:08<05:33,  2.90it/s]

{'loss': 0.0003, 'grad_norm': 0.006248532794415951, 'learning_rate': 4.85e-05, 'epoch': 17.67}


 36%|███▌      | 540/1500 [03:12<05:20,  3.00it/s]

{'loss': 0.0003, 'grad_norm': 0.006373974960297346, 'learning_rate': 4.8e-05, 'epoch': 18.0}


                                                  
 36%|███▌      | 540/1500 [03:12<05:20,  3.00it/s]

{'eval_loss': 0.0002110218774760142, 'eval_runtime': 0.3309, 'eval_samples_per_second': 178.296, 'eval_steps_per_second': 24.176, 'epoch': 18.0}


 37%|███▋      | 550/1500 [03:15<05:24,  2.93it/s]

{'loss': 0.0003, 'grad_norm': 0.005918843671679497, 'learning_rate': 4.75e-05, 'epoch': 18.33}


 37%|███▋      | 560/1500 [03:19<05:16,  2.97it/s]

{'loss': 0.0003, 'grad_norm': 0.0059205275028944016, 'learning_rate': 4.7e-05, 'epoch': 18.67}


 38%|███▊      | 570/1500 [03:22<05:04,  3.06it/s]

{'loss': 0.0003, 'grad_norm': 0.004121669568121433, 'learning_rate': 4.6500000000000005e-05, 'epoch': 19.0}


                                                  
 38%|███▊      | 570/1500 [03:23<05:04,  3.06it/s]

{'eval_loss': 0.00018975081911776215, 'eval_runtime': 0.3277, 'eval_samples_per_second': 180.061, 'eval_steps_per_second': 24.415, 'epoch': 19.0}


 39%|███▊      | 580/1500 [03:26<05:14,  2.93it/s]

{'loss': 0.0003, 'grad_norm': 0.00672291312366724, 'learning_rate': 4.600000000000001e-05, 'epoch': 19.33}


 39%|███▉      | 590/1500 [03:29<05:07,  2.96it/s]

{'loss': 0.0003, 'grad_norm': 0.004705086350440979, 'learning_rate': 4.55e-05, 'epoch': 19.67}


 40%|████      | 600/1500 [03:33<04:55,  3.05it/s]

{'loss': 0.0003, 'grad_norm': 0.010221980512142181, 'learning_rate': 4.5e-05, 'epoch': 20.0}


                                                  
 40%|████      | 600/1500 [03:33<04:55,  3.05it/s]

{'eval_loss': 0.00017094363283831626, 'eval_runtime': 0.3308, 'eval_samples_per_second': 178.34, 'eval_steps_per_second': 24.182, 'epoch': 20.0}


 41%|████      | 610/1500 [03:36<05:03,  2.94it/s]

{'loss': 0.0003, 'grad_norm': 0.004128170665353537, 'learning_rate': 4.4500000000000004e-05, 'epoch': 20.33}


 41%|████▏     | 620/1500 [03:40<04:56,  2.97it/s]

{'loss': 0.0003, 'grad_norm': 0.004148968495428562, 'learning_rate': 4.4000000000000006e-05, 'epoch': 20.67}


 42%|████▏     | 630/1500 [03:43<04:46,  3.04it/s]

{'loss': 0.0002, 'grad_norm': 0.005801196675747633, 'learning_rate': 4.35e-05, 'epoch': 21.0}


                                                  
 42%|████▏     | 630/1500 [03:44<04:46,  3.04it/s]

{'eval_loss': 0.00015574399731121957, 'eval_runtime': 0.3635, 'eval_samples_per_second': 162.319, 'eval_steps_per_second': 22.009, 'epoch': 21.0}


 43%|████▎     | 640/1500 [03:47<04:55,  2.91it/s]

{'loss': 0.0002, 'grad_norm': 0.004284354392439127, 'learning_rate': 4.3e-05, 'epoch': 21.33}


 43%|████▎     | 650/1500 [03:50<04:46,  2.97it/s]

{'loss': 0.0002, 'grad_norm': 0.0038560947868973017, 'learning_rate': 4.25e-05, 'epoch': 21.67}


 44%|████▍     | 660/1500 [03:54<04:38,  3.02it/s]

{'loss': 0.0002, 'grad_norm': 0.004808507394045591, 'learning_rate': 4.2e-05, 'epoch': 22.0}


                                                  
 44%|████▍     | 660/1500 [03:54<04:38,  3.02it/s]

{'eval_loss': 0.00014325037773232907, 'eval_runtime': 0.3284, 'eval_samples_per_second': 179.679, 'eval_steps_per_second': 24.363, 'epoch': 22.0}


 45%|████▍     | 670/1500 [03:57<04:42,  2.94it/s]

{'loss': 0.0002, 'grad_norm': 0.003991642035543919, 'learning_rate': 4.15e-05, 'epoch': 22.33}


 45%|████▌     | 680/1500 [04:01<04:36,  2.97it/s]

{'loss': 0.0002, 'grad_norm': 0.0037362673319876194, 'learning_rate': 4.1e-05, 'epoch': 22.67}


 46%|████▌     | 690/1500 [04:04<04:25,  3.05it/s]

{'loss': 0.0002, 'grad_norm': 0.004773550666868687, 'learning_rate': 4.05e-05, 'epoch': 23.0}


                                                  
 46%|████▌     | 690/1500 [04:04<04:25,  3.05it/s]

{'eval_loss': 0.0001326251367572695, 'eval_runtime': 0.3329, 'eval_samples_per_second': 177.24, 'eval_steps_per_second': 24.033, 'epoch': 23.0}


 47%|████▋     | 700/1500 [04:08<04:32,  2.93it/s]

{'loss': 0.0002, 'grad_norm': 0.0033088841009885073, 'learning_rate': 4e-05, 'epoch': 23.33}


 47%|████▋     | 710/1500 [04:11<04:26,  2.97it/s]

{'loss': 0.0002, 'grad_norm': 0.0038174488581717014, 'learning_rate': 3.9500000000000005e-05, 'epoch': 23.67}


 48%|████▊     | 720/1500 [04:15<04:14,  3.06it/s]

{'loss': 0.0002, 'grad_norm': 0.0031442244071513414, 'learning_rate': 3.9000000000000006e-05, 'epoch': 24.0}


                                                  
 48%|████▊     | 720/1500 [04:15<04:14,  3.06it/s]

{'eval_loss': 0.00012432486983016133, 'eval_runtime': 0.3656, 'eval_samples_per_second': 161.383, 'eval_steps_per_second': 21.883, 'epoch': 24.0}


 49%|████▊     | 730/1500 [04:18<04:22,  2.93it/s]

{'loss': 0.0002, 'grad_norm': 0.003055920824408531, 'learning_rate': 3.85e-05, 'epoch': 24.33}


 49%|████▉     | 740/1500 [04:22<04:16,  2.97it/s]

{'loss': 0.0002, 'grad_norm': 0.003729382762685418, 'learning_rate': 3.8e-05, 'epoch': 24.67}


 50%|█████     | 750/1500 [04:25<04:04,  3.06it/s]

{'loss': 0.0002, 'grad_norm': 0.0033157956786453724, 'learning_rate': 3.7500000000000003e-05, 'epoch': 25.0}


                                                  
 50%|█████     | 750/1500 [04:25<04:04,  3.06it/s]

{'eval_loss': 0.00011647631617961451, 'eval_runtime': 0.3266, 'eval_samples_per_second': 180.652, 'eval_steps_per_second': 24.495, 'epoch': 25.0}


 51%|█████     | 760/1500 [04:29<04:12,  2.93it/s]

{'loss': 0.0002, 'grad_norm': 0.003190933959558606, 'learning_rate': 3.7e-05, 'epoch': 25.33}


 51%|█████▏    | 770/1500 [04:32<04:07,  2.96it/s]

{'loss': 0.0002, 'grad_norm': 0.0037128038238734007, 'learning_rate': 3.65e-05, 'epoch': 25.67}


 52%|█████▏    | 780/1500 [04:36<04:00,  3.00it/s]

{'loss': 0.0002, 'grad_norm': 0.0035515539348125458, 'learning_rate': 3.6e-05, 'epoch': 26.0}


                                                  
 52%|█████▏    | 780/1500 [04:36<04:00,  3.00it/s]

{'eval_loss': 0.00010999586811522022, 'eval_runtime': 0.3297, 'eval_samples_per_second': 178.941, 'eval_steps_per_second': 24.263, 'epoch': 26.0}


 53%|█████▎    | 790/1500 [04:39<04:04,  2.91it/s]

{'loss': 0.0002, 'grad_norm': 0.0028398248832672834, 'learning_rate': 3.55e-05, 'epoch': 26.33}


 53%|█████▎    | 800/1500 [04:43<03:55,  2.97it/s]

{'loss': 0.0002, 'grad_norm': 0.0032570832408964634, 'learning_rate': 3.5e-05, 'epoch': 26.67}


 54%|█████▍    | 810/1500 [04:46<03:48,  3.02it/s]

{'loss': 0.0002, 'grad_norm': 0.0037238115910440683, 'learning_rate': 3.45e-05, 'epoch': 27.0}


                                                  
 54%|█████▍    | 810/1500 [04:46<03:48,  3.02it/s]

{'eval_loss': 0.00010425904474686831, 'eval_runtime': 0.336, 'eval_samples_per_second': 175.61, 'eval_steps_per_second': 23.812, 'epoch': 27.0}


 55%|█████▍    | 820/1500 [04:50<03:53,  2.91it/s]

{'loss': 0.0002, 'grad_norm': 0.002679910510778427, 'learning_rate': 3.4000000000000007e-05, 'epoch': 27.33}


 55%|█████▌    | 830/1500 [04:53<03:48,  2.94it/s]

{'loss': 0.0002, 'grad_norm': 0.0028100863564759493, 'learning_rate': 3.35e-05, 'epoch': 27.67}


 56%|█████▌    | 840/1500 [04:57<03:39,  3.01it/s]

{'loss': 0.0001, 'grad_norm': 0.002577619394287467, 'learning_rate': 3.3e-05, 'epoch': 28.0}


                                                  
 56%|█████▌    | 840/1500 [04:57<03:39,  3.01it/s]

{'eval_loss': 9.940935706254095e-05, 'eval_runtime': 0.3373, 'eval_samples_per_second': 174.896, 'eval_steps_per_second': 23.715, 'epoch': 28.0}


 57%|█████▋    | 850/1500 [05:00<03:41,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.0025898872409015894, 'learning_rate': 3.2500000000000004e-05, 'epoch': 28.33}


 57%|█████▋    | 860/1500 [05:04<03:35,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0022566423285752535, 'learning_rate': 3.2000000000000005e-05, 'epoch': 28.67}


 58%|█████▊    | 870/1500 [05:07<03:26,  3.05it/s]

{'loss': 0.0001, 'grad_norm': 0.0051058330573141575, 'learning_rate': 3.15e-05, 'epoch': 29.0}


                                                  
 58%|█████▊    | 870/1500 [05:07<03:26,  3.05it/s]

{'eval_loss': 9.490928641753271e-05, 'eval_runtime': 0.3629, 'eval_samples_per_second': 162.584, 'eval_steps_per_second': 22.045, 'epoch': 29.0}


 59%|█████▊    | 880/1500 [05:11<03:31,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.002395190764218569, 'learning_rate': 3.1e-05, 'epoch': 29.33}


 59%|█████▉    | 890/1500 [05:14<03:25,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.002274327212944627, 'learning_rate': 3.05e-05, 'epoch': 29.67}


 60%|██████    | 900/1500 [05:18<03:16,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0022905603982508183, 'learning_rate': 3e-05, 'epoch': 30.0}


                                                  
 60%|██████    | 900/1500 [05:18<03:16,  3.06it/s]

{'eval_loss': 9.061733726412058e-05, 'eval_runtime': 0.3285, 'eval_samples_per_second': 179.611, 'eval_steps_per_second': 24.354, 'epoch': 30.0}


 61%|██████    | 910/1500 [05:21<03:21,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.002675343072041869, 'learning_rate': 2.95e-05, 'epoch': 30.33}


 61%|██████▏   | 920/1500 [05:25<03:15,  2.96it/s]

{'loss': 0.0001, 'grad_norm': 0.0026784895453602076, 'learning_rate': 2.9e-05, 'epoch': 30.67}


 62%|██████▏   | 930/1500 [05:28<03:05,  3.07it/s]

{'loss': 0.0001, 'grad_norm': 0.0030053132213652134, 'learning_rate': 2.8499999999999998e-05, 'epoch': 31.0}


                                                  
 62%|██████▏   | 930/1500 [05:28<03:05,  3.07it/s]

{'eval_loss': 8.730545232538134e-05, 'eval_runtime': 0.3308, 'eval_samples_per_second': 178.374, 'eval_steps_per_second': 24.186, 'epoch': 31.0}


 63%|██████▎   | 940/1500 [05:32<03:11,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.00229508220218122, 'learning_rate': 2.8000000000000003e-05, 'epoch': 31.33}


 63%|██████▎   | 950/1500 [05:35<03:05,  2.96it/s]

{'loss': 0.0001, 'grad_norm': 0.00218142569065094, 'learning_rate': 2.7500000000000004e-05, 'epoch': 31.67}


 64%|██████▍   | 960/1500 [05:38<02:56,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.002887640381231904, 'learning_rate': 2.7000000000000002e-05, 'epoch': 32.0}


                                                  
 64%|██████▍   | 960/1500 [05:39<02:56,  3.06it/s]

{'eval_loss': 8.383596286876127e-05, 'eval_runtime': 0.3325, 'eval_samples_per_second': 177.462, 'eval_steps_per_second': 24.063, 'epoch': 32.0}


 65%|██████▍   | 970/1500 [05:42<03:00,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.0037698487285524607, 'learning_rate': 2.6500000000000004e-05, 'epoch': 32.33}


 65%|██████▌   | 980/1500 [05:46<02:55,  2.96it/s]

{'loss': 0.0001, 'grad_norm': 0.00183938629925251, 'learning_rate': 2.6000000000000002e-05, 'epoch': 32.67}


 66%|██████▌   | 990/1500 [05:49<02:47,  3.05it/s]

{'loss': 0.0001, 'grad_norm': 0.003575096372514963, 'learning_rate': 2.5500000000000003e-05, 'epoch': 33.0}


                                                  
 66%|██████▌   | 990/1500 [05:49<02:47,  3.05it/s]

{'eval_loss': 8.083931606961414e-05, 'eval_runtime': 0.3275, 'eval_samples_per_second': 180.13, 'eval_steps_per_second': 24.424, 'epoch': 33.0}


 67%|██████▋   | 1000/1500 [05:53<02:52,  2.90it/s]

{'loss': 0.0001, 'grad_norm': 0.0020889793522655964, 'learning_rate': 2.5e-05, 'epoch': 33.33}


 67%|██████▋   | 1010/1500 [05:58<02:59,  2.73it/s]

{'loss': 0.0001, 'grad_norm': 0.0019144502002745867, 'learning_rate': 2.45e-05, 'epoch': 33.67}


 68%|██████▊   | 1020/1500 [06:02<02:41,  2.98it/s]

{'loss': 0.0001, 'grad_norm': 0.002584288828074932, 'learning_rate': 2.4e-05, 'epoch': 34.0}


                                                   
 68%|██████▊   | 1020/1500 [06:02<02:41,  2.98it/s]

{'eval_loss': 7.833976997062564e-05, 'eval_runtime': 0.3313, 'eval_samples_per_second': 178.102, 'eval_steps_per_second': 24.149, 'epoch': 34.0}


 69%|██████▊   | 1030/1500 [06:05<02:40,  2.92it/s]

{'loss': 0.0001, 'grad_norm': 0.0022895645815879107, 'learning_rate': 2.35e-05, 'epoch': 34.33}


 69%|██████▉   | 1040/1500 [06:09<02:35,  2.96it/s]

{'loss': 0.0001, 'grad_norm': 0.0018976592691615224, 'learning_rate': 2.3000000000000003e-05, 'epoch': 34.67}


 70%|███████   | 1050/1500 [06:12<02:27,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0020646657794713974, 'learning_rate': 2.25e-05, 'epoch': 35.0}


                                                   
 70%|███████   | 1050/1500 [06:12<02:27,  3.06it/s]

{'eval_loss': 7.587455911561847e-05, 'eval_runtime': 0.3493, 'eval_samples_per_second': 168.885, 'eval_steps_per_second': 22.9, 'epoch': 35.0}


 71%|███████   | 1060/1500 [06:16<02:31,  2.91it/s]

{'loss': 0.0001, 'grad_norm': 0.0020603288430720568, 'learning_rate': 2.2000000000000003e-05, 'epoch': 35.33}


 71%|███████▏  | 1070/1500 [06:19<02:24,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.002112769987434149, 'learning_rate': 2.15e-05, 'epoch': 35.67}


 72%|███████▏  | 1080/1500 [06:23<02:17,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0023808307014405727, 'learning_rate': 2.1e-05, 'epoch': 36.0}


                                                   
 72%|███████▏  | 1080/1500 [06:23<02:17,  3.06it/s]

{'eval_loss': 7.370033563347533e-05, 'eval_runtime': 0.3273, 'eval_samples_per_second': 180.285, 'eval_steps_per_second': 24.445, 'epoch': 36.0}


 73%|███████▎  | 1090/1500 [06:26<02:19,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.004066396038979292, 'learning_rate': 2.05e-05, 'epoch': 36.33}


 73%|███████▎  | 1100/1500 [06:30<02:14,  2.96it/s]

{'loss': 0.0001, 'grad_norm': 0.0017668831860646605, 'learning_rate': 2e-05, 'epoch': 36.67}


 74%|███████▍  | 1110/1500 [06:33<02:07,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0022259349934756756, 'learning_rate': 1.9500000000000003e-05, 'epoch': 37.0}


                                                   
 74%|███████▍  | 1110/1500 [06:33<02:07,  3.06it/s]

{'eval_loss': 7.187164737842977e-05, 'eval_runtime': 0.3284, 'eval_samples_per_second': 179.673, 'eval_steps_per_second': 24.362, 'epoch': 37.0}


 75%|███████▍  | 1120/1500 [06:37<02:09,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.001932106213644147, 'learning_rate': 1.9e-05, 'epoch': 37.33}


 75%|███████▌  | 1130/1500 [06:40<02:04,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0023962222039699554, 'learning_rate': 1.85e-05, 'epoch': 37.67}


 76%|███████▌  | 1140/1500 [06:44<01:57,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.001860545831732452, 'learning_rate': 1.8e-05, 'epoch': 38.0}


                                                   
 76%|███████▌  | 1140/1500 [06:44<01:57,  3.06it/s]

{'eval_loss': 7.024100341368467e-05, 'eval_runtime': 0.3311, 'eval_samples_per_second': 178.196, 'eval_steps_per_second': 24.162, 'epoch': 38.0}


 77%|███████▋  | 1150/1500 [06:47<01:59,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.0016721367137506604, 'learning_rate': 1.75e-05, 'epoch': 38.33}


 77%|███████▋  | 1160/1500 [06:51<01:54,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0022023285273462534, 'learning_rate': 1.7000000000000003e-05, 'epoch': 38.67}


 78%|███████▊  | 1170/1500 [06:54<01:47,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0019494619918987155, 'learning_rate': 1.65e-05, 'epoch': 39.0}


                                                   
 78%|███████▊  | 1170/1500 [06:54<01:47,  3.06it/s]

{'eval_loss': 6.870934885228053e-05, 'eval_runtime': 0.3283, 'eval_samples_per_second': 179.735, 'eval_steps_per_second': 24.371, 'epoch': 39.0}


 79%|███████▊  | 1180/1500 [06:58<01:48,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.0020210212096571922, 'learning_rate': 1.6000000000000003e-05, 'epoch': 39.33}


 79%|███████▉  | 1190/1500 [07:01<01:44,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0021332823671400547, 'learning_rate': 1.55e-05, 'epoch': 39.67}


 80%|████████  | 1200/1500 [07:04<01:37,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.004033771343529224, 'learning_rate': 1.5e-05, 'epoch': 40.0}


                                                   
 80%|████████  | 1200/1500 [07:05<01:37,  3.06it/s]

{'eval_loss': 6.730299355695024e-05, 'eval_runtime': 0.3295, 'eval_samples_per_second': 179.034, 'eval_steps_per_second': 24.276, 'epoch': 40.0}


 81%|████████  | 1210/1500 [07:08<01:38,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.0018764038104563951, 'learning_rate': 1.45e-05, 'epoch': 40.33}


 81%|████████▏ | 1220/1500 [07:12<01:34,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0014986322494223714, 'learning_rate': 1.4000000000000001e-05, 'epoch': 40.67}


 82%|████████▏ | 1230/1500 [07:15<01:28,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0019254910293966532, 'learning_rate': 1.3500000000000001e-05, 'epoch': 41.0}


                                                   
 82%|████████▏ | 1230/1500 [07:15<01:28,  3.06it/s]

{'eval_loss': 6.605423550354317e-05, 'eval_runtime': 0.3291, 'eval_samples_per_second': 179.26, 'eval_steps_per_second': 24.306, 'epoch': 41.0}


 83%|████████▎ | 1240/1500 [07:19<01:28,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.001693899161182344, 'learning_rate': 1.3000000000000001e-05, 'epoch': 41.33}


 83%|████████▎ | 1250/1500 [07:22<01:24,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.00243524182587862, 'learning_rate': 1.25e-05, 'epoch': 41.67}


 84%|████████▍ | 1260/1500 [07:25<01:19,  3.03it/s]

{'loss': 0.0001, 'grad_norm': 0.0034511741250753403, 'learning_rate': 1.2e-05, 'epoch': 42.0}


                                                   
 84%|████████▍ | 1260/1500 [07:26<01:19,  3.03it/s]

{'eval_loss': 6.49287539999932e-05, 'eval_runtime': 0.3281, 'eval_samples_per_second': 179.83, 'eval_steps_per_second': 24.384, 'epoch': 42.0}


 85%|████████▍ | 1270/1500 [07:29<01:18,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.0019448749953880906, 'learning_rate': 1.1500000000000002e-05, 'epoch': 42.33}


 85%|████████▌ | 1280/1500 [07:32<01:14,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0019354360410943627, 'learning_rate': 1.1000000000000001e-05, 'epoch': 42.67}


 86%|████████▌ | 1290/1500 [07:36<01:08,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.002233816310763359, 'learning_rate': 1.05e-05, 'epoch': 43.0}


                                                   
 86%|████████▌ | 1290/1500 [07:36<01:08,  3.06it/s]

{'eval_loss': 6.393460353137925e-05, 'eval_runtime': 0.3304, 'eval_samples_per_second': 178.58, 'eval_steps_per_second': 24.214, 'epoch': 43.0}


 87%|████████▋ | 1300/1500 [07:40<01:08,  2.94it/s]

{'loss': 0.0001, 'grad_norm': 0.0034294729121029377, 'learning_rate': 1e-05, 'epoch': 43.33}


 87%|████████▋ | 1310/1500 [07:43<01:03,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0017909301677718759, 'learning_rate': 9.5e-06, 'epoch': 43.67}


 88%|████████▊ | 1320/1500 [07:46<00:58,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0019819922745227814, 'learning_rate': 9e-06, 'epoch': 44.0}


                                                   
 88%|████████▊ | 1320/1500 [07:47<00:58,  3.06it/s]

{'eval_loss': 6.313039193628356e-05, 'eval_runtime': 0.328, 'eval_samples_per_second': 179.877, 'eval_steps_per_second': 24.39, 'epoch': 44.0}


 89%|████████▊ | 1330/1500 [07:50<00:57,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.0015935241244733334, 'learning_rate': 8.500000000000002e-06, 'epoch': 44.33}


 89%|████████▉ | 1340/1500 [07:53<00:54,  2.96it/s]

{'loss': 0.0001, 'grad_norm': 0.0017646211199462414, 'learning_rate': 8.000000000000001e-06, 'epoch': 44.67}


 90%|█████████ | 1350/1500 [07:57<00:48,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0018039565766230226, 'learning_rate': 7.5e-06, 'epoch': 45.0}


                                                   
 90%|█████████ | 1350/1500 [07:57<00:48,  3.06it/s]

{'eval_loss': 6.245954864425585e-05, 'eval_runtime': 0.3288, 'eval_samples_per_second': 179.446, 'eval_steps_per_second': 24.332, 'epoch': 45.0}


 91%|█████████ | 1360/1500 [08:00<00:47,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.0018674117745831609, 'learning_rate': 7.000000000000001e-06, 'epoch': 45.33}


 91%|█████████▏| 1370/1500 [08:04<00:44,  2.95it/s]

{'loss': 0.0001, 'grad_norm': 0.0015586907975375652, 'learning_rate': 6.5000000000000004e-06, 'epoch': 45.67}


 92%|█████████▏| 1380/1500 [08:07<00:39,  3.05it/s]

{'loss': 0.0001, 'grad_norm': 0.002500680508092046, 'learning_rate': 6e-06, 'epoch': 46.0}


                                                   
 92%|█████████▏| 1380/1500 [08:08<00:39,  3.05it/s]

{'eval_loss': 6.192003638716415e-05, 'eval_runtime': 0.329, 'eval_samples_per_second': 179.322, 'eval_steps_per_second': 24.315, 'epoch': 46.0}


 93%|█████████▎| 1390/1500 [08:11<00:37,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.0015319710364565253, 'learning_rate': 5.500000000000001e-06, 'epoch': 46.33}


 93%|█████████▎| 1400/1500 [08:14<00:33,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0015098094008862972, 'learning_rate': 5e-06, 'epoch': 46.67}


 94%|█████████▍| 1410/1500 [08:18<00:29,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0015109609812498093, 'learning_rate': 4.5e-06, 'epoch': 47.0}


                                                   
 94%|█████████▍| 1410/1500 [08:18<00:29,  3.06it/s]

{'eval_loss': 6.15118769928813e-05, 'eval_runtime': 0.3283, 'eval_samples_per_second': 179.732, 'eval_steps_per_second': 24.37, 'epoch': 47.0}


 95%|█████████▍| 1420/1500 [08:21<00:27,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.001998133258894086, 'learning_rate': 4.000000000000001e-06, 'epoch': 47.33}


 95%|█████████▌| 1430/1500 [08:25<00:23,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0016069381963461637, 'learning_rate': 3.5000000000000004e-06, 'epoch': 47.67}


 96%|█████████▌| 1440/1500 [08:28<00:19,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.002478568581864238, 'learning_rate': 3e-06, 'epoch': 48.0}


                                                   
 96%|█████████▌| 1440/1500 [08:28<00:19,  3.06it/s]

{'eval_loss': 6.120272155385464e-05, 'eval_runtime': 0.3279, 'eval_samples_per_second': 179.915, 'eval_steps_per_second': 24.395, 'epoch': 48.0}


 97%|█████████▋| 1450/1500 [08:32<00:17,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.0013752616941928864, 'learning_rate': 2.5e-06, 'epoch': 48.33}


 97%|█████████▋| 1460/1500 [08:35<00:13,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0016493207076564431, 'learning_rate': 2.0000000000000003e-06, 'epoch': 48.67}


 98%|█████████▊| 1470/1500 [08:39<00:09,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0017246822826564312, 'learning_rate': 1.5e-06, 'epoch': 49.0}


                                                   
 98%|█████████▊| 1470/1500 [08:39<00:09,  3.06it/s]

{'eval_loss': 6.101480539655313e-05, 'eval_runtime': 0.3312, 'eval_samples_per_second': 178.151, 'eval_steps_per_second': 24.156, 'epoch': 49.0}


 99%|█████████▊| 1480/1500 [08:42<00:06,  2.93it/s]

{'loss': 0.0001, 'grad_norm': 0.0020352336578071117, 'learning_rate': 1.0000000000000002e-06, 'epoch': 49.33}


 99%|█████████▉| 1490/1500 [08:46<00:03,  2.97it/s]

{'loss': 0.0001, 'grad_norm': 0.0014875916531309485, 'learning_rate': 5.000000000000001e-07, 'epoch': 49.67}


100%|██████████| 1500/1500 [08:49<00:00,  3.06it/s]

{'loss': 0.0001, 'grad_norm': 0.001861209631897509, 'learning_rate': 0.0, 'epoch': 50.0}


                                                   
100%|██████████| 1500/1500 [08:52<00:00,  2.82it/s]


{'eval_loss': 6.0950140323257074e-05, 'eval_runtime': 0.3361, 'eval_samples_per_second': 175.54, 'eval_steps_per_second': 23.802, 'epoch': 50.0}
{'train_runtime': 532.224, 'train_samples_per_second': 21.983, 'train_steps_per_second': 2.818, 'train_loss': 0.06812195948363903, 'epoch': 50.0}


100%|██████████| 8/8 [00:00<00:00, 19.97it/s]

Query: when was Samuel L. Jackson born --> Predicted Tag: qna
Query: Can I see a photo of Tom Hanks? --> Predicted Tag: image
Query: Find a poster of Frozen 2 --> Predicted Tag: image
Query: birthplace of Christopher Nolan --> Predicted Tag: qna
Query: What does Daniel Radcliffe look like? --> Predicted Tag: image
Query: What does Matt Damon look like? --> Predicted Tag: image
Query: total revenue of Inception --> Predicted Tag: qna
Query: who was the producer of Star Wars --> Predicted Tag: qna
Query: who was the director of Inception --> Predicted Tag: qna
Query: Who is the screenwriter of The Masked Gang: Cyprus? --> Predicted Tag: qna
Query: Suggest films similar to Slumdog Millionaire. --> Predicted Tag: recommendation
Query: Display an image of Leonardo DiCaprio --> Predicted Tag: image
Query: what was the budget of Titanic --> Predicted Tag: qna
Query: What are some movies like The Silence of the Lambs? --> Predicted Tag: recommendation
Query: What are some films similar to The 




In [7]:
# Save the model
model.save_pretrained("../data/final_classification_model")

In [10]:
# load the model
model = BertForSequenceClassification.from_pretrained("../data/final_classification_model")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
id2tag = {0: 'qna', 1: 'recommendation', 2: 'image'}

# Test the model
query = "Who is the director of Inception?"
inputs = tokenizer(query, return_tensors="pt")
outputs = model(**inputs)
predicted_label = torch.argmax(outputs.logits[0]).item()
predicted_tag = id2tag[predicted_label]
predicted_tag

'qna'

In [None]:
id2tag