In [215]:
from transformers import BertModel, BertTokenizer
from torch.utils.data import Dataset , DataLoader
import torch
from collections import Counter
from bs4 import BeautifulSoup
import re
from torch.nn.utils.rnn import pad_sequence

In [103]:
file_name = "input.txt"

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [104]:
def clean_text(text):
    # Remove HTML tags
    text = BeautifulSoup(text, "html.parser").get_text()
    # Convert text to lowercase
    text = text.lower()
    # Remove punctuation and numbers
    text = re.sub(r'[^a-z\s]', ' ', text)
    return text

In [105]:
with open(file_name, 'r') as file:
    data_str = file.read()
    data_str = clean_text(data_str)

In [203]:
class MyDataset(Dataset):

    def __init__(self, data_str, tokenizer, block_size):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.data_str = data_str
        self.split_str = self.data_str.split()
    

    def __getitem__(self, i):
        input = self.tokenizer.encode(" ".join(self.split_str[i:i+self.block_size]) , max_length=self.block_size, padding='max_length')
        target = self.tokenizer.encode(" ".join(self.split_str[i+1:i+1+self.block_size]), max_length=self.block_size, padding='max_length')
        return input , target

    def __len__(self):
        return len(self.split_str) - self.block_size

In [204]:
data = MyDataset(data_str, tokenizer , 5)

In [205]:
input , target = data.__getitem__(1)

In [216]:
def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = [torch.tensor(x) for x in inputs]
    targets = [torch.tensor(x) for x in targets]
    inputs = pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id)
    targets = pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id)
    return inputs, targets
train_loader = DataLoader(data , batch_size = 100 , shuffle = True, collate_fn = collate_fn)

In [218]:
for input , target in train_loader:
    print(input.shape , target.shape)
    break

torch.Size([100, 11]) torch.Size([100, 11])
