In [11]:
import torch
from tokenizer import tokenizer
from torch.utils.data import Dataset, SubsetRandomSampler, DataLoader
import numpy as np

def preprocess_corpus(corpus: str) -> str:
    return tokenizer.preprocess(corpus)

class TextDataset(Dataset):
    def __init__(self, corpus: str, context_len: int) -> None:
        super().__init__()
        self.data = torch.tensor(tokenizer.encode(corpus), dtype=torch.long)
        self.context_len = context_len
        self.size = len(self.data) - context_len

    def __len__(self) -> int:
        return self.size

    def __getitem__(self, index: int):
        x = self.data[index : index + self.context_len]
        y = self.data[index + 1 : index + self.context_len + 1]
        return x, y

def get_dataloader(corpus: str, batch_size: int, context_len: int, cv_ratio: float, num_workers: int, device: str, shuffle: bool = True):
    dataset_obj = TextDataset(corpus, context_len)
    indices = list(range(len(dataset_obj)))
    if shuffle:
        np.random.shuffle(indices)
    split = int(len(indices) * cv_ratio)
    train_indices = indices[split:]
    val_indices = indices[:split]
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    
    # Enable pin_memory if using CUDA
    pin_memory = True if "cuda" in device.lower() else False
    
    train_loader = DataLoader(dataset_obj, batch_size=batch_size, sampler=train_sampler,
                              num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(dataset_obj, batch_size=batch_size, sampler=val_sampler,
                            num_workers=num_workers, pin_memory=pin_memory)
    return train_loader, val_loader
    

In [None]:
file = "data/lyrics.txt"
with open(file, "r", encoding="UTF-8") as f:
    corpus = f.read()

corpus = preprocess_corpus(corpus)
train_loader, val_loader = get_dataloader(corpus, 256, 256, 0.1, 4, "cuda")





In [8]:
print(len(train_loader))
print(len(val_loader))

11697251
