In [None]:
class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # modified.
        # token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
        token_ids = tokenizer.encode(txt)

        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader_v1(txt, batch_size=4, max_length=256,
                         stride=128, shuffle=True, drop_last=True, num_workers=0):
    # -------------------------------
    # Modofy tokenizer
    # -------------------------------
    # modified. tokenizer initialization
    # tokenizer = tiktoken.get_encoding("gpt2")
    # Load the tokenizer
    tokenizer = PreTrainedTokenizerFast.from_pretrained("Aananda-giri/NepaliBPE")
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

    return dataloader

def create_dataloader_v2(batch_size=4, shuffle=True, drop_last=True, num_workers=0, train_ratio=.9, context_length=1024):
    '''
    modified.
    * dont need text data as input
    * dont need max_length and stride as input : they were set during preparing tokenized_datasets
    '''
    # Download the whole dataset
    base_url = "https://huggingface.co/datasets/Aananda-giri/nepali_llm_datasets/resolve/main/pre_tokenized/"
    data_files = {"train": base_url + "nepberta_" + str(context_length) + ".parquet"}
    dataset = load_dataset("parquet", data_files=data_files, split="train", cache_dir='hf_cache')
    
    print(dataset)

    # and split it later
    dataset = dataset.train_test_split(train_size=train_ratio, seed=42)
    # Convert Hugging Face Dataset to PyTorch tensors (we can directly use the dataset as it is already in the correct format)
    dataset.set_format(type="torch", columns=["input_ids", "target_ids"])  # Directly set columns to torch tensors



    # Define the custom collate_fn function
    def collate_fn(batch):
        # Extract the 'input_ids' and 'target_ids' from the batch and return them as a list of tensors
        input_ids = [item['input_ids'] for item in batch]
        target_ids = [item['target_ids'] for item in batch]

        # Convert to tensors (if not already)
        input_ids_tensor = torch.stack(input_ids)
        target_ids_tensor = torch.stack(target_ids)

        return [input_ids_tensor, target_ids_tensor]

    
    # Creating the DataLoader for the 'train' split of the dataset with the custom collate_fn
    train_loader = DataLoader(
        dataset['train'],
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    val_loader =  DataLoader(
        dataset['test'],
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    return train_loader, val_loader


def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride, num_workers=0):
    split_idx = int(train_ratio * len(text_data))
    train_loader = create_dataloader_v1(
        text_data[:split_idx],
        batch_size=batch_size,
        max_length=max_length,
        stride=stride,
        drop_last=True,
        shuffle=True,
        num_workers=num_workers
    )
    val_loader = create_dataloader_v1(
        text_data[split_idx:],
        batch_size=batch_size,
        max_length=max_length,
        stride=stride,
        drop_last=False,
        shuffle=False,
        num_workers=num_workers
    )
    return train_loader, val_loader


# Initialize new data loaders for each book
train_loader, val_loader = create_dataloaders(
    text_data,
    train_ratio=0.9,
    batch_size=args.batch_size,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    num_workers=0
)

In [None]:
# extract input_ids and target_ids from 
a={'input_ids,target_ids': '"[239, 552, 875, 904, 630, 2809, 13407, 6327, 3525, 38388, 4834, 283, 15880, 1227, 4385, 6106, 410, 37792, 12186, 170, 251, 630, 3981, 745, 12622, 22082, 6478, 875, 904, 38388, 4834, 283, 15880, 1227, 750, 834, 3345, 14118, 1017, 3656, 26349, 410, 1076, 170, 251, 2162, 800, 343, 630, 2809, 13407, 6327, 3525, 38388, 4834, 283, 4743, 2852, 5708, 251, 630, 2809, 43214, 6286, 13407, 6327, 3525, 258, 38388, 4834, 283, 6131, 675, 966, 675, 19002, 937, 15880, 5949, 952, 38388, 18818, 4743, 1962, 597, 684, 251, 1489, 9991, 207, 8484, 935, 479, 11787, 18767, 3862, 6131, 675, 966, 675, 19002, 937, 411, 92, 13623, 7945, 750, 14970, 43214, 2186, 3345, 1076, 170, 251, 25266, 328, 3110, 2022, 10835, 1366, 2549, 10090, 4719, 2036, 339, 251, 834, 7823, 750, 17438, 13763, 410, 728, 251, 5175, 20802, 201, 283, 5288, 1188, 43214, 2186, 5396, 487, 271, 3548, 45567, 19076, 843, 869, 2981, 592, 1366, 13514, 152, 8799, 592, 834, 192, 41096, 639, 750, 3672, 339, 251, 11989, 7771, 18916, 2656, 10676, 1017, 2449, 2968, 542, 1418, 2809, 92, 572, 3690, 11207, 22260, 2449, 1624, 1152, 43214, 3182, 2745, 931, 750, 1366, 5761, 1241, 102, 592, 339, 251, 630, 3981, 745, 12622, 22082, 31017, 5396, 487, 3131, 2959, 2282, 701, 362, 674, 8569, 391, 11460, 8843, 4612, 5223, 350, 549, 251, 43214, 3182, 3221, 18838, 606, 862, 25462, 8907, 21533, 251, 4512, 3442, 11460, 6260, 458, 3221, 7500, 393, 4156, 7057, 1366, 17934, 549, 251, 5347, 46244, 305, 12068, 36198, 19878, 3364, 12521, 336, 10516, 327, 4302, 17037, 608, 4881, 251, 455, 1326, 7808, 283, 665, 6484, 19878, 14661, 5463, 5093, 251, 1156, 350, 30174, 3340, 1416, 3840, 1366, 5396, 20304, 15884, 8907, 44895, 35337, 667, 170, 696, 4719, 350, 549, 251, 6677, 21701, 6467, 664, 38388, 18818, 6131, 675, 966, 675, 19002, 937, 9300, 2144, 9300, 521, 13238, 1011, 13788, 2179, 5830, 251, 838, 188, 241, 248, 434, 248, 270, 239, 24471, 1500, 8246, 5083, 4805, 1033, 1600, 20447, 8308, 736, 9860, 13480, 1782, 152, 2181, 35280, 2904, 1416, 4068, 8277, 21034, 346, 2133, 529, 251, 2846, 6949, 11961, 3174, 592, 4805, 1033, 1512, 12705, 27441, 36746, 43962, 42202, 283, 2133, 529, 251, 5646, 834, 1744, 4554, 350, 423, 251, 1600, 58, 17885, 1416, 4068, 4540, 1584, 1500, 38216, 1744, 7213, 319, 2634, 522, 529, 251, 20780, 19047, 2195, 3900, 10113, 46151, 3302, 1260, 24393, 96, 32095, 18109, 4540, 715, 42989, 20157, 20157, 863, 12780, 368, 529, 251, 1250, 5307, 2890, 350, 7380, 483, 159, 241, 248, 434, 248, 270, 10377, 2981, 600, 3654, 26799, 18919, 1527, 2597, 10881, 2927, 167, 4068, 11119, 483, 14537, 4510, 3798, 4619, 350, 170, 251, 1680, 2490, 13688, 7365, 404, 251, 952, 42233, 22972, 597, 20108, 975, 3235, 20663, 2903, 684, 251, 13688, 7365, 11617, 2597, 5265, 42233, 22972, 597, 20108, 25968, 2597, 4068, 11119, 1213, 2871, 553, 15550, 10881, 2927, 167, 5214, 1393, 4510, 3798, 1929, 1460, 20108, 18504, 4619, 382, 1500, 12434, 828, 170, 251, 3434, 1193, 283, 7899, 3280, 17704, 20043, 14406, 1393, 2597, 96, 971, 2708, 6161, 639, 4577, 390, 6738, 339, 251]","[552, 875, 904, 630, 2809, 13407, 6327, 3525, 38388, 4834, 283, 15880, 1227, 4385, 6106, 410, 37792, 12186, 170, 251, 630, 3981, 745, 12622, 22082, 6478, 875, 904, 38388, 4834, 283, 15880, 1227, 750, 834, 3345, 14118, 1017, 3656, 26349, 410, 1076, 170, 251, 2162, 800, 343, 630, 2809, 13407, 6327, 3525, 38388, 4834, 283, 4743, 2852, 5708, 251, 630, 2809, 43214, 6286, 13407, 6327, 3525, 258, 38388, 4834, 283, 6131, 675, 966, 675, 19002, 937, 15880, 5949, 952, 38388, 18818, 4743, 1962, 597, 684, 251, 1489, 9991, 207, 8484, 935, 479, 11787, 18767, 3862, 6131, 675, 966, 675, 19002, 937, 411, 92, 13623, 7945, 750, 14970, 43214, 2186, 3345, 1076, 170, 251, 25266, 328, 3110, 2022, 10835, 1366, 2549, 10090, 4719, 2036, 339, 251, 834, 7823, 750, 17438, 13763, 410, 728, 251, 5175, 20802, 201, 283, 5288, 1188, 43214, 2186, 5396, 487, 271, 3548, 45567, 19076, 843, 869, 2981, 592, 1366, 13514, 152, 8799, 592, 834, 192, 41096, 639, 750, 3672, 339, 251, 11989, 7771, 18916, 2656, 10676, 1017, 2449, 2968, 542, 1418, 2809, 92, 572, 3690, 11207, 22260, 2449, 1624, 1152, 43214, 3182, 2745, 931, 750, 1366, 5761, 1241, 102, 592, 339, 251, 630, 3981, 745, 12622, 22082, 31017, 5396, 487, 3131, 2959, 2282, 701, 362, 674, 8569, 391, 11460, 8843, 4612, 5223, 350, 549, 251, 43214, 3182, 3221, 18838, 606, 862, 25462, 8907, 21533, 251, 4512, 3442, 11460, 6260, 458, 3221, 7500, 393, 4156, 7057, 1366, 17934, 549, 251, 5347, 46244, 305, 12068, 36198, 19878, 3364, 12521, 336, 10516, 327, 4302, 17037, 608, 4881, 251, 455, 1326, 7808, 283, 665, 6484, 19878, 14661, 5463, 5093, 251, 1156, 350, 30174, 3340, 1416, 3840, 1366, 5396, 20304, 15884, 8907, 44895, 35337, 667, 170, 696, 4719, 350, 549, 251, 6677, 21701, 6467, 664, 38388, 18818, 6131, 675, 966, 675, 19002, 937, 9300, 2144, 9300, 521, 13238, 1011, 13788, 2179, 5830, 251, 838, 188, 241, 248, 434, 248, 270, 239, 24471, 1500, 8246, 5083, 4805, 1033, 1600, 20447, 8308, 736, 9860, 13480, 1782, 152, 2181, 35280, 2904, 1416, 4068, 8277, 21034, 346, 2133, 529, 251, 2846, 6949, 11961, 3174, 592, 4805, 1033, 1512, 12705, 27441, 36746, 43962, 42202, 283, 2133, 529, 251, 5646, 834, 1744, 4554, 350, 423, 251, 1600, 58, 17885, 1416, 4068, 4540, 1584, 1500, 38216, 1744, 7213, 319, 2634, 522, 529, 251, 20780, 19047, 2195, 3900, 10113, 46151, 3302, 1260, 24393, 96, 32095, 18109, 4540, 715, 42989, 20157, 20157, 863, 12780, 368, 529, 251, 1250, 5307, 2890, 350, 7380, 483, 159, 241, 248, 434, 248, 270, 10377, 2981, 600, 3654, 26799, 18919, 1527, 2597, 10881, 2927, 167, 4068, 11119, 483, 14537, 4510, 3798, 4619, 350, 170, 251, 1680, 2490, 13688, 7365, 404, 251, 952, 42233, 22972, 597, 20108, 975, 3235, 20663, 2903, 684, 251, 13688, 7365, 11617, 2597, 5265, 42233, 22972, 597, 20108, 25968, 2597, 4068, 11119, 1213, 2871, 553, 15550, 10881, 2927, 167, 5214, 1393, 4510, 3798, 1929, 1460, 20108, 18504, 4619, 382, 1500, 12434, 828, 170, 251, 3434, 1193, 283, 7899, 3280, 17704, 20043, 14406, 1393, 2597, 96, 971, 2708, 6161, 639, 4577, 390, 6738, 339, 251, 42233]"'}

import json;
input_ids = json.loads(a['input_ids,target_ids'].split("\",")[0].replace('\"',''))
print(f'input_ids: {type(input_ids)} {input_ids}')
target_ids = json.loads(a['input_ids,target_ids'].split("\",")[1].replace('\"',''))
print(f'target_ids: {type(target_ids)} {target_ids}')

input_ids: <class 'list'> [239, 552, 875, 904, 630, 2809, 13407, 6327, 3525, 38388, 4834, 283, 15880, 1227, 4385, 6106, 410, 37792, 12186, 170, 251, 630, 3981, 745, 12622, 22082, 6478, 875, 904, 38388, 4834, 283, 15880, 1227, 750, 834, 3345, 14118, 1017, 3656, 26349, 410, 1076, 170, 251, 2162, 800, 343, 630, 2809, 13407, 6327, 3525, 38388, 4834, 283, 4743, 2852, 5708, 251, 630, 2809, 43214, 6286, 13407, 6327, 3525, 258, 38388, 4834, 283, 6131, 675, 966, 675, 19002, 937, 15880, 5949, 952, 38388, 18818, 4743, 1962, 597, 684, 251, 1489, 9991, 207, 8484, 935, 479, 11787, 18767, 3862, 6131, 675, 966, 675, 19002, 937, 411, 92, 13623, 7945, 750, 14970, 43214, 2186, 3345, 1076, 170, 251, 25266, 328, 3110, 2022, 10835, 1366, 2549, 10090, 4719, 2036, 339, 251, 834, 7823, 750, 17438, 13763, 410, 728, 251, 5175, 20802, 201, 283, 5288, 1188, 43214, 2186, 5396, 487, 271, 3548, 45567, 19076, 843, 869, 2981, 592, 1366, 13514, 152, 8799, 592, 834, 192, 41096, 639, 750, 3672, 339, 251, 11989, 7771, 1891