# Token IDs 转 Token embeddings：数据采样
* 为了得到 input-target 对

## 前置

In [1]:
# 使用 tiktoken 库的 BPE 分词器对数据集分词：

import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")
file_path = '../../input/the-verdict.txt'
with open(file_path, 'r', encoding='utf-8') as file:
    raw_text = file.read()

enc_text = tokenizer.encode(raw_text)
print(len(enc_text))

5145


## exp - 滑动窗口生成输出目标对

In [2]:
context_size = 4
x = enc_text[:context_size]
y = enc_text[1:context_size+1]

for i in range(1, context_size+1):
    input = enc_text[:i]
    target = enc_text[i]
    print(input, "---->", target)

# 解码成文本
for i in range(1, context_size+1):
    input = enc_text[:i]
    target = enc_text[i]
    print(tokenizer.decode(input), "---->", tokenizer.decode([target]))


[40] ----> 367
[40, 367] ----> 2885
[40, 367, 2885] ----> 1464
[40, 367, 2885, 1464] ----> 1807
I ---->  H
I H ----> AD
I HAD ---->  always
I HAD always ---->  thought


### 引入张量
Pytorch 高效处理

In [3]:
# 用于批处理输入和目标的数据集：

import torch
from torch.utils.data import Dataset, DataLoader

# Pytorch的Dataset的子类
class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.target_ids = []
        token_ids = tokenizer.encode(txt)
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            # 将原始数据转换为 PyTorch 张量，享受到 PyTorch 提供的各种优化和功能
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [4]:
# 用于生成input-target的批次数据加载器：

def create_dataloader_v1(txt, bath_size = 4,
        max_length = 256, stride = 128, shuffle=True, drop_last=True):
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
    dataloader = DataLoader(
        dataset, batch_size=bath_size, shuffle=shuffle, drop_last=drop_last)
    return dataloader

### 测试

In [5]:
# 在上下文大小（context size）为4的LLM测试批量大小（bath size）为1的数据加载器

file_path = '../../input/the-verdict.txt'
with open(file_path, 'r', encoding="utf-8") as file:
    raw_text = file.read()
    dataloader = create_dataloader_v1(
        raw_text, bath_size=1, max_length=4, stride=1, shuffle=False)
    data_iter = iter(dataloader)
    first_batch = next(data_iter)
    print(first_batch)

[tensor([[  40,  367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]


## 练习2.2用不同步长和上下文大小的数据加载器加载数据

### 解决

In [6]:
# max_length=2, stride=2:

dataloader = create_dataloader_v1(
    raw_text, bath_size=1, max_length=2, stride=2, shuffle=False)
data_iter = iter(dataloader)
first_batch = next(data_iter)
print(first_batch)

[tensor([[ 40, 367]]), tensor([[ 367, 2885]])]


In [7]:
# max_length=8, stride=2:

dataloader = create_dataloader_v1(
    raw_text, bath_size=1, max_length=8, stride=2, shuffle=False)
data_iter = iter(dataloader)
first_batch = next(data_iter)
print(first_batch)

[tensor([[  40,  367, 2885, 1464, 1807, 3619,  402,  271]]), tensor([[  367,  2885,  1464,  1807,  3619,   402,   271, 10899]])]


看看batch_size > 1的tensor输出

In [8]:
dataloader = create_dataloader_v1(
    raw_text, bath_size=8, max_length=4, stride=4, shuffle=False)
data_iter = iter(dataloader)
inputs, targets = next(data_iter)
print("Input:\n", inputs)
print("Targets:\n", targets)

Input:
 tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]])
Targets:
 tensor([[  367,  2885,  1464,  1807],
        [ 3619,   402,   271, 10899],
        [ 2138,   257,  7026, 15632],
        [  438,  2016,   257,   922],
        [ 5891,  1576,   438,   568],
        [  340,   373,   645,  1049],
        [ 5975,   284,   502,   284],
        [ 3285,   326,    11,   287]])
