In [14]:
from typing import List
from torch.utils.data import Dataset


class ExampleDataset(Dataset):
    def __init__(self, large_file_path, chunk_size):
        self.large_file_path = large_file_path
        self.line_offsets = self.get_line_offsets(large_file_path, chunk_size)

    def get_line_offsets(self, path: str, chunk_size: int) -> List[int]:
        offsets = [0]
        with open(path, "rb") as file:
            chunk = file.readlines(chunk_size)
            while chunk:
                for line in chunk:
                    offsets.append(offsets[-1] + len(line))
                print(f"Lines found: {len(offsets)}", end='\r')
                chunk = file.readlines(chunk_size)
        return offsets
    
    def __len__(self):
        return len(self.line_offsets)
    
    def __getitem__(self, line):
        offset = self.line_offsets[line]
        with open(self.large_file_path, 'r', encoding='utf-8') as f:
            f.seek(offset)
            line = f.readline()
            return line

In [15]:
dataset = ExampleDataset("data/full_dataset.txt", chunk_size=2**6)

print(len(dataset))


Lines found: 26
Lines found: 48
Lines found: 70
Lines found: 92
Lines found: 111
Lines found: 128
Lines found: 145
Lines found: 162
Lines found: 179
Lines found: 196
Lines found: 201
201


In [None]:
from torch.utils.data import DataLoader


dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for batch in dataloader:
    print(batch)