In [1]:
import torch
from datasets import load_dataset
import re
from collections import Counter
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer
from transformers import AutoTokenizer
import datasets

ds = datasets.load_from_disk("data/imdb_dataset/")
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_ds = ds['train']
val_ds = ds['test']

In [9]:
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
class TextProcessing:
    def __init__(self, lemmatizer=WordNetLemmatizer(), stop_words=None):
        self.lemmatizer = lemmatizer
        self.stop_words = stop_words

    def __call__(self, text):
        text = re.sub(r'<.*?>', '', text)
        text = re.sub(r'[^a-zA-Z\s]', '', text)
        text = text.lower()

        tokens = word_tokenize(text)
        
        if self.stop_words is not None:
            tokens = [token for token in tokens if not token in self.stop_words]
        
        lemmatized_tokens = [self.lemmatizer.lemmatize(token) for token in tokens]
        
        return lemmatized_tokens     

In [74]:
class Vocabulary:
    def __init__(self, processing=AutoTokenizer.from_pretrained('bert-base-uncased'), min_freq=10):
        self.processing = processing
        self.min_freq = min_freq
        #self.max_words = max_words
        self.idx_to_str = {}
        self.str_to_idx = {}

    def __len__(self):
        return len(self.idx_to_str)

    @staticmethod
    def clean_sentence(sentence):
        return re.sub(r'[^a-zA-Z\s]', ' ', re.sub(r'[.-]', ' ', re.sub(r'<.*?>', '', sentence))).strip()

    def build_vocabulary(self, sentence_list):
        frequencies = Counter() 

        for sentence in sentence_list:
            sentence = self.clean_sentence(sentence)
            words = self.processing.tokenize(sentence)  
            frequencies.update(words) 

        frequencies = {k: v for k, v in frequencies.items() if v > self.min_freq}
        frequencies = dict(sorted(frequencies.items(), key=lambda x: -x[1]))

        for idx, word in enumerate(frequencies.keys(), start=len(self.idx_to_str)):
            self.str_to_idx[word] = idx
            self.idx_to_str[idx] = word

    def numericalize(self, sentence):
        sentence = self.clean_sentence(sentence)
        tokens = self.processing.encode(sentence, truncation=True, max_length=512, padding='max_length')
        
        return tokens




In [76]:
class CustomImdbDataset(Dataset):
    def __init__(self, df, max_length=512, min_freq=2):
        self.df = df
        self.reviews = self.df['text']
        self.labels = self.df['label']

        self.vocab = Vocabulary()
        self.vocab.build_vocabulary(self.reviews)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        label = self.labels[idx]
        review = self.reviews[idx]

        tokenized_review = self.vocab.numericalize(review) 
        
        return torch.tensor(tokenized_review), torch.tensor(label)

In [77]:
train_ds = CustomImdbDataset(ds['train'])

Token indices sequence length is longer than the specified maximum sequence length for this model (558 > 512). Running this sequence through the model will result in indexing errors


In [78]:
review, label = train_ds[0]
review, label

(tensor([  101,  1045, 12524,  1045,  2572,  8025,  3756,  2013,  2026,  2678,
          3573,  2138,  1997,  2035,  1996,  6704,  2008,  5129,  2009,  2043,
          2009,  2001,  2034,  2207,  1999,  1045,  2036,  2657,  2008,  2012,
          2034,  2009,  2001,  8243,  2011,  1057,  1055,  8205,  2065,  2009,
          2412,  2699,  2000,  4607,  2023,  2406,  3568,  2108,  1037,  5470,
          1997,  3152,  2641,  6801,  1045,  2428,  2018,  2000,  2156,  2023,
          2005,  2870,  1996,  5436,  2003,  8857,  2105,  1037,  2402,  4467,
          3689,  3076,  2315, 14229,  2040,  4122,  2000,  4553,  2673,  2016,
          2064,  2055,  2166,  1999,  3327,  2016,  4122,  2000,  3579,  2014,
          3086,  2015,  2000,  2437,  2070,  4066,  1997,  4516,  2006,  2054,
          1996,  2779, 25430, 14728,  2245,  2055,  3056,  2576,  3314,  2107,
          2004,  1996,  5148,  2162,  1998,  2679,  3314,  1999,  1996,  2142,
          2163,  1999,  2090,  4851,  8801,  1998,  