In [1]:
import json
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import time
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import RidgeClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import f1_score, classification_report
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
import nltk
from nltk.corpus import stopwords
import re

# Load data
with open('arxiv_data.json', 'r') as f:
    data = json.load(f)

# Create DataFrame
df = pd.DataFrame({
    'title': data['titles'],
    'abstract': data['summaries'],
    'labels': data['terms']
})

# Preprocess text
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
def preprocess_text(text):
    text = re.sub(r'\W', ' ', text)  # Remove non-word characters
    text = text.lower()  # Lowercase
    text = ' '.join([word for word in text.split() if word not in stop_words])
    return text

df['abstract'] = df['abstract'].apply(preprocess_text)

# Convert labels to binary matrix
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['labels'])

# Split dataset
train_texts, test_texts, y_train, y_test = train_test_split(df['abstract'], y, test_size=0.15, random_state=42)
train_texts, val_texts, y_train, y_val = train_test_split(train_texts, y_train, test_size=0.1765, random_state=42)

# Vectorize text
tfidf = TfidfVectorizer(max_features=500)
X_train = tfidf.fit_transform(train_texts).toarray()
X_val = tfidf.transform(val_texts).toarray()
X_test = tfidf.transform(test_texts).toarray()

# Define models
models = {
    'Naive Bayes': MultinomialNB(),
    'Logistic Regression': LogisticRegression(max_iter=1000),
    'Hist Gradient Boosting': HistGradientBoostingClassifier(max_iter=100),
    'K-Nearest Neighbors': KNeighborsClassifier(n_neighbors=3),
    'Ridge Classifier': RidgeClassifier()
}

# Track metrics
train_times = {}
inference_times = {}
val_micro_f1_scores = {}
val_macro_f1_scores = {}
reports = {}

# Train and evaluate models
total_models = len(models)
for idx, (name, model) in enumerate(models.items(), start=1):
    multi_target_model = MultiOutputClassifier(model, n_jobs=-1)
    
    # Training time
    start_time = time.time()
    multi_target_model.fit(X_train, y_train)
    train_times[name] = time.time() - start_time
    
    # Inference time
    start_time = time.time()
    y_val_pred = multi_target_model.predict(X_val)
    inference_times[name] = time.time() - start_time
    
    # Calculate micro and macro F1 scores
    micro_f1 = f1_score(y_val, y_val_pred, average='micro')
    macro_f1 = f1_score(y_val, y_val_pred, average='macro')
    val_micro_f1_scores[name] = micro_f1
    val_macro_f1_scores[name] = macro_f1
    
    print(f"{name} Validation Micro F1 Score: {micro_f1:.4f}")
    print(f"{name} Validation Macro F1 Score: {macro_f1:.4f}")
    
    # Generate classification report
    reports[name] = classification_report(y_val, y_val_pred, zero_division=0)
    print(f"{name} Validation Classification Report:\n{reports[name]}")
    
    # Show progress
    progress = (idx / total_models) * 100
    print(f"Progress: {progress:.2f}% complete\n")

# Select best model based on micro F1 score and generate test report
best_model_name = max(val_micro_f1_scores, key=val_micro_f1_scores.get)
best_model = MultiOutputClassifier(models[best_model_name], n_jobs=-1)
best_model.fit(X_train, y_train)
y_test_pred = best_model.predict(X_test)
test_report = classification_report(y_test, y_test_pred, zero_division=0)
print(f"\nTest Classification Report ({best_model_name}):\n{test_report}")

# Display training and inference times
print("\nTraining and Inference Times:")
for name in models.keys():
    print(f"{name} - Training time: {train_times[name]:.4f} seconds, Inference time: {inference_times[name]:.4f} seconds")


[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\USER\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Naive Bayes Validation Micro F1 Score: 0.7026
Naive Bayes Validation Macro F1 Score: 0.0299
Naive Bayes Validation Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         4
           1       0.00      0.00      0.00        11
           2       0.00      0.00      0.00        10
           3       0.00      0.00      0.00         6
           4       0.55      0.16      0.25      1150
           5       0.00      0.00      0.00         5
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.00      0.00      0.00        19
           9       0.00      0.00      0.00       238
          10       0.00      0.00      0.00       103
          11       0.89      0.92      0.91      4526
          12       0.00      0.00      0.00        30
          13       0.00      0.00      0.00        23
          14       0.00      0.00      0.00        

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Logistic Regression Validation Micro F1 Score: 0.7216
Logistic Regression Validation Macro F1 Score: 0.0452
Logistic Regression Validation Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         4
           1       0.00      0.00      0.00        11
           2       0.00      0.00      0.00        10
           3       0.00      0.00      0.00         6
           4       0.58      0.15      0.24      1150
           5       0.00      0.00      0.00         5
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.00      0.00      0.00        19
           9       0.56      0.15      0.23       238
          10       0.42      0.08      0.13       103
          11       0.94      0.91      0.92      4526
          12       0.00      0.00      0.00        30
          13       0.00      0.00      0.00        23
          14       0.00    

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


K-Nearest Neighbors Validation Micro F1 Score: 0.6951
K-Nearest Neighbors Validation Macro F1 Score: 0.1300
K-Nearest Neighbors Validation Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         4
           1       0.00      0.00      0.00        11
           2       1.00      0.20      0.33        10
           3       0.20      0.17      0.18         6
           4       0.46      0.36      0.41      1150
           5       0.00      0.00      0.00         5
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.56      0.26      0.36        19
           9       0.50      0.32      0.39       238
          10       0.38      0.15      0.21       103
          11       0.89      0.88      0.88      4526
          12       0.14      0.03      0.05        30
          13       1.00      0.04      0.08        23
          14       0.33    

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Ridge Classifier Validation Micro F1 Score: 0.7131
Ridge Classifier Validation Macro F1 Score: 0.0310
Ridge Classifier Validation Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         4
           1       0.00      0.00      0.00        11
           2       0.00      0.00      0.00        10
           3       0.00      0.00      0.00         6
           4       0.59      0.14      0.22      1150
           5       0.00      0.00      0.00         5
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.00      0.00      0.00        19
           9       0.55      0.05      0.09       238
          10       0.00      0.00      0.00       103
          11       0.93      0.90      0.92      4526
          12       0.00      0.00      0.00        30
          13       0.00      0.00      0.00        23
          14       0.00      0.00   