In [88]:
import argparse
from itertools import cycle
import torch
import typing
import pandas as pd
from torch.utils.data import DataLoader, Dataset, IterableDataset, random_split
import numpy as np
from encoder import Encoder, create_encoder
from typing import Generator

In [90]:
class TextDataset(IterableDataset):
    def __init__(
        self, path: str, encoder: Encoder, seq_len: int = 512
    ) -> None:
        self.seq_len = seq_len
        self.encoder = encoder
        self.path = path

    def read_file(self) -> Generator:
        """Reads the file and yields each character."""
        with open(self.path, "r") as f:
            for line in f:
                yield from line.strip("\n")
        

    def __iter__(self) -> Generator:
        """
        Stream characters from file and encode them. 
        When the sequence reaches the desired length, yield it.
        """
        sequence = []
        for char in self.read_file():
            tokens = self.encoder.encode(char)
            for token in tokens:
                if len(sequence) == self.seq_len:
                    yield sequence
                    sequence = []
                sequence.append(token)

def get_dataloaders(
    batch_size=128, 
    seq_len=512, 
    train_path="./data/gutenberg_train.txt", 
    test_path="./data/gutenberg_test.txt"
) -> typing.Tuple[DataLoader, DataLoader]:
    """Create the test and validation dataloaders"""

    encoder = create_encoder("./data/pg16457.txt")
    train_dataset = TextDataset(train_path, encoder, seq_len=seq_len)
    test_dataset = TextDataset(test_path, encoder, seq_len=seq_len)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, collate_fn=lambda x: torch.tensor(x)
    )

    val_loader = DataLoader(
        test_dataset, batch_size=batch_size, collate_fn=lambda x: torch.tensor(x)
    )

    return train_loader, val_loader

torch.Size([128, 512])
