In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import accuracy_score, classification_report

In [2]:
# Load the data
data = pd.read_csv('merged_dhia.csv')

# Shuffle the data
np.random.seed(1)
data = data.sample(frac=1).reset_index(drop=True)

In [3]:
# Parse the `processed_quotes` column from string to list
data['processed_quotes'] = data['processed_quotes'].apply(eval)

# Convert lists of tokens back into text for vectorization
data['text_data'] = data['processed_quotes'].apply(lambda x: ' '.join(x))

# Convert categories into numerical labels (if not already numeric)
data['Category'] = data['Category'].astype('category').cat.codes

# Split the dataset
train_one, test_one = train_test_split(data, test_size=0.2, random_state=1)


In [4]:
# Vectorize the `text_data` column
vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(train_one['text_data'])
X_test = vectorizer.transform(test_one['text_data'])


In [5]:
# Target column
y_train = train_one['Category']
y_test = test_one['Category']


In [6]:

knn = KNeighborsRegressor(n_neighbors=1, algorithm='auto')
knn.fit(X_train, y_train)  # Use dense data
predictions = knn.predict(X_test)
            
# Convert predictions to integer classes
predicted_classes = np.rint(predictions).astype(int)
predicted_classes = np.clip(predicted_classes, 0, data['Category'].max())
            

In [7]:
# Calculate accuracy
accuracy = accuracy_score(y_test, predicted_classes)

print(classification_report(y_test, predicted_classes))

              precision    recall  f1-score   support

           1       0.00      0.00      0.00        12
           3       1.00      0.07      0.13        14
           4       0.00      0.00      0.00        19
           5       0.00      0.00      0.00         7
           6       0.76      0.83      0.79        30
           8       0.26      0.45      0.33        20
           9       0.00      0.00      0.00         5
          10       1.00      0.20      0.33        10
          12       0.96      0.98      0.97       107
          13       0.17      0.10      0.12        10
          14       0.72      0.46      0.56        46
          15       0.45      0.57      0.51        35
          16       0.83      0.91      0.87        94
          17       0.56      0.43      0.49        21
          18       0.16      0.19      0.17        21
          19       0.90      0.62      0.73       105
          20       0.73      0.85      0.79        48
          21       0.06    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
