In [40]:
import torch
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, text, chunk_size, chunk_overlap):
        self.text = text
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

        # Create overlapping sequences
        self.sequences = [text[i:i + chunk_size] for i in range(0, len(text) - chunk_size + 1, chunk_size - chunk_overlap)]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        # Convert characters to indices
        sequence_indices = [char_to_index[char] for char in sequence]

        # Convert to PyTorch tensor
        sequence_tensor = torch.tensor(sequence_indices, dtype=torch.long)

        return sequence_tensor

In [47]:
# Example usage
with open('file.c', 'r') as f:
    text = f.read()
char_to_index = {char: idx for idx, char in enumerate(set(text))}
index_to_char = {idx: char for char, idx in char_to_index.items()}

chunk_size = 100
chunk_overlap = 50

dataset = TextDataset(text, chunk_size, chunk_overlap)
dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)

print(dir(dataset))
print(dataset.sequences)
print("\n" * 50)
for batch in dataloader:
    print("Batch:\n")
    for sequence in batch:
        decoded_sequence = ''.join([index_to_char[idx.item()] for idx in sequence])
        print(decoded_sequence)
    print("-" * 20)

['__add__', '__annotations__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__parameters__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_is_protocol', 'chunk_overlap', 'chunk_size', 'sequences', 'text']
['#include "std_testcase.h"\n\n#include <wchar.h>\n\n#ifdef _WIN32\n#define COMMAND_INT_PATH L"%WINDIR%\\\\sy', 'def _WIN32\n#define COMMAND_INT_PATH L"%WINDIR%\\\\system32\\\\cmd.exe"\n#define COMMAND_INT L"cmd.exe"\n#d', 'stem32\\\\cmd.exe"\n#define COMMAND_INT L"cmd.exe"\n#define COMMAND_ARG1 L"/c"\n#define COMMAND_ARG2 L"di', 'efine COMMAND_ARG1 L"/c"\n#define COMMAND_ARG2 L"dir "\n#define COMMAND_ARG3 data\n#else /* NOT _WIN32 ', 'r "\n

In [48]:
len(dataset)

111

In [50]:
dataset.sequences[50], dataset.sequences[51]

('t sockaddr*)&service, sizeof(service)) == SOCKET_ERROR)\n            {\n                break;\n       ',
 'RROR)\n            {\n                break;\n            }\n            if (listen(listenSocket, LISTEN')