# Setup

In [6]:
!pip install langchain_groq langchain_core langchain_huggingface

Collecting langchain_groq
  Downloading langchain_groq-0.2.1-py3-none-any.whl.metadata (2.9 kB)
Collecting langchain_core
  Downloading langchain_core-0.3.22-py3-none-any.whl.metadata (6.3 kB)
Collecting langchain_huggingface
  Downloading langchain_huggingface-0.1.2-py3-none-any.whl.metadata (1.3 kB)
Collecting groq<1,>=0.4.1 (from langchain_groq)
  Downloading groq-0.13.0-py3-none-any.whl.metadata (13 kB)
Collecting langsmith<0.2.0,>=0.1.125 (from langchain_core)
  Downloading langsmith-0.1.147-py3-none-any.whl.metadata (14 kB)
Collecting packaging<25,>=23.2 (from langchain_core)
  Downloading packaging-24.2-py3-none-any.whl.metadata (3.2 kB)
Collecting sentence-transformers>=2.6.0 (from langchain_huggingface)
  Downloading sentence_transformers-3.3.1-py3-none-any.whl.metadata (10 kB)
Collecting requests-toolbelt<2.0.0,>=1.0.0 (from langsmith<0.2.0,>=0.1.125->langchain_core)
  Downloading requests_toolbelt-1.0.0-py2.py3-none-any.whl.metadata (14 kB)
Downloading langchain_groq-0.2.1-p

In [7]:
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

## Data Loading

In [8]:
data = pd.read_csv("/kaggle/input/tubes-nlp/seq2seq_data.csv")
data

Unnamed: 0,topic_category,original_text,base_word_text
0,9.0,what makes friendship click?,what make friendship click
1,2.0,why does zebras have stripes?,why zebra stripe
2,4.0,what did the itsy bitsy sipder climb up?,what itsy bitsy sipder climb up
3,4.0,what is the difference between a bachelors and...,what difference between bachelor and master de...
4,3.0,why do women get pms?,why woman get pm
...,...,...,...
174712,9.0,imperative: tell me what guys only guys must do!,tell me what guy only guy must
174713,9.0,tell me the story of any fantasy figure i'd ch...,tell me story of any fantasy figure i d choose
174714,8.0,imperative: reveal a secret about life.,reveal secret about life
174715,6.0,imperative: demande à domenech ce qu'il en est...,demande à domenech ce quil en est de son méti...


In [25]:
data["topic_category"] = data["topic_category"]-1
data["topic_category"]

0         8.0
1         1.0
2         3.0
3         3.0
4         2.0
         ... 
174712    8.0
174713    8.0
174714    7.0
174715    5.0
174716    4.0
Name: topic_category, Length: 173907, dtype: float64

# Data Preparation

## Data Cleaning

In [26]:
data.dropna(inplace=True)
data.isna().sum()

topic_category    0
original_text     0
base_word_text    0
processed_text    0
dtype: int64

## Data Preprocessing

In [27]:
import string
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import nltk

# Ensure you have the necessary NLTK data files
nltk.download('punkt')
nltk.download('stopwords')

# Define stopwords and punctuation
# stop_words = set(stopwords.words('english'))
stop_words = set()
stop_words.update(["imperative", "declarative"])
punctuation = string.punctuation

# Function to preprocess text
def preprocess_text(text):
    # Convert to lowercase
    text = text.lower()
    # Remove punctuation
    text = text.translate(str.maketrans('', '', punctuation))
    # Tokenize text
    words = word_tokenize(text)
    # Remove stopwords
    words = [word for word in words if word not in stop_words]
    return ' '.join(words)

# Apply the preprocessing function to the 'original_text' column
data['processed_text'] = data['original_text'].apply(preprocess_text)
data[['original_text', 'processed_text']].head()

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Unnamed: 0,original_text,processed_text
0,what makes friendship click?,what makes friendship click
1,why does zebras have stripes?,why does zebras have stripes
2,what did the itsy bitsy sipder climb up?,what did the itsy bitsy sipder climb up
3,what is the difference between a bachelors and...,what is the difference between a bachelors and...
4,why do women get pms?,why do women get pms


## Data Splitting

In [28]:
test_ratio = 0.01  
instances_per_class = int(len(data) * test_ratio / 10)  # Calculate instances per class

# Sample data equally for each class
test_data = data.groupby('topic_category').sample(n=instances_per_class, random_state=42)

# Ensure balanced test set
print(test_data['topic_category'].value_counts())

topic_category
0.0    173
1.0    173
2.0    173
3.0    173
4.0    173
5.0    173
6.0    173
7.0    173
8.0    173
9.0    173
Name: count, dtype: int64


In [29]:
X_test = test_data['processed_text']
y_test = test_data['topic_category']

# Model Development

## QWEN

In [None]:
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

In [None]:
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move the model to the GPU
qwen_model.to(device)

In [None]:
# def generate_answer(is_zero_shot:bool):

#   generated_texts = []

#   for i in tqdm(range(len(data["original_text"]))):
#     text = data["original_text"][i]

#     if(is_zero_shot):
#       prompt = f"""Classify this text to positive or negative sentiment. text: {text} sentiment:"""
#     else:
#       prompt = f"""Classify this text to positive or negative sentiment. for example
#                   text:
#                   text: {text} sentiment:"""

#     inputs = tokenizer(prompt, return_tensors="pt").to(device)
#     outputs = model.generate(**inputs, max_new_tokens=10)

#     generated_text = (tokenizer.decode(outputs[0], skip_special_tokens=True)).replace(prompt, "")
#     generated_texts.append(generated_text)

#   return generated_texts

## Gemma

In [13]:
from langchain_huggingface import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [14]:
from langchain_core.messages import HumanMessage, SystemMessage

def format_prompt(text):
  messages = [
      SystemMessage(content="""You are a model that classify the topic from a text. 
            You only classify the text to 10 topics below.
            0: Society & Culture
            1: Science & Mathematics
            2: Health
            3: Education & Reference
            4: Computers & Internet
            5: Sports
            6: Business & Finance
            7: Entertainment & Music
            8: Family & Relationships
            9: Politics & Government

            Answer it in this format
            text: what did the itsy bitsy spider climb up
            topic: 3

            text: why do women get pms
            topic: 2
      
      """),
      HumanMessage(content=f"text: {text} topic:")
  ]
  return messages

In [15]:
from langchain_groq import ChatGroq
from kaggle_secrets import UserSecretsClient

chat_model = ChatGroq(
    api_key=  UserSecretsClient().get_secret("GROQ_API_KEY"),
    model="llama3-70b-8192",
    temperature=0,
    max_tokens=512,
)

In [16]:
from langchain_core.output_parsers import StrOutputParser

In [33]:
gemma_classifier = (
    format_prompt
    | chat_model
    | StrOutputParser()
)

In [34]:
response = gemma_classifier.invoke("tell me what guys only guys must do")

In [35]:
response.split("topic:")

['text: tell me what guys only guys must do\n', ' 8']

# Evaluation

## Inference

In [32]:
from tqdm import tqdm

### Gemma

In [36]:
gemma_predictions = []
for text in tqdm(X_test):
    gemma_predictions.append(gemma_classifier.invoke(text))
gemma_predictions

100%|██████████| 1730/1730 [1:04:19<00:00,  2.23s/it]


['topic: 9',
 'topic: 0',
 'topic: 7',
 'text: if life everafter is taken out of religious ideologywhat happens to amount of believers\ntopic: 0',
 'topic: 0',
 'topic: 0',
 'topic: 3',
 'topic: 1',
 'topic: 0',
 'topic: 0',
 'topic: 0',
 'text: could you tell me clever answers to any kind of insult without swear words\ntopic: 0',
 'topic: 4',
 'topic: 9',
 'topic: 0',
 'topic: 0',
 'topic: 0',
 'topic: 1',
 'topic: 4',
 'text: how do u grade instinct intellect and intuition\ntopic: 3',
 'topic: 3',
 'topic: 0',
 'text: how many ppl use myspace\ntopic: 4',
 'topic: 0',
 'topic: 0',
 'topic: 0',
 'text: which group is the coolest african americans european americans or van halen\ntopic: 7',
 'topic: 0',
 'text: does anyone have any brand new toys that would like to donate\ntopic: 8',
 'text: is it right for the nativity story to label itself based on a true story in the movie trailer\ntopic: 0',
 'text: what is the dialect spoken in perugia italy\ntopic: 0',
 'topic: 9',
 'topic: 1',
 '

## Output Parsing

In [37]:
existing_labels = [str(label) for label in range(10)]
def parse_output_to_label(output: str) -> int:
    output_label = output.split("topic: ")
    for label in existing_labels:
        if label in output_label:
            return int(label)
    return 0

### Gemma

In [38]:
gemma_label_predictions = [parse_output_to_label(output) for output in gemma_predictions]
gemma_label_predictions

[9,
 0,
 7,
 0,
 0,
 0,
 3,
 1,
 0,
 0,
 0,
 0,
 4,
 9,
 0,
 0,
 0,
 1,
 4,
 3,
 3,
 0,
 4,
 0,
 0,
 0,
 7,
 0,
 8,
 0,
 0,
 9,
 1,
 0,
 0,
 0,
 9,
 0,
 1,
 0,
 8,
 0,
 8,
 0,
 8,
 2,
 0,
 0,
 0,
 0,
 0,
 7,
 1,
 7,
 0,
 0,
 0,
 0,
 7,
 7,
 8,
 3,
 0,
 7,
 0,
 6,
 9,
 1,
 9,
 0,
 0,
 3,
 0,
 6,
 0,
 0,
 0,
 0,
 7,
 3,
 0,
 0,
 8,
 1,
 9,
 0,
 0,
 8,
 9,
 0,
 8,
 7,
 8,
 0,
 7,
 0,
 7,
 9,
 2,
 0,
 0,
 7,
 6,
 7,
 0,
 0,
 7,
 2,
 0,
 0,
 0,
 0,
 4,
 9,
 0,
 9,
 0,
 0,
 8,
 0,
 8,
 0,
 0,
 0,
 1,
 4,
 0,
 9,
 0,
 0,
 9,
 0,
 4,
 0,
 8,
 0,
 3,
 7,
 7,
 0,
 3,
 6,
 0,
 1,
 3,
 0,
 1,
 7,
 0,
 6,
 3,
 0,
 1,
 0,
 0,
 0,
 0,
 2,
 0,
 8,
 3,
 8,
 1,
 8,
 0,
 0,
 1,
 7,
 1,
 7,
 3,
 7,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 3,
 1,
 3,
 7,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 3,
 4,
 2,
 3,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 7,
 1,
 2,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 7,
 2,
 1,
 1,
 1,
 5,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 3,
 1,
 0,
 1,
 2,
 1,
 0,
 1,
 1,


## Metric Evaluation

In [39]:
from sklearn.metrics import classification_report

### Gemma

In [40]:
print("Classification Report:\n", classification_report(y_test, gemma_label_predictions))

Classification Report:
               precision    recall  f1-score   support

         0.0       0.42      0.49      0.45       173
         1.0       0.64      0.74      0.69       173
         2.0       0.73      0.83      0.78       173
         3.0       0.51      0.35      0.42       173
         4.0       0.73      0.89      0.80       173
         5.0       0.91      0.77      0.84       173
         6.0       0.56      0.46      0.50       173
         7.0       0.61      0.76      0.68       173
         8.0       0.68      0.64      0.66       173
         9.0       0.77      0.61      0.68       173

    accuracy                           0.65      1730
   macro avg       0.66      0.65      0.65      1730
weighted avg       0.66      0.65      0.65      1730

