From 7e3096358e648856d6e3c09088e9bf5cacd27055 Mon Sep 17 00:00:00 2001 From: Omar Shaikh Date: Sun, 4 Oct 2020 18:01:38 -0400 Subject: [PATCH] edit baseline notebook --- baselines.ipynb | 31 +++++-------------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/baselines.ipynb b/baselines.ipynb index 0468cc7..5f2613d 100644 --- a/baselines.ipynb +++ b/baselines.ipynb @@ -206,7 +206,7 @@ "metadata": {}, "outputs": [], "source": [ - "class ToxicDataset(Dataset):\n", + "class PersuasionDataset(Dataset):\n", " \n", " def __init__(self, tokenizer: AutoTokenizer, dataframe: pd.DataFrame, lazy: bool = False):\n", " self.tokenizer = tokenizer\n", @@ -252,9 +252,9 @@ " y = torch.stack(y)\n", " return x.cuda(), y.cuda(), index\n", "\n", - "train_dataset = ToxicDataset(tokenizer, train_df, lazy=True)\n", - "dev_dataset = ToxicDataset(tokenizer, val_df, lazy=True)\n", - "test_dataset = ToxicDataset(tokenizer, test_df, lazy=True)\n", + "train_dataset = PersuasionDataset(tokenizer, train_df, lazy=True)\n", + "dev_dataset = PersuasionDataset(tokenizer, val_df, lazy=True)\n", + "test_dataset = PersuasionDataset(tokenizer, test_df, lazy=True)\n", "collate_fn = partial(collate_fn)\n", "BATCH_SIZE = 8\n", "train_sampler = RandomSampler(train_dataset)\n", @@ -489,27 +489,6 @@ " evaluate(model, test_iterator, test_df)" ] }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "# from transformers import BertTokenizer, DistillBertModel, AdamW, get_linear_schedule_with_warmup, DistillBertPreTrainedModel\n", - "from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score\n", - "from tqdm import tqdm\n", - "from pathlib import Path\n", - "import logging\n", - "import pickle\n", - "import random\n", - "from nltk.tokenize import sent_tokenize, word_tokenize\n", - "import itertools\n", - "import random\n", - "random.seed(0)\n", - "np.random.seed(0)\n", - "import json" - ] - }, { "cell_type": "code", "execution_count": 29, @@ -517,7 +496,7 @@ "outputs": [], "source": [ "from sklearn.dummy import DummyClassifier\n", - "from sklearn.feature_extraction.text import TfidfVectorizer\n" + "from sklearn.feature_extraction.text import TfidfVectorizer" ] }, {