In [None]:
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/jeopardy/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n",
      "[nltk_data] Downloading package stopwords to\n",
      "[nltk_data]     /home/jeopardy/nltk_data...\n",
      "[nltk_data]   Package stopwords is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np \n",
    "import pandas as pd\n",
    "import nltk\n",
    "import string\n",
    "from nltk.stem.porter import PorterStemmer\n",
    "import re\n",
    "nltk.download('punkt')\n",
    "nltk.download('stopwords')\n",
    "stemmer = PorterStemmer()\n",
    "stopwords = set(nltk.corpus.stopwords.words('english'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ratio of duplicate questions in the splits\n",
      "Train set:  0.36919749967314835\n",
      "Validation set:  0.36920279997031835\n",
      "Test set:  0.369190432610255\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "df = pd.read_csv('questions.csv.zip')\n",
    "df.dropna(how=\"any\").reset_index(drop=True)\n",
    "\n",
    "# train validate test split 70:20:10\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train_q1, X_test_q1, X_train_q2, X_test_q2, y_train, y_test = train_test_split(df['question1'], df['question2'], df['is_duplicate'], test_size=0.3, random_state=42, stratify=df['is_duplicate'])\n",
    "X_val_q1, X_test_q1, X_val_q2, X_test_q2, y_val, y_test = train_test_split(X_test_q1, X_test_q2, y_test, test_size=(1/3), random_state=42, stratify=y_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ratio of duplicate questions in the splits\n",
      "Train set:  0.36919749967314835\n",
      "Validation set:  0.36920279997031835\n",
      "Test set:  0.369190432610255\n"
     ]
    }
   ],
   "source": [
    "y_train, y_val, y_test = np.array(y_train), np.array(y_val), np.array(y_test)\n",
    "# ratio of duplicate questions in train, validation and test set\n",
    "print(\"Ratio of duplicate questions in the splits\")\n",
    "print(\"Train set: \", y_train.sum()/len(y_train))\n",
    "print(\"Validation set: \", y_val.sum()/len(y_val))\n",
    "print(\"Test set: \", y_test.sum()/len(y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT = './input/'\n",
    "TRAIN_LINEAR_PATH = INPUT + 'train_linear.csv.zip'\n",
    "TEST_LINEAR_PATH = INPUT + 'test_linear.csv.zip'\n",
    "VAL_LINEAR_PATH = INPUT + 'val_linear.csv.zip'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_linear = pd.DataFrame({'question1': X_train_q1, 'question2': X_train_q2, 'is_duplicate': y_train})\n",
    "val_linear = pd.DataFrame({'question1': X_val_q1, 'question2': X_val_q2, 'is_duplicate': y_val})\n",
    "test_linear = pd.DataFrame({'question1': X_test_q1, 'question2': X_test_q2, 'is_duplicate': y_test})\n",
    "allQuestions = pd.concat((train_linear['question1'], train_linear['question2'])).reset_index(drop=True).astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dump split files\n",
    "os.makedirs(INPUT, exist_ok=True)  \n",
    "train_linear.to_csv(TRAIN_LINEAR_PATH, index=False, compression='zip')\n",
    "val_linear.to_csv(VAL_LINEAR_PATH, index=False, compression='zip')\n",
    "test_linear.to_csv(TEST_LINEAR_PATH, index=False, compression='zip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_linear = pd.read_csv(TRAIN_LINEAR_PATH)\n",
    "val_linear = pd.read_csv(VAL_LINEAR_PATH)\n",
    "test_linear = pd.read_csv(TEST_LINEAR_PATH)\n",
    "X_train_q1, X_train_q2, y_train = train_linear['question1'].astype('U').values, train_linear['question2'].astype('U').values, train_linear['is_duplicate'].values\n",
    "X_val_q1, X_val_q2, y_val = val_linear['question1'].astype('U').values, val_linear['question2'].astype('U').values, val_linear['is_duplicate'].values\n",
    "X_test_q1, X_test_q2, y_test = test_linear['question1'].astype('U').values, test_linear['question2'].astype('U').values, test_linear['is_duplicate'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "from sklearn.svm import SVC, LinearSVC\n",
    "from sklearn.metrics import f1_score\n",
    "from sklearn.metrics import accuracy_score\n",
    "from scipy.sparse import hstack as sparse_hstack, vstack as sparse_vstack, save_npz, load_npz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "stemmer = PorterStemmer()\n",
    "\n",
    "def tokenize(text: str) -> list[str]:\n",
    "    tokens = nltk.word_tokenize(re.sub(r'[^\\x00-\\x7F]+',' ', text))\n",
    "    tokens = [stemmer.stem(w) for w in tokens if stemmer.stem(w) not in stopwords]\n",
    "    return tokens\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_GRAMS_PATH = './n_gram_features/'\n",
    "UNIGRAM_PATH = N_GRAMS_PATH + 'unigrams_linear/'\n",
    "BIGRAM_PATH = N_GRAMS_PATH + 'bigrams_linear/'\n",
    "TRIGRAM_PATH = N_GRAMS_PATH + 'trigrams_linear/'\n",
    "os.makedirs(N_GRAMS_PATH, exist_ok=True)\n",
    "os.makedirs(UNIGRAM_PATH, exist_ok=True)\n",
    "os.makedirs(BIGRAM_PATH, exist_ok=True)\n",
    "os.makedirs(TRIGRAM_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating Unigram features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "unigramVectorizer = CountVectorizer(\n",
    "                        analyzer='word', \n",
    "                        ngram_range=(1,1), \n",
    "                        lowercase=True,\n",
    "                        tokenizer=tokenize\n",
    "                    )\n",
    "                    \n",
    "unigramVectorizer.fit(allQuestions)\n",
    "q1_train = unigramVectorizer.transform(train_linear['question1'].astype(str))\n",
    "q2_train = unigramVectorizer.transform(train_linear['question2'].astype(str))\n",
    "X_train_unigram = sparse_hstack([q1_train, q2_train])\n",
    "q1_val = unigramVectorizer.transform(val_linear['question1'].astype(str))\n",
    "q2_val = unigramVectorizer.transform(val_linear['question2'].astype(str))\n",
    "X_val_unigram = sparse_hstack([q1_val, q2_val])\n",
    "q1_test = unigramVectorizer.transform(test_linear['question1'].astype(str))\n",
    "q2_test = unigramVectorizer.transform(test_linear['question2'].astype(str))\n",
    "X_test_unigram = sparse_hstack([q1_test, q2_test])\n",
    "\n",
    "save_npz(UNIGRAM_PATH + \"train.npz\", X_train_unigram)\n",
    "save_npz(UNIGRAM_PATH + \"val.npz\", X_val_unigram)\n",
    "save_npz(UNIGRAM_PATH + \"test.npz\", X_test_unigram)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating Bigram features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "bigramVectorizer = CountVectorizer(\n",
    "                        analyzer='word', \n",
    "                        ngram_range=(1,2), \n",
    "                        lowercase=True,\n",
    "                        tokenizer=tokenize\n",
    "                    )\n",
    "                    \n",
    "bigramVectorizer.fit(allQuestions)\n",
    "q1_train = bigramVectorizer.transform(train_linear['question1'].astype(str))\n",
    "q2_train = bigramVectorizer.transform(train_linear['question2'].astype(str))\n",
    "X_train_bigram = sparse_hstack([q1_train, q2_train])\n",
    "q1_val = bigramVectorizer.transform(val_linear['question1'].astype(str))\n",
    "q2_val = bigramVectorizer.transform(val_linear['question2'].astype(str))\n",
    "X_val_bigram = sparse_hstack([q1_val, q2_val])\n",
    "q1_test = bigramVectorizer.transform(test_linear['question1'].astype(str))\n",
    "q2_test = bigramVectorizer.transform(test_linear['question2'].astype(str))\n",
    "X_test_bigram = sparse_hstack([q1_test, q2_test])\n",
    "\n",
    "save_npz(BIGRAM_PATH + \"train.npz\", X_train_bigram)\n",
    "save_npz(BIGRAM_PATH + \"val.npz\", X_val_bigram)\n",
    "save_npz(BIGRAM_PATH + \"test.npz\", X_test_bigram)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating Trigram features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "trigramVectorizer = CountVectorizer(\n",
    "                        analyzer='word', \n",
    "                        ngram_range=(1,3), \n",
    "                        lowercase=True,\n",
    "                        tokenizer=tokenize\n",
    "                    )\n",
    "                    \n",
    "trigramVectorizer.fit(allQuestions)\n",
    "q1_train = trigramVectorizer.transform(train_linear['question1'].astype(str))\n",
    "q2_train = trigramVectorizer.transform(train_linear['question2'].astype(str))\n",
    "X_train_trigram = sparse_hstack([q1_train, q2_train])\n",
    "q1_val = trigramVectorizer.transform(val_linear['question1'].astype(str))\n",
    "q2_val = trigramVectorizer.transform(val_linear['question2'].astype(str))\n",
    "X_val_trigram = sparse_hstack([q1_val, q2_val])\n",
    "q1_test = trigramVectorizer.transform(test_linear['question1'].astype(str))\n",
    "q2_test = trigramVectorizer.transform(test_linear['question2'].astype(str))\n",
    "X_test_trigram = sparse_hstack([q1_test, q2_test])\n",
    "\n",
    "save_npz(TRIGRAM_PATH + \"train.npz\", X_train_trigram)\n",
    "save_npz(TRIGRAM_PATH + \"val.npz\", X_val_trigram)\n",
    "save_npz(TRIGRAM_PATH + \"test.npz\", X_test_trigram)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logistic Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Unigrams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unigram Logistic Regression Accuracy:  0.7418437260382399\n",
      "Unigram Logistic Regression F1 Score:  0.6310840903467534\n"
     ]
    }
   ],
   "source": [
    "# X_train_unigram = load_npz(UNIGRAM_PATH + \"train.npz\")\n",
    "# X_test_unigram = load_npz(UNIGRAM_PATH + \"test.npz\")\n",
    "unigramLogisticRegressor = SGDClassifier(\n",
    "                            loss='log_loss', \n",
    "                            penalty='l2', \n",
    "                            alpha=0.00001, \n",
    "                            max_iter=1000,\n",
    "                            n_iter_no_change=20,\n",
    "                            learning_rate='optimal',\n",
    "                            n_jobs=-1,\n",
    "                            random_state=42)\n",
    "unigramLogisticRegressor.fit(X_train_unigram, y_train)\n",
    "y_pred_unigram_logistic = unigramLogisticRegressor.predict(X_test_unigram)\n",
    "print(\"Unigram Logistic Regression Accuracy: \", accuracy_score(y_test, y_pred_unigram_logistic))\n",
    "print(\"Unigram Logistic Regression F1 Score: \", f1_score(y_test, y_pred_unigram_logistic))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bigrams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bigram Logistic Regression Accuracy:  0.7962106408765984\n",
      "bigram Logistic Regression F1 Score:  0.7066405554566494\n"
     ]
    }
   ],
   "source": [
    "# X_train_bigram = load_npz(BIGRAM_PATH + \"train.npz\")\n",
    "# X_test_bigram = load_npz(BIGRAM_PATH + \"test.npz\")\n",
    "bigramLogisticRegressor = SGDClassifier(\n",
    "                            loss='log_loss', \n",
    "                            penalty='l2', \n",
    "                            alpha=0.00001, \n",
    "                            max_iter=1000,\n",
    "                            n_iter_no_change=20,\n",
    "                            learning_rate='optimal',\n",
    "                            n_jobs=-1,\n",
    "                            random_state=42)\n",
    "bigramLogisticRegressor.fit(X_train_bigram, y_train)\n",
    "y_pred_bigram_logistic = bigramLogisticRegressor.predict(X_test_bigram)\n",
    "print(\"bigram Logistic Regression Accuracy: \", accuracy_score(y_test, y_pred_bigram_logistic))\n",
    "print(\"bigram Logistic Regression F1 Score: \", f1_score(y_test, y_pred_bigram_logistic))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trigrams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trigram Logistic Regression Accuracy:  0.8114472284746098\n",
      "trigram Logistic Regression F1 Score:  0.7147614593077642\n"
     ]
    }
   ],
   "source": [
    "# X_train_trigram = load_npz(TRIGRAM_PATH + \"train.npz\")\n",
    "# X_test_trigram = load_npz(TRIGRAM_PATH + \"test.npz\")\n",
    "trigramLogisticRegressor = SGDClassifier(\n",
    "                            loss='log_loss', \n",
    "                            penalty='l2', \n",
    "                            alpha=0.00001, \n",
    "                            max_iter=1000,\n",
    "                            n_iter_no_change=20,\n",
    "                            learning_rate='optimal',\n",
    "                            n_jobs=-1,\n",
    "                            random_state=42)\n",
    "trigramLogisticRegressor.fit(X_train_trigram, y_train)\n",
    "y_pred_trigram_logistic = trigramLogisticRegressor.predict(X_test_trigram)\n",
    "print(\"trigram Logistic Regression Accuracy: \", accuracy_score(y_test, y_pred_trigram_logistic))\n",
    "print(\"trigram Logistic Regression F1 Score: \", f1_score(y_test, y_pred_trigram_logistic))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trigrams Tuned"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Applying GridSearchCV on Trigrams model to get the best set of parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-4 {color: black;background-color: white;}#sk-container-id-4 pre{padding: 0;}#sk-container-id-4 div.sk-toggleable {background-color: white;}#sk-container-id-4 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-4 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-4 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-4 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-4 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-4 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-4 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-4 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-4 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-4 div.sk-item {position: relative;z-index: 1;}#sk-container-id-4 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-4 div.sk-item::before, #sk-container-id-4 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-4 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-4 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-4 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-4 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-4 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-4 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-4 div.sk-label-container {text-align: center;}#sk-container-id-4 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-4 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-4\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=StratifiedKFold(n_splits=5, random_state=None, shuffle=False),\n",
       "             estimator=SGDClassifier(loss=&#x27;log_loss&#x27;, n_jobs=-1,\n",
       "                                     random_state=42),\n",
       "             n_jobs=-1,\n",
       "             param_grid={&#x27;alpha&#x27;: [0.01, 0.001, 0.0001, 1e-05, 1e-06],\n",
       "                         &#x27;n_iter_no_change&#x27;: [5, 10, 15, 20]},\n",
       "             refit=&#x27;f1&#x27;, scoring=[&#x27;accuracy&#x27;, &#x27;f1&#x27;])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-10\" type=\"checkbox\" ><label for=\"sk-estimator-id-10\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=StratifiedKFold(n_splits=5, random_state=None, shuffle=False),\n",
       "             estimator=SGDClassifier(loss=&#x27;log_loss&#x27;, n_jobs=-1,\n",
       "                                     random_state=42),\n",
       "             n_jobs=-1,\n",
       "             param_grid={&#x27;alpha&#x27;: [0.01, 0.001, 0.0001, 1e-05, 1e-06],\n",
       "                         &#x27;n_iter_no_change&#x27;: [5, 10, 15, 20]},\n",
       "             refit=&#x27;f1&#x27;, scoring=[&#x27;accuracy&#x27;, &#x27;f1&#x27;])</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-11\" type=\"checkbox\" ><label for=\"sk-estimator-id-11\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: SGDClassifier</label><div class=\"sk-toggleable__content\"><pre>SGDClassifier(loss=&#x27;log_loss&#x27;, n_jobs=-1, random_state=42)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-12\" type=\"checkbox\" ><label for=\"sk-estimator-id-12\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SGDClassifier</label><div class=\"sk-toggleable__content\"><pre>SGDClassifier(loss=&#x27;log_loss&#x27;, n_jobs=-1, random_state=42)</pre></div></div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "GridSearchCV(cv=StratifiedKFold(n_splits=5, random_state=None, shuffle=False),\n",
       "             estimator=SGDClassifier(loss='log_loss', n_jobs=-1,\n",
       "                                     random_state=42),\n",
       "             n_jobs=-1,\n",
       "             param_grid={'alpha': [0.01, 0.001, 0.0001, 1e-05, 1e-06],\n",
       "                         'n_iter_no_change': [5, 10, 15, 20]},\n",
       "             refit='f1', scoring=['accuracy', 'f1'])"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trigramLogisticRegressor = SGDClassifier(\n",
    "                            loss='log_loss',\n",
    "                            penalty='l2',\n",
    "                            max_iter=1000,\n",
    "                            learning_rate='optimal',\n",
    "                            n_jobs=-1,\n",
    "                            random_state=42)\n",
    "parameters = dict({\n",
    "                'alpha':[0.01, 0.001, 0.0001, 0.00001, 0.000001],\n",
    "                'n_iter_no_change': [5, 10, 15, 20]\n",
    "            })\n",
    "cv_stratified_splitter = StratifiedKFold(n_splits=5)\n",
    "grid_search = GridSearchCV(trigramLogisticRegressor, \n",
    "                            parameters, \n",
    "                            cv=cv_stratified_splitter, \n",
    "                            scoring=['accuracy', 'f1'], \n",
    "                            n_jobs=-1,\n",
    "                            refit='f1')\n",
    "grid_search.fit(sparse_vstack([X_train_trigram, X_val_trigram]), np.concatenate((y_train, y_val)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>mean_fit_time</th>\n",
       "      <th>std_fit_time</th>\n",
       "      <th>mean_score_time</th>\n",
       "      <th>std_score_time</th>\n",
       "      <th>param_alpha</th>\n",
       "      <th>param_n_iter_no_change</th>\n",
       "      <th>params</th>\n",
       "      <th>split0_test_accuracy</th>\n",
       "      <th>split1_test_accuracy</th>\n",
       "      <th>...</th>\n",
       "      <th>std_test_accuracy</th>\n",
       "      <th>rank_test_accuracy</th>\n",
       "      <th>split0_test_f1</th>\n",
       "      <th>split1_test_f1</th>\n",
       "      <th>split2_test_f1</th>\n",
       "      <th>split3_test_f1</th>\n",
       "      <th>split4_test_f1</th>\n",
       "      <th>mean_test_f1</th>\n",
       "      <th>std_test_f1</th>\n",
       "      <th>rank_test_f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>32.057936</td>\n",
       "      <td>4.823794</td>\n",
       "      <td>1.462032</td>\n",
       "      <td>0.469907</td>\n",
       "      <td>0.01</td>\n",
       "      <td>5</td>\n",
       "      <td>{'alpha': 0.01, 'n_iter_no_change': 5}</td>\n",
       "      <td>0.692152</td>\n",
       "      <td>0.691969</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000602</td>\n",
       "      <td>17</td>\n",
       "      <td>0.352870</td>\n",
       "      <td>0.351276</td>\n",
       "      <td>0.352789</td>\n",
       "      <td>0.354524</td>\n",
       "      <td>0.356614</td>\n",
       "      <td>0.353614</td>\n",
       "      <td>0.001818</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>34.839753</td>\n",
       "      <td>5.847555</td>\n",
       "      <td>1.570614</td>\n",
       "      <td>0.270366</td>\n",
       "      <td>0.01</td>\n",
       "      <td>10</td>\n",
       "      <td>{'alpha': 0.01, 'n_iter_no_change': 10}</td>\n",
       "      <td>0.691190</td>\n",
       "      <td>0.691969</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000497</td>\n",
       "      <td>18</td>\n",
       "      <td>0.347076</td>\n",
       "      <td>0.351426</td>\n",
       "      <td>0.348466</td>\n",
       "      <td>0.349686</td>\n",
       "      <td>0.349993</td>\n",
       "      <td>0.349329</td>\n",
       "      <td>0.001469</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>41.940447</td>\n",
       "      <td>10.329488</td>\n",
       "      <td>1.597674</td>\n",
       "      <td>0.275955</td>\n",
       "      <td>0.01</td>\n",
       "      <td>15</td>\n",
       "      <td>{'alpha': 0.01, 'n_iter_no_change': 15}</td>\n",
       "      <td>0.691122</td>\n",
       "      <td>0.691461</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000457</td>\n",
       "      <td>19</td>\n",
       "      <td>0.346570</td>\n",
       "      <td>0.348641</td>\n",
       "      <td>0.347652</td>\n",
       "      <td>0.350843</td>\n",
       "      <td>0.350410</td>\n",
       "      <td>0.348823</td>\n",
       "      <td>0.001617</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>44.677039</td>\n",
       "      <td>6.685752</td>\n",
       "      <td>1.307405</td>\n",
       "      <td>0.393770</td>\n",
       "      <td>0.01</td>\n",
       "      <td>20</td>\n",
       "      <td>{'alpha': 0.01, 'n_iter_no_change': 20}</td>\n",
       "      <td>0.690916</td>\n",
       "      <td>0.690843</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000539</td>\n",
       "      <td>20</td>\n",
       "      <td>0.345848</td>\n",
       "      <td>0.345646</td>\n",
       "      <td>0.346039</td>\n",
       "      <td>0.347637</td>\n",
       "      <td>0.349897</td>\n",
       "      <td>0.347013</td>\n",
       "      <td>0.001605</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>24.846832</td>\n",
       "      <td>6.631401</td>\n",
       "      <td>1.670751</td>\n",
       "      <td>0.193314</td>\n",
       "      <td>0.001</td>\n",
       "      <td>5</td>\n",
       "      <td>{'alpha': 0.001, 'n_iter_no_change': 5}</td>\n",
       "      <td>0.741415</td>\n",
       "      <td>0.742827</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001528</td>\n",
       "      <td>14</td>\n",
       "      <td>0.546314</td>\n",
       "      <td>0.550779</td>\n",
       "      <td>0.547320</td>\n",
       "      <td>0.547302</td>\n",
       "      <td>0.555941</td>\n",
       "      <td>0.549531</td>\n",
       "      <td>0.003546</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>31.927570</td>\n",
       "      <td>2.169471</td>\n",
       "      <td>0.891322</td>\n",
       "      <td>0.097146</td>\n",
       "      <td>0.001</td>\n",
       "      <td>10</td>\n",
       "      <td>{'alpha': 0.001, 'n_iter_no_change': 10}</td>\n",
       "      <td>0.742473</td>\n",
       "      <td>0.743638</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001614</td>\n",
       "      <td>13</td>\n",
       "      <td>0.552838</td>\n",
       "      <td>0.556443</td>\n",
       "      <td>0.547028</td>\n",
       "      <td>0.549382</td>\n",
       "      <td>0.549654</td>\n",
       "      <td>0.551069</td>\n",
       "      <td>0.003262</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>35.785039</td>\n",
       "      <td>6.948313</td>\n",
       "      <td>1.889833</td>\n",
       "      <td>0.513881</td>\n",
       "      <td>0.001</td>\n",
       "      <td>15</td>\n",
       "      <td>{'alpha': 0.001, 'n_iter_no_change': 15}</td>\n",
       "      <td>0.741072</td>\n",
       "      <td>0.743129</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001499</td>\n",
       "      <td>15</td>\n",
       "      <td>0.545130</td>\n",
       "      <td>0.551458</td>\n",
       "      <td>0.546284</td>\n",
       "      <td>0.545874</td>\n",
       "      <td>0.550921</td>\n",
       "      <td>0.547933</td>\n",
       "      <td>0.002690</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>59.535236</td>\n",
       "      <td>2.297733</td>\n",
       "      <td>1.418594</td>\n",
       "      <td>0.290130</td>\n",
       "      <td>0.001</td>\n",
       "      <td>20</td>\n",
       "      <td>{'alpha': 0.001, 'n_iter_no_change': 20}</td>\n",
       "      <td>0.741236</td>\n",
       "      <td>0.741796</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001460</td>\n",
       "      <td>16</td>\n",
       "      <td>0.546055</td>\n",
       "      <td>0.544507</td>\n",
       "      <td>0.539016</td>\n",
       "      <td>0.542346</td>\n",
       "      <td>0.547395</td>\n",
       "      <td>0.543864</td>\n",
       "      <td>0.002949</td>\n",
       "      <td>16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>32.525184</td>\n",
       "      <td>8.522953</td>\n",
       "      <td>1.530440</td>\n",
       "      <td>0.418499</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>5</td>\n",
       "      <td>{'alpha': 0.0001, 'n_iter_no_change': 5}</td>\n",
       "      <td>0.780660</td>\n",
       "      <td>0.782870</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001323</td>\n",
       "      <td>9</td>\n",
       "      <td>0.652092</td>\n",
       "      <td>0.650675</td>\n",
       "      <td>0.650521</td>\n",
       "      <td>0.648484</td>\n",
       "      <td>0.657971</td>\n",
       "      <td>0.651949</td>\n",
       "      <td>0.003223</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>49.180857</td>\n",
       "      <td>11.419878</td>\n",
       "      <td>1.625803</td>\n",
       "      <td>0.334466</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>10</td>\n",
       "      <td>{'alpha': 0.0001, 'n_iter_no_change': 10}</td>\n",
       "      <td>0.780578</td>\n",
       "      <td>0.782719</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001481</td>\n",
       "      <td>12</td>\n",
       "      <td>0.650713</td>\n",
       "      <td>0.653565</td>\n",
       "      <td>0.643136</td>\n",
       "      <td>0.640263</td>\n",
       "      <td>0.642371</td>\n",
       "      <td>0.646010</td>\n",
       "      <td>0.005172</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>10</td>\n",
       "      <td>40.283555</td>\n",
       "      <td>10.689865</td>\n",
       "      <td>1.386696</td>\n",
       "      <td>0.343637</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>15</td>\n",
       "      <td>{'alpha': 0.0001, 'n_iter_no_change': 15}</td>\n",
       "      <td>0.780908</td>\n",
       "      <td>0.782870</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001463</td>\n",
       "      <td>10</td>\n",
       "      <td>0.653647</td>\n",
       "      <td>0.655294</td>\n",
       "      <td>0.651344</td>\n",
       "      <td>0.648414</td>\n",
       "      <td>0.656066</td>\n",
       "      <td>0.652953</td>\n",
       "      <td>0.002787</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>11</td>\n",
       "      <td>57.527124</td>\n",
       "      <td>5.102434</td>\n",
       "      <td>1.504030</td>\n",
       "      <td>0.103422</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>20</td>\n",
       "      <td>{'alpha': 0.0001, 'n_iter_no_change': 20}</td>\n",
       "      <td>0.780688</td>\n",
       "      <td>0.782691</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001637</td>\n",
       "      <td>11</td>\n",
       "      <td>0.650919</td>\n",
       "      <td>0.652103</td>\n",
       "      <td>0.647006</td>\n",
       "      <td>0.647918</td>\n",
       "      <td>0.656548</td>\n",
       "      <td>0.650899</td>\n",
       "      <td>0.003388</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>12</td>\n",
       "      <td>33.256662</td>\n",
       "      <td>3.574577</td>\n",
       "      <td>1.443488</td>\n",
       "      <td>0.112564</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>5</td>\n",
       "      <td>{'alpha': 1e-05, 'n_iter_no_change': 5}</td>\n",
       "      <td>0.807332</td>\n",
       "      <td>0.808237</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001453</td>\n",
       "      <td>1</td>\n",
       "      <td>0.698025</td>\n",
       "      <td>0.714616</td>\n",
       "      <td>0.709538</td>\n",
       "      <td>0.707157</td>\n",
       "      <td>0.707429</td>\n",
       "      <td>0.707353</td>\n",
       "      <td>0.005377</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>13</td>\n",
       "      <td>32.103819</td>\n",
       "      <td>3.748073</td>\n",
       "      <td>1.302405</td>\n",
       "      <td>0.426505</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>10</td>\n",
       "      <td>{'alpha': 1e-05, 'n_iter_no_change': 10}</td>\n",
       "      <td>0.806453</td>\n",
       "      <td>0.807934</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001237</td>\n",
       "      <td>4</td>\n",
       "      <td>0.713679</td>\n",
       "      <td>0.701722</td>\n",
       "      <td>0.710008</td>\n",
       "      <td>0.709509</td>\n",
       "      <td>0.719614</td>\n",
       "      <td>0.710906</td>\n",
       "      <td>0.005842</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>14</td>\n",
       "      <td>44.492694</td>\n",
       "      <td>0.965001</td>\n",
       "      <td>1.558569</td>\n",
       "      <td>0.328143</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>15</td>\n",
       "      <td>{'alpha': 1e-05, 'n_iter_no_change': 15}</td>\n",
       "      <td>0.804680</td>\n",
       "      <td>0.808828</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001937</td>\n",
       "      <td>2</td>\n",
       "      <td>0.713057</td>\n",
       "      <td>0.706113</td>\n",
       "      <td>0.711488</td>\n",
       "      <td>0.711608</td>\n",
       "      <td>0.715967</td>\n",
       "      <td>0.711647</td>\n",
       "      <td>0.003203</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>15</td>\n",
       "      <td>58.558907</td>\n",
       "      <td>10.112322</td>\n",
       "      <td>1.424766</td>\n",
       "      <td>0.203237</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>20</td>\n",
       "      <td>{'alpha': 1e-05, 'n_iter_no_change': 20}</td>\n",
       "      <td>0.806742</td>\n",
       "      <td>0.807426</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001874</td>\n",
       "      <td>3</td>\n",
       "      <td>0.712181</td>\n",
       "      <td>0.711004</td>\n",
       "      <td>0.715343</td>\n",
       "      <td>0.710696</td>\n",
       "      <td>0.715824</td>\n",
       "      <td>0.713009</td>\n",
       "      <td>0.002164</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>16</td>\n",
       "      <td>39.761328</td>\n",
       "      <td>0.877920</td>\n",
       "      <td>1.103284</td>\n",
       "      <td>0.397345</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>5</td>\n",
       "      <td>{'alpha': 1e-06, 'n_iter_no_change': 5}</td>\n",
       "      <td>0.802949</td>\n",
       "      <td>0.803784</td>\n",
       "      <td>...</td>\n",
       "      <td>0.003224</td>\n",
       "      <td>7</td>\n",
       "      <td>0.706448</td>\n",
       "      <td>0.698482</td>\n",
       "      <td>0.704552</td>\n",
       "      <td>0.708678</td>\n",
       "      <td>0.718208</td>\n",
       "      <td>0.707274</td>\n",
       "      <td>0.006433</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>17</td>\n",
       "      <td>45.833273</td>\n",
       "      <td>3.359309</td>\n",
       "      <td>1.549358</td>\n",
       "      <td>0.397857</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>10</td>\n",
       "      <td>{'alpha': 1e-06, 'n_iter_no_change': 10}</td>\n",
       "      <td>0.801602</td>\n",
       "      <td>0.802960</td>\n",
       "      <td>...</td>\n",
       "      <td>0.002758</td>\n",
       "      <td>5</td>\n",
       "      <td>0.708099</td>\n",
       "      <td>0.707074</td>\n",
       "      <td>0.710392</td>\n",
       "      <td>0.709978</td>\n",
       "      <td>0.715186</td>\n",
       "      <td>0.710146</td>\n",
       "      <td>0.002797</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>18</td>\n",
       "      <td>56.188770</td>\n",
       "      <td>2.537069</td>\n",
       "      <td>1.026216</td>\n",
       "      <td>0.285707</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>15</td>\n",
       "      <td>{'alpha': 1e-06, 'n_iter_no_change': 15}</td>\n",
       "      <td>0.798497</td>\n",
       "      <td>0.802273</td>\n",
       "      <td>...</td>\n",
       "      <td>0.001822</td>\n",
       "      <td>6</td>\n",
       "      <td>0.714152</td>\n",
       "      <td>0.715007</td>\n",
       "      <td>0.711416</td>\n",
       "      <td>0.706668</td>\n",
       "      <td>0.719169</td>\n",
       "      <td>0.713283</td>\n",
       "      <td>0.004138</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>19</td>\n",
       "      <td>50.262399</td>\n",
       "      <td>7.179083</td>\n",
       "      <td>0.540624</td>\n",
       "      <td>0.348990</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>20</td>\n",
       "      <td>{'alpha': 1e-06, 'n_iter_no_change': 20}</td>\n",
       "      <td>0.797260</td>\n",
       "      <td>0.799813</td>\n",
       "      <td>...</td>\n",
       "      <td>0.002437</td>\n",
       "      <td>8</td>\n",
       "      <td>0.712611</td>\n",
       "      <td>0.715613</td>\n",
       "      <td>0.713257</td>\n",
       "      <td>0.711372</td>\n",
       "      <td>0.716722</td>\n",
       "      <td>0.713915</td>\n",
       "      <td>0.001968</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>20 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    index  mean_fit_time  std_fit_time  mean_score_time  std_score_time  \\\n",
       "0       0      32.057936      4.823794         1.462032        0.469907   \n",
       "1       1      34.839753      5.847555         1.570614        0.270366   \n",
       "2       2      41.940447     10.329488         1.597674        0.275955   \n",
       "3       3      44.677039      6.685752         1.307405        0.393770   \n",
       "4       4      24.846832      6.631401         1.670751        0.193314   \n",
       "5       5      31.927570      2.169471         0.891322        0.097146   \n",
       "6       6      35.785039      6.948313         1.889833        0.513881   \n",
       "7       7      59.535236      2.297733         1.418594        0.290130   \n",
       "8       8      32.525184      8.522953         1.530440        0.418499   \n",
       "9       9      49.180857     11.419878         1.625803        0.334466   \n",
       "10     10      40.283555     10.689865         1.386696        0.343637   \n",
       "11     11      57.527124      5.102434         1.504030        0.103422   \n",
       "12     12      33.256662      3.574577         1.443488        0.112564   \n",
       "13     13      32.103819      3.748073         1.302405        0.426505   \n",
       "14     14      44.492694      0.965001         1.558569        0.328143   \n",
       "15     15      58.558907     10.112322         1.424766        0.203237   \n",
       "16     16      39.761328      0.877920         1.103284        0.397345   \n",
       "17     17      45.833273      3.359309         1.549358        0.397857   \n",
       "18     18      56.188770      2.537069         1.026216        0.285707   \n",
       "19     19      50.262399      7.179083         0.540624        0.348990   \n",
       "\n",
       "   param_alpha param_n_iter_no_change  \\\n",
       "0         0.01                      5   \n",
       "1         0.01                     10   \n",
       "2         0.01                     15   \n",
       "3         0.01                     20   \n",
       "4        0.001                      5   \n",
       "5        0.001                     10   \n",
       "6        0.001                     15   \n",
       "7        0.001                     20   \n",
       "8       0.0001                      5   \n",
       "9       0.0001                     10   \n",
       "10      0.0001                     15   \n",
       "11      0.0001                     20   \n",
       "12     0.00001                      5   \n",
       "13     0.00001                     10   \n",
       "14     0.00001                     15   \n",
       "15     0.00001                     20   \n",
       "16    0.000001                      5   \n",
       "17    0.000001                     10   \n",
       "18    0.000001                     15   \n",
       "19    0.000001                     20   \n",
       "\n",
       "                                       params  split0_test_accuracy  \\\n",
       "0      {'alpha': 0.01, 'n_iter_no_change': 5}              0.692152   \n",
       "1     {'alpha': 0.01, 'n_iter_no_change': 10}              0.691190   \n",
       "2     {'alpha': 0.01, 'n_iter_no_change': 15}              0.691122   \n",
       "3     {'alpha': 0.01, 'n_iter_no_change': 20}              0.690916   \n",
       "4     {'alpha': 0.001, 'n_iter_no_change': 5}              0.741415   \n",
       "5    {'alpha': 0.001, 'n_iter_no_change': 10}              0.742473   \n",
       "6    {'alpha': 0.001, 'n_iter_no_change': 15}              0.741072   \n",
       "7    {'alpha': 0.001, 'n_iter_no_change': 20}              0.741236   \n",
       "8    {'alpha': 0.0001, 'n_iter_no_change': 5}              0.780660   \n",
       "9   {'alpha': 0.0001, 'n_iter_no_change': 10}              0.780578   \n",
       "10  {'alpha': 0.0001, 'n_iter_no_change': 15}              0.780908   \n",
       "11  {'alpha': 0.0001, 'n_iter_no_change': 20}              0.780688   \n",
       "12    {'alpha': 1e-05, 'n_iter_no_change': 5}              0.807332   \n",
       "13   {'alpha': 1e-05, 'n_iter_no_change': 10}              0.806453   \n",
       "14   {'alpha': 1e-05, 'n_iter_no_change': 15}              0.804680   \n",
       "15   {'alpha': 1e-05, 'n_iter_no_change': 20}              0.806742   \n",
       "16    {'alpha': 1e-06, 'n_iter_no_change': 5}              0.802949   \n",
       "17   {'alpha': 1e-06, 'n_iter_no_change': 10}              0.801602   \n",
       "18   {'alpha': 1e-06, 'n_iter_no_change': 15}              0.798497   \n",
       "19   {'alpha': 1e-06, 'n_iter_no_change': 20}              0.797260   \n",
       "\n",
       "    split1_test_accuracy  ...  std_test_accuracy  rank_test_accuracy  \\\n",
       "0               0.691969  ...           0.000602                  17   \n",
       "1               0.691969  ...           0.000497                  18   \n",
       "2               0.691461  ...           0.000457                  19   \n",
       "3               0.690843  ...           0.000539                  20   \n",
       "4               0.742827  ...           0.001528                  14   \n",
       "5               0.743638  ...           0.001614                  13   \n",
       "6               0.743129  ...           0.001499                  15   \n",
       "7               0.741796  ...           0.001460                  16   \n",
       "8               0.782870  ...           0.001323                   9   \n",
       "9               0.782719  ...           0.001481                  12   \n",
       "10              0.782870  ...           0.001463                  10   \n",
       "11              0.782691  ...           0.001637                  11   \n",
       "12              0.808237  ...           0.001453                   1   \n",
       "13              0.807934  ...           0.001237                   4   \n",
       "14              0.808828  ...           0.001937                   2   \n",
       "15              0.807426  ...           0.001874                   3   \n",
       "16              0.803784  ...           0.003224                   7   \n",
       "17              0.802960  ...           0.002758                   5   \n",
       "18              0.802273  ...           0.001822                   6   \n",
       "19              0.799813  ...           0.002437                   8   \n",
       "\n",
       "    split0_test_f1  split1_test_f1  split2_test_f1  split3_test_f1  \\\n",
       "0         0.352870        0.351276        0.352789        0.354524   \n",
       "1         0.347076        0.351426        0.348466        0.349686   \n",
       "2         0.346570        0.348641        0.347652        0.350843   \n",
       "3         0.345848        0.345646        0.346039        0.347637   \n",
       "4         0.546314        0.550779        0.547320        0.547302   \n",
       "5         0.552838        0.556443        0.547028        0.549382   \n",
       "6         0.545130        0.551458        0.546284        0.545874   \n",
       "7         0.546055        0.544507        0.539016        0.542346   \n",
       "8         0.652092        0.650675        0.650521        0.648484   \n",
       "9         0.650713        0.653565        0.643136        0.640263   \n",
       "10        0.653647        0.655294        0.651344        0.648414   \n",
       "11        0.650919        0.652103        0.647006        0.647918   \n",
       "12        0.698025        0.714616        0.709538        0.707157   \n",
       "13        0.713679        0.701722        0.710008        0.709509   \n",
       "14        0.713057        0.706113        0.711488        0.711608   \n",
       "15        0.712181        0.711004        0.715343        0.710696   \n",
       "16        0.706448        0.698482        0.704552        0.708678   \n",
       "17        0.708099        0.707074        0.710392        0.709978   \n",
       "18        0.714152        0.715007        0.711416        0.706668   \n",
       "19        0.712611        0.715613        0.713257        0.711372   \n",
       "\n",
       "    split4_test_f1  mean_test_f1  std_test_f1  rank_test_f1  \n",
       "0         0.356614      0.353614     0.001818            17  \n",
       "1         0.349993      0.349329     0.001469            18  \n",
       "2         0.350410      0.348823     0.001617            19  \n",
       "3         0.349897      0.347013     0.001605            20  \n",
       "4         0.555941      0.549531     0.003546            14  \n",
       "5         0.549654      0.551069     0.003262            13  \n",
       "6         0.550921      0.547933     0.002690            15  \n",
       "7         0.547395      0.543864     0.002949            16  \n",
       "8         0.657971      0.651949     0.003223            10  \n",
       "9         0.642371      0.646010     0.005172            12  \n",
       "10        0.656066      0.652953     0.002787             9  \n",
       "11        0.656548      0.650899     0.003388            11  \n",
       "12        0.707429      0.707353     0.005377             7  \n",
       "13        0.719614      0.710906     0.005842             5  \n",
       "14        0.715967      0.711647     0.003203             4  \n",
       "15        0.715824      0.713009     0.002164             3  \n",
       "16        0.718208      0.707274     0.006433             8  \n",
       "17        0.715186      0.710146     0.002797             6  \n",
       "18        0.719169      0.713283     0.004138             2  \n",
       "19        0.716722      0.713915     0.001968             1  \n",
       "\n",
       "[20 rows x 24 columns]"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gridSearchCVResults = pd.DataFrame.from_dict(grid_search.cv_results_).reset_index()\n",
    "gridSearchCVResults"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<u>Best Parameters</u>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best Parameters:  {'alpha': 1e-06, 'n_iter_no_change': 20}\n"
     ]
    }
   ],
   "source": [
    "print(\"Best Parameters: \", grid_search.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trigram Tuned Logistic Regression Accuracy:  80.31116277919315\n",
      "trigram Tuned Logistic Regression F1 Score:  72.37454015409178\n"
     ]
    }
   ],
   "source": [
    "bestAlpha = grid_search.best_params_['alpha']\n",
    "bestNIterNoChange = grid_search.best_params_['n_iter_no_change']\n",
    "trigramTunedLogisticRegressor = SGDClassifier(\n",
    "                            loss='log_loss',\n",
    "                            alpha=bestAlpha,\n",
    "                            penalty='l2',\n",
    "                            max_iter=1000,\n",
    "                            learning_rate='optimal',\n",
    "                            n_iter_no_change=bestNIterNoChange,\n",
    "                            n_jobs=-1,\n",
    "                            random_state=42)\n",
    "trigramTunedLogisticRegressor.fit(X_train_trigram, y_train)\n",
    "y_pred_trigram_tuned_logistic = trigramTunedLogisticRegressor.predict(X_test_trigram)\n",
    "print(\"trigram Tuned Logistic Regression Accuracy: \", 100*accuracy_score(y_test, y_pred_trigram_tuned_logistic))\n",
    "print(\"trigram Tuned Logistic Regression F1 Score: \", 100*f1_score(y_test, y_pred_trigram_tuned_logistic))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SVM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "According to paper, the default parameters to be used for SVM are $C=1.0$ and $kernel=linear$ unless specified otherwise."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Unigram Linear SVM Model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unigram Linear SVM Accuracy:  0.7339038808775878\n",
      "Unigram Linear SVM F1 Score:  0.6413282656531306\n"
     ]
    }
   ],
   "source": [
    "# X_train_unigram = load_npz(UNIGRAM_PATH + \"train.npz\")\n",
    "# X_test_unigram = load_npz(UNIGRAM_PATH + \"test.npz\")\n",
    "unigramLinearSVM = LinearSVC(C=1.0, max_iter=10000, random_state=42)\n",
    "unigramLinearSVM.fit(X_train_unigram, y_train)\n",
    "y_pred_unigram_linear_svm = unigramLinearSVM.predict(X_test_unigram)\n",
    "print(\"Unigram Linear SVM Accuracy: \", accuracy_score(y_test, y_pred_unigram_linear_svm))\n",
    "print(\"Unigram Linear SVM F1 Score: \", f1_score(y_test, y_pred_unigram_linear_svm))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Bigram Linear SVM Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bigram Linear SVM Accuracy:  0.7765465383759182\n",
      "Bigram Linear SVM F1 Score:  0.6993877279382404\n"
     ]
    }
   ],
   "source": [
    "# X_train_bigram = load_npz(BIGRAM_PATH + \"train.npz\")\n",
    "# X_test_bigram = load_npz(BIGRAM_PATH + \"test.npz\")\n",
    "bigramLinearSVM = LinearSVC(C=1.0, max_iter=10000, random_state=42)\n",
    "bigramLinearSVM.fit(X_train_bigram, y_train)\n",
    "y_pred_bigram_linear_svm = bigramLinearSVM.predict(X_test_bigram)\n",
    "print(\"Bigram Linear SVM Accuracy: \", accuracy_score(y_test, y_pred_bigram_linear_svm))\n",
    "print(\"Bigram Linear SVM F1 Score: \", f1_score(y_test, y_pred_bigram_linear_svm))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Trigram Linear SVM Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trigram Linear SVM Accuracy:  0.7926241064582354\n",
      "Trigram Linear SVM F1 Score:  0.7131910235358511\n"
     ]
    }
   ],
   "source": [
    "# X_train_trigram = load_npz(TRIGRAM_PATH + \"train.npz\")\n",
    "# X_test_trigram = load_npz(TRIGRAM_PATH + \"test.npz\")\n",
    "trigramLinearSVM = LinearSVC(C=1.0, max_iter=10000, random_state=42)\n",
    "trigramLinearSVM.fit(X_train_trigram, y_train)\n",
    "y_pred_trigram_linear_svm = trigramLinearSVM.predict(X_test_trigram)\n",
    "print(\"Trigram Linear SVM Accuracy: \", accuracy_score(y_test, y_pred_trigram_linear_svm))\n",
    "print(\"Trigram Linear SVM F1 Score: \", f1_score(y_test, y_pred_trigram_linear_svm))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameter tuning for Trigram SVM Model with different kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train_trigram = load_npz(TRIGRAM_PATH + \"train.npz\")\n",
    "# X_test_trigram = load_npz(TRIGRAM_PATH + \"test.npz\")\n",
    "# X_val_trigram = load_npz(TRIGRAM_PATH + \"val.npz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 2 folds for each of 14 candidates, totalling 28 fits\n"
     ]
    }
   ],
   "source": [
    "parameters = dict({\n",
    "                    'C':[0.001, 0.005, 0.1, 0.5, 1.0, 10, 50], \n",
    "                    'kernel':['linear', 'rbf']\n",
    "                })\n",
    "trigramSVM = SVC(max_iter=-1, random_state=42, gamma='auto')\n",
    "cv_stratified_splitter = StratifiedKFold(n_splits=2)\n",
    "grid_search = GridSearchCV(trigramSVM, \n",
    "                            parameters, \n",
    "                            cv=cv_stratified_splitter, \n",
    "                            scoring=['accuracy', 'f1'], \n",
    "                            n_jobs=-1,\n",
    "                            refit='accuracy',\n",
    "                            verbose=1)\n",
    "grid_search.fit(sparse_vstack([X_train_trigram, X_val_trigram]), np.concatenate((y_train, y_val)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best Parameters:  {'C': 0.1, 'kernel': 'linear'}\n"
     ]
    }
   ],
   "source": [
    "print(\"Best Parameters: \", grid_search.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trigram Tuned SVM Accuracy:  0.8011445843330283\n",
      "Trigram Tuned SVM F1 Score:  0.712977065767285\n"
     ]
    }
   ],
   "source": [
    "bestC=grid_search.best_params_['C']\n",
    "bestKernel=grid_search.best_params_['kernel']\n",
    "trigramTunedSVM = SVC(C=bestC, kernel=bestKernel, max_iter=-1, random_state=42, gamma='auto')\n",
    "trigramTunedSVM.fit(X_train_trigram, y_train)\n",
    "y_pred_trigram_tuned_svm = trigramTunedSVM.predict(X_test_trigram)\n",
    "print(\"Trigram Tuned SVM Accuracy: \", accuracy_score(y_test, y_pred_trigram_tuned_svm))\n",
    "print(\"Trigram Tuned SVM F1 Score: \", f1_score(y_test, y_pred_trigram_tuned_svm))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sentence embeddings as feature vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "50-dimensional word vectors are obtained using GloVe vectors (GloVe.6B.50d.txt). The sentence embeddings are obtained by simply summing the word embeddings in a sentence. The sentence embeddings are then used as feature vectors for classification in the following two ways:\n",
    "- Plain sentence embeddings\n",
    "- Distance measure between vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "GLOVE_PATH = INPUT + 'glove.6B.50d.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocessAndTokenizeForGlove(text: str) -> list[str]:\n",
    "    text = re.sub(r'[^\\x00-\\x7F]+',' ', text.lower())\n",
    "    text = text.translate(str.maketrans('', '', string.punctuation))\n",
    "    tokens = nltk.word_tokenize(text)\n",
    "    tokens = [stemmer.stem(w) for w in tokens if stemmer.stem(w) not in stopwords]\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    X_train_q1_tokenized = np.array([preprocessAndTokenizeForGlove(ques) for ques in X_train_q1], dtype=object)\n",
    "    X_train_q2_tokenized = np.array([preprocessAndTokenizeForGlove(ques) for ques in X_train_q2], dtype=object)\n",
    "    X_test_q1_tokenized = np.array([preprocessAndTokenizeForGlove(ques) for ques in X_test_q1], dtype=object)\n",
    "    X_test_q2_tokenized = np.array([preprocessAndTokenizeForGlove(ques) for ques in X_test_q2], dtype=object)\n",
    "    X_val_q1_tokenized = np.array([preprocessAndTokenizeForGlove(ques) for ques in X_val_q1], dtype=object)\n",
    "    X_val_q2_tokenized = np.array([preprocessAndTokenizeForGlove(ques) for ques in X_val_q2], dtype=object)\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load GloVe Word Embeddings\n",
    "GloVe_embeddings = {}\n",
    "with open(GLOVE_PATH, 'r') as f:\n",
    "    for line in f:\n",
    "        values = line.split()\n",
    "        word = values[0]\n",
    "        vector = np.asarray(values[1:], \"float32\")\n",
    "        GloVe_embeddings[word] = vector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating sentence embeddings for each question\n",
    "\n",
    "According to the paper, the sentence embeddings are obtained by simply summing the embeddings of all the tokens in a sentence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_q1_embeddings = np.array([np.sum([GloVe_embeddings[w] for w in ques if w in GloVe_embeddings] + [np.zeros((50,))], axis=0) for ques in X_train_q1_tokenized])\n",
    "X_train_q2_embeddings = np.array([np.sum([GloVe_embeddings[w] for w in ques if w in GloVe_embeddings] + [np.zeros((50,))], axis=0) for ques in X_train_q2_tokenized])\n",
    "X_test_q1_embeddings = np.array([np.sum([GloVe_embeddings[w] for w in ques if w in GloVe_embeddings] + [np.zeros((50,))], axis=0) for ques in X_test_q1_tokenized])\n",
    "X_test_q2_embeddings = np.array([np.sum([GloVe_embeddings[w] for w in ques if w in GloVe_embeddings] + [np.zeros((50,))], axis=0) for ques in X_test_q2_tokenized])\n",
    "X_val_q1_embeddings = np.array([np.sum([GloVe_embeddings[w] for w in ques if w in GloVe_embeddings] + [np.zeros((50,))], axis=0) for ques in X_val_q1_tokenized])\n",
    "X_val_q2_embeddings = np.array([np.sum([GloVe_embeddings[w] for w in ques if w in GloVe_embeddings] + [np.zeros((50,))], axis=0) for ques in X_val_q2_tokenized])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "QUESTION_EMBEDDINGS = './question_embeddings/'\n",
    "os.makedirs(QUESTION_EMBEDDINGS, exist_ok=True)\n",
    "TRAIN_Q1_EMBEDDINGS = QUESTION_EMBEDDINGS + 'train_q1_embeddings.npz'\n",
    "TRAIN_Q2_EMBEDDINGS = QUESTION_EMBEDDINGS + 'train_q2_embeddings.npz'\n",
    "TEST_Q1_EMBEDDINGS = QUESTION_EMBEDDINGS + 'test_q1_embeddings.npz'\n",
    "TEST_Q2_EMBEDDINGS = QUESTION_EMBEDDINGS + 'test_q2_embeddings.npz'\n",
    "VAL_Q1_EMBEDDINGS = QUESTION_EMBEDDINGS + 'val_q1_embeddings.npz'\n",
    "VAL_Q2_EMBEDDINGS = QUESTION_EMBEDDINGS + 'val_q2_embeddings.npz'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez_compressed(TRAIN_Q1_EMBEDDINGS, X_train_q1_embeddings)\n",
    "np.savez_compressed(TRAIN_Q2_EMBEDDINGS, X_train_q2_embeddings)\n",
    "np.savez_compressed(TEST_Q1_EMBEDDINGS, X_test_q1_embeddings)\n",
    "np.savez_compressed(TEST_Q2_EMBEDDINGS, X_test_q2_embeddings)\n",
    "np.savez_compressed(VAL_Q1_EMBEDDINGS, X_val_q1_embeddings)\n",
    "np.savez_compressed(VAL_Q2_EMBEDDINGS, X_val_q2_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_q1_embeddings = np.load(TRAIN_Q1_EMBEDDINGS)['arr_0']\n",
    "X_train_q2_embeddings = np.load(TRAIN_Q2_EMBEDDINGS)['arr_0']\n",
    "X_test_q1_embeddings = np.load(TEST_Q1_EMBEDDINGS)['arr_0']\n",
    "X_test_q2_embeddings = np.load(TEST_Q2_EMBEDDINGS)['arr_0']\n",
    "X_val_q1_embeddings = np.load(VAL_Q1_EMBEDDINGS)['arr_0']\n",
    "X_val_q2_embeddings = np.load(VAL_Q2_EMBEDDINGS)['arr_0']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Plain sentence embeddings\n",
    "\n",
    "$100$-dimensional feature vector = $50$-dimensional question $1$ sentence embedding + $50$-dimensional question $2$ sentence embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_plain_embeddings = np.hstack((X_train_q1_embeddings, X_train_q2_embeddings))\n",
    "X_test_plain_embeddings = np.hstack((X_test_q1_embeddings, X_test_q2_embeddings))\n",
    "X_val_plain_embeddings = np.hstack((X_val_q1_embeddings, X_val_q2_embeddings))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.7738499097182715\n",
      "F1 Score: 0.6989414766371224\n"
     ]
    }
   ],
   "source": [
    "SVMmodel = SVC(kernel='rbf', C=1.0, random_state=42, max_iter=-1, gamma='auto')\n",
    "SVMmodel.fit(X_train_plain_embeddings, y_train)\n",
    "y_pred = SVMmodel.predict(X_test_plain_embeddings)\n",
    "print(\"Accuracy:\", accuracy_score(y_test, y_pred))\n",
    "print(\"F1 Score:\", f1_score(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nemesis/.local/lib/python3.10/site-packages/sklearn/svm/_base.py:301: ConvergenceWarning: Solver terminated early (max_iter=3500).  Consider pre-processing your data with StandardScaler or MinMaxScaler.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.63856125306092164\n",
      "F1 Score: 0.6193053333333333\n"
     ]
    }
   ],
   "source": [
    "SVMmodel = SVC(kernel='linear', C=1.0, random_state=42, max_iter=-1, gamma='auto')\n",
    "SVMmodel.fit(X_train_plain_embeddings, y_train)\n",
    "y_pred = SVMmodel.predict(X_test_plain_embeddings)\n",
    "print(\"Accuracy:\", accuracy_score(y_test, y_pred))\n",
    "print(\"F1 Score:\", f1_score(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Distance measure between vectors\n",
    "Feature vector is obtained by taking various distance measures between the sentence embeddings of the two questions:\n",
    "- Bray Curtis distance:\n",
    "$$\n",
    "d_{Bray Curtis} = \\frac{1}{2} \\sum_{i=1}^{50} \\frac{|q_1[i] - q_2[i]|}{|q_1[i] + q_2[i]|}\n",
    "$$\n",
    "- Canberra distance:\n",
    "$$\n",
    "d_{Canberra} = \\sum_{i=1}^{50} \\frac{|q_1[i] - q_2[i]|}{|q_1[i]| + |q_2[i]|}\n",
    "$$\n",
    "- Chebyshev distance:\n",
    "$$d_{Chebyshev} = max(|q_1[i] - q_2[i]|)$$\n",
    "- City block distance:\n",
    "$$d_{City block} = \\sum_{i=1}^{50} |q_1[i] - q_2[i]|$$\n",
    "- Correlation distance:\n",
    "$$\n",
    "\\begin{aligned}\n",
    "d_{Correlation} = 1 - \\frac{\\sum_{i=1}^{50} (q_1[i] - \\bar{q_1})(q_2[i] - \\bar{q_2})}{\\sqrt{\\sum_{i=1}^{50} (q_1[i] - \\bar{q_1})^2} \\sqrt{\\sum_{i=1}^{50} (q_2[i] - \\bar{q_2})^2}}\n",
    "\\end{aligned}\n",
    "$$\n",
    "- Cosine distance: \n",
    "$$\n",
    "d_{Cosine} = 1 - \\frac{\\sum_{i=1}^{50} q_1[i]q_2[i]}{\\sqrt{\\sum_{i=1}^{50} q_1[i]^2} \\sqrt{\\sum_{i=1}^{50} q_2[i]^2}}\n",
    "$$\n",
    "- Euclidean distance:\n",
    "$$d_{Euclidean} = \\sqrt{\\sum_{i=1}^{50} (q_1[i] - q_2[i])^2}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jeopardy/.local/lib/python3.10/site-packages/scipy/spatial/distance.py:1162: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  return l1_diff.sum() / l1_sum.sum()\n",
      "/home/jeopardy/.local/lib/python3.10/site-packages/scipy/spatial/distance.py:630: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  dist = 1.0 - uv / np.sqrt(uu * vv)\n"
     ]
    }
   ],
   "source": [
    "import scipy.spatial.distance as scipyDistance\n",
    "\n",
    "def distances(q1, q2):\n",
    "    distanceFeatureVector = np.array([\n",
    "        scipyDistance.braycurtis(q1, q2),\n",
    "        scipyDistance.canberra(q1, q2),\n",
    "        scipyDistance.chebyshev(q1, q2),\n",
    "        scipyDistance.cityblock(q1, q2),\n",
    "        scipyDistance.correlation(q1, q2),\n",
    "        scipyDistance.cosine(q1, q2),\n",
    "        scipyDistance.euclidean(q1, q2)\n",
    "    ])\n",
    "    distanceFeatureVector = np.nan_to_num(distanceFeatureVector)\n",
    "    return distanceFeatureVector\n",
    "\n",
    "\n",
    "X_train_distances = np.array([distances(q1, q2) for q1, q2 in zip(X_train_q1_embeddings, X_train_q2_embeddings)])\n",
    "X_test_distances = np.array([distances(q1, q2) for q1, q2 in zip(X_test_q1_embeddings, X_test_q2_embeddings)])\n",
    "X_val_distances = np.array([distances(q1, q2) for q1, q2 in zip(X_val_q1_embeddings, X_val_q2_embeddings)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.683458686969944\n",
      "F1 Score: 0.679439494994484\n"
     ]
    }
   ],
   "source": [
    "distanceSVM = SVC(kernel='rbf', C=1.0, random_state=42, max_iter=-1, gamma='auto')\n",
    "distanceSVM.fit(X_train_distances, y_train)\n",
    "y_pred = distanceSVM.predict(X_test_distances)\n",
    "print(\"Accuracy:\", accuracy_score(y_test, y_pred))\n",
    "print(\"F1 Score:\", f1_score(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.639099378488921\n",
      "F1 Score: 0.625289938484038\n"
     ]
    }
   ],
   "source": [
    "distanceSVM = SVC(kernel='linear', C=1.0, random_state=42, max_iter=-1, gamma='auto')\n",
    "distanceSVM.fit(X_train_distances, y_train)\n",
    "y_pred = distanceSVM.predict(X_test_distances)\n",
    "print(\"Accuracy:\", accuracy_score(y_test, y_pred))\n",
    "print(\"F1 Score:\", f1_score(y_test, y_pred))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.8 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}