In [1]:
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from typing import Any

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CommonLitDataset(Dataset):
    "Dataset for CommonLit Readability Prize, reading summary and text from csv file"

    def __init__(self, model, df) -> None:
        super().__init__()
        self.tok = AutoTokenizer.from_pretrained(model)
        self.df = df

    def __getitem__(self, idx) -> Any:
        text = self.df.iloc[idx].excerpt
        text = self.tok(
            text, padding="max_length", truncation=True, return_tensors="pt"
        )
        return (
            {
                "input_ids": text.input_ids.squeeze(),
                "attention_mask": text.attention_mask.squeeze(),
                "content": self.df.iloc[idx].content,
                "wording": self.df.iloc[idx].wording,
            }
            if "content" in self.df.columns
            else {
                "input_ids": text.input_ids.squeeze(),
                "attention_mask": text.attention_mask.squeeze(),
                "content": None,
                "wording": None,
            }
        )

In [8]:
tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
tok.add_tokens([
                "<PROMPT>",
                "<SUMMARY>",
                "<PROMPT_TITLE>",
                "</PROMPT>",
                "</SUMMARY>",
                "</PROMPT_TITLE>",
            ])

6

In [9]:
len(tok)

30528

In [4]:
df = pd.read_csv('./data/train.csv')
df.prompt_id.unique()

array(['814d6b', 'ebad26', '3b9047', '39c16e'], dtype=object)

In [5]:
train_df = df[df.prompt_id!='814d6b'].copy().reset_index(drop=True)
val_df = df[df.prompt_id=='814d6b'].copy().reset_index(drop=True)
train_df.shape, val_df.shape

((6062, 9), (1103, 9))

In [6]:
train_dataset = CommonLitDataset(model='distilbert-base-uncased', df=train_df)
val_dataset = CommonLitDataset(model='distilbert-base-uncased', df=val_df)

Downloading (…)okenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 170kB/s]
Downloading (…)lve/main/config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 483/483 [00:00<00:00, 3.22MB/s]
Downloading (…)solve/main/vocab.txt: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 586kB/s]
Downloading (…)/main/tokenizer.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 466k/466k [00:00<00:00, 787kB/s]


In [13]:
train_dataset[1]['input_ids'].shape

torch.Size([512])