In [1]:
import torch
import os
import torch.nn as nn
import numpy as np

In [2]:
def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32)
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt

In [None]:
# no. of tokens, vector representation of tokens

In [3]:
class data_loader:
    def __init__(self, B, T, split, data_root):
        self.B = B
        self.T = T
        assert split in {'train', 'val'}

        data_root = data_root
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        if (len(shards) > 0) == False:
            print(f"no shards found for split {split}")
        # assert len(shards) > 0, f"no shards found for split {split}"
        print(f"[green]found {len(shards)} shards for split {split}[/green]")
        self.reset()

    def reset(self):
        # state, init at shard zero
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T
        return x, y

In [41]:
data = data_loader(B=4096, T=1024, split='train', data_root='data')

[green]found 99 shards for split train[/green]


In [42]:
data_batch = data.next_batch()

In [43]:
data_batch

(tensor([[  329,   262,  1175,  ...,   477,  5582, 33052],
         [  290,  8849,   606,  ...,   262,  5103,   286],
         [  262, 20091,  1671,  ..., 24475,   779,   329],
         ...,
         [30933,    69,   380,  ...,  4967,    13,   785],
         [   14, 16049, 45429,  ..., 48072,   287,   262],
         [ 1366,   284,   670,  ...,   423, 20261,   517]]),
 tensor([[  262,  1175,    13,  ...,  5582, 33052,   290],
         [ 8849,   606,   351,  ...,  5103,   286,   262],
         [20091,  1671,  9116,  ...,   779,   329,   262],
         ...,
         [   69,   380,   408,  ...,    13,   785,    14],
         [16049, 45429,   406,  ...,   287,   262,  1366],
         [  284,   670,  1088,  ..., 20261,   517,   588]]))

In [44]:
data_batch[0].shape

torch.Size([4096, 1024])

In [34]:
len(data_batch)

2