In [1]:
import tiktoken             # 导入分词工具
import torch
import os
import numpy as np

In [2]:
# 加载处理好token碎片文件
def load_tokens(filename):
    npt = np.load(filename)
    ppt = torch.tensor(npt, dtype=torch.long)
    return ppt

In [3]:
class DataLoaderLite:

    def __init__(self, B, T, split):
        self.B = B
        self.T = T
        assert split in {'train', 'val'}

        # 获取储存tokens的碎片文件名
        data_root = "edu_fineweb10B"
        shards = os.listdir(data_root)              # 返回在data_root文件夹下文件名的列表
        shards = [s for s in shards if split in s]  # 根据split参数，选取是train数据集还是value数据集
        shards = sorted(shards)                     # 排序文件
        shards = [os.path.join(data_root, s) for s in shards]   # 将shards列表中的文件名加上文件夹的名字，组成路径
        self.shards = shards
        assert len(shards) > 0, f"no shards found for split {split}"

        print(f"found {len(shards)} shards for split {split}")
        self.reset()                                # 为什么要调用这个函数

    def reset(self):
        # 
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = 0 
        print(f"1 shard tokens:{len(self.tokens)}  batch:{self.B * self.T}")
        print(f"1 shard = {len(self.tokens) // (self.B * self.T)} batchs")


    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1] # 加载一个小小批量
        x = (buf[:-1]).view(B, T)                   # 训练集
        y = (buf[1:]).view(B, T)                    # 标签
        print(f"position:{self.current_position} shard:{self.current_shard}")

        self.current_position += B * T 

        if self.current_position + (B * T  + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards) #  此处设置的特别好，是一个轮回
            self.tokens = load_tokens(self.shards[self.current_shard]) # 加载新的tokens碎片
            self.current_position = 0
            print(f"新的碎片shard： {self.current_shard}")
        # print(f"position:{self.current_position} shard:{self.current_shard}")
        return x, y


In [7]:
train_data = DataLoaderLite(3000, 1024, 'train')
for i in range(50):

    x, y= train_data.next_batch()
    # print(x, y )

found 99 shards for split train
1 shard tokens:100000000  min——batch:3072000
1 shard = 32 batchs
position:0 shard:0
position:3072000 shard:0
position:6144000 shard:0
position:9216000 shard:0
position:12288000 shard:0
position:15360000 shard:0
position:18432000 shard:0
position:21504000 shard:0
position:24576000 shard:0
position:27648000 shard:0
position:30720000 shard:0
position:33792000 shard:0
position:36864000 shard:0
position:39936000 shard:0
position:43008000 shard:0
position:46080000 shard:0
position:49152000 shard:0
position:52224000 shard:0
position:55296000 shard:0
position:58368000 shard:0
position:61440000 shard:0
position:64512000 shard:0
position:67584000 shard:0
position:70656000 shard:0
position:73728000 shard:0
position:76800000 shard:0
position:79872000 shard:0
position:82944000 shard:0
position:86016000 shard:0
position:89088000 shard:0
position:92160000 shard:0
position:95232000 shard:0
新的碎片shard： 1
position:0 shard:1
position:3072000 shard:1
position:6144000 shard:1

In [79]:
class DataLoaderLite_input():
    
    def __init__(self, B, T):
        self.B = B
        self.T = T

        # 从训练数据中加载数据
        with open('input.txt', 'r') as f:
            text = f.read()
        
        # 先进行分词
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)

        # 将分词后的token转为tensor格式
        self.tokens = torch.tensor(tokens)

        
        # 打印显示有多少tokens
        print(f"loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B * T)} batchs")

        # 标记已经用过的数据
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        # 准备输入和标签
        buf = self.tokens[self.current_position : self.current_position+B*T+1]

        print(f"position : {self.current_position}")
        x = (buf[:-1]).view(B, T)
        y = (buf[1:]).view(B, T)

        # 下一批用新的数据
        self.current_position += B * T

        # 如果1个epoch用完了，则计数器归零
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
            print("新的轮回")
        
        return x, y

In [80]:
train_data = DataLoaderLite_input(50, 1000)
for i in range(10):
    x, y= train_data.next_batch()

loaded 338025 tokens
1 epoch = 6 batchs
position : 0
position : 50000
position : 100000
position : 150000
position : 200000
position : 250000
新的轮回
position : 0
position : 50000
position : 100000
position : 150000


In [89]:
total_batch_size = 524288 # 2**19,约0.5M的token数量
B = 16 # 小batch的大小，3090只能支持16，教程中可以是64
T = 1024
assert total_batch_size % (B * T) == 0 # 确保大批量是小批量的整数倍！
grad_accum_steps = total_batch_size // (B * T) # 这个是统计需要进行累积小批量的轮数
print(f"total desired batch size: {total_batch_size}")
print(f">= calculated gradient accumulation steps: {grad_accum_steps}")

total desired batch size: 524288
>= calculated gradient accumulation steps: 32
