In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import tokenizers
from multi_tokenizer import MultiTokenizer, partition_list_by_lang
from gpt import GPT, GPTConfig
from tqdm import tqdm
import os
import math

assert torch.cuda.is_available()
#torch.manual_seed(1337)
#if torch.cuda.is_available():
    #torch.cuda.manual_seed(1337)

device= 'cuda' if torch.cuda.is_available() else 'cpu'
assert device == 'cuda', "USE A GPU"
torch.set_default_device(device)

In [2]:
list(partition_list_by_lang([0,1,323,323,1,0,3232,2]))

[(1, [323, 323]), (0, [3232, 2])]

In [21]:
# Settings
run = "shared/8k" # multi_8k
run = "multi_8k"

# Data Loading

## Load Tokenizers

In [22]:

EXPT_NAME = 'expt_1'

shared_tokenizer_small_path = f'./{EXPT_NAME}/shared/8k/tokenizer.json'
multi_tokenizer_english = f'./{EXPT_NAME}/multi_8k/english_tokenizer.json'
multi_tokenizer_spanish = f'./{EXPT_NAME}/multi_8k/spanish_tokenizer.json'

shared_tokenizer_small = tokenizers.Tokenizer.from_file(shared_tokenizer_small_path)
multi_tokenizer_english = tokenizers.Tokenizer.from_file(multi_tokenizer_english)
multi_tokenizer_spanish = tokenizers.Tokenizer.from_file(multi_tokenizer_spanish)

multi_tokenizer = MultiTokenizer([multi_tokenizer_english, multi_tokenizer_spanish])

In [23]:
encoded = multi_tokenizer.encode("[BEGIN_EN]Hello, I like to[BEGIN_ES] comer arroz")
decoded = multi_tokenizer.decode(encoded)
decoded

shared_tokenizer_small.encode("Hello, I like to comer arroz").ids


[2127, 17, 409, 689, 281, 911, 6015]

In [24]:
if run == "multi_8k":
    tokenizer = multi_tokenizer
    vocab_size = 8192
elif run == "shared/8k":
    tokenizer = shared_tokenizer_small
    vocab_size = 8192
else:
    raise ValueError(f"Unknown run: {run}")

In [25]:
config = GPTConfig(
    block_size = 256,
    n_layer = 8,
    n_embd=256,
    n_head=8,
    use_multitokenizer=True,
    vocab_size = vocab_size
)


## Load Dataset

In [26]:
if run == "multi_8k":
    train = torch.load(f'./{EXPT_NAME}/multi_8k/train.pt')
    english_test = torch.load(f'./{EXPT_NAME}/multi_8k/english_test.pt')
    spanish_test = torch.load(f'./{EXPT_NAME}/multi_8k/spanish_test.pt')
    translation_test = torch.load(f'./{EXPT_NAME}/multi_8k/translation_test.pt')
elif run == "shared/8k":
    tokenizer = shared_tokenizer_small
    train = torch.load(f'./{EXPT_NAME}/shared/8k/train.pt')
    english_test = torch.load(f'./{EXPT_NAME}/shared/8k/english_test.pt')
    spanish_test = torch.load(f'./{EXPT_NAME}/shared/8k/spanish_test.pt')
    translation_test = torch.load(f'./{EXPT_NAME}/shared/8k/translation_test.pt')

# shuffle train
train = train[torch.randperm(train.size(0))]

In [29]:
tokenizer.decode(train[0].tolist())

'[BEGIN_ES]. ¡Los dos novios querían jugar con el perro!\n\nEl perro no sabía qué hacer. El perro quería jugar con sus novios. Pero no podía jugar con ambos a la vez.\n\nEl perro se puso triste. El perro no sabía qué hacer. El perro estaba muy, muy triste.\n\nDe repente, el perro tuvo una idea. El perro jugó con un novio durante un rato. Luego jugó con el otro novio. ¡Todos estaban felices![END][BEGIN_ES][PROMPT_ES_STORY]Un niño pequeño estaba jugando en el jardín. Vio algo verde y largo en la hierba. "¡Serpiente!" gritó. La serpiente era rápida. Se arrastró hacia un árbol. El niño se asustó. Corrió a decirle a su mamá.\n\nSu mamá fue al jardín. "¡No hay serpiente!" dijo. El niño miró. ¡La serpiente ya no estaba! El niño estaba triste. Luego, vio algo pequeño y amarillo. "¡Pollito!" gritó.\n\nEl pequeño pollito se movía rápido. Se metió debajo del árbol. El niño se rió. "¡Serpiente!" gritó de nuevo. El pollito salió corriendo. ¡El niño se rió mucho! \n\nSu mamá se rió también. "¡No es 

In [30]:
def load_tokens(filename):    
    return torch.load(filename).to(dtype=torch.long)

class DataLoaderLite:
    def __init__(self, data_dir, B, T, split, shuffle):
        self.B, self.T, self.shuffle = B, T, shuffle                
        
        shards = os.listdir(data_dir)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_dir, s) for s in shards]
        self.shards = shards
        assert len(shards) > 0, f"no shards found for split {split}"
    
        print(f"found {len(shards)} shards for split {split}")
        
        self.current_shard, self.current_position = -1, 0
        self.reset()

    def reset(self):
        self.current_shard, self.current_position = (self.current_shard + 1) % len(self.shards), 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        if self.shuffle:
            assert False, 'not implemented'
            start = time.time()
            self.shuffle_tokens()
            print(f"shuffled {self.tokens.shape[0]} tokens in {time.time() - start:.1f}s")

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T 
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.reset()
        return x, y

# Train

In [31]:

torch.set_float32_matmul_precision(config.matmul_precision)
optimizer = model.configure_optimizers(weight_decay=config.weight_decay, learning_rate=config.max_lr, beta1=config.beta_1, beta2=config.beta_2, device_type="cuda")

loss_accum = []
train = train.to(torch.int64)
def do_epoch():
    batch_stride = config.batch_size
    for current_step, batch_start_idx in enumerate(tqdm(range(0, train.shape[0]-batch_stride, batch_stride))):
        batch = train[batch_start_idx:batch_start_idx+batch_stride]
        inputs = batch[:, 0:config.block_size].contiguous()
        targets = batch[:, 1:config.block_size+1].contiguous()
        model(inputs, targets)
        model.train()
        optimizer.zero_grad()
        with torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
            logits, loss = model(inputs, targets)
        loss_accum.append(loss.detach())
        loss.backward()
        optimizer.step()
        if current_step % 100 ==0:
            print(f"step: {current_step}, loss: {loss.item()}")

do_epoch()


# current_step = 0  # Initialize the current step counter

# for inputs, targets in train_dataloader:
#     model.train()
#     optimizer.zero_grad()
#     inputs, targets = inputs.to(device), targets.to(device)
#     with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
#         logits, loss = model(inputs, targets)
#     loss = loss
#     loss_accum.append(loss.detach())
#     loss.backward()
#     norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#     optimizer.step()
    
#     current_step += 1
#     if math.log2(step) % 1 == 0:
#         print(f"Step: {current_step}, Loss: {loss.item()}")


num decayed parameter tensors: 34, with 6,356,992 parameters
num non-decayed parameter tensors: 66, with 23,040 parameters
using fused AdamW: True


  0%|          | 0/12760 [00:00<?, ?it/s]

  0%|          | 3/12760 [00:00<09:25, 22.56it/s]

step: 0, loss: 10.382516860961914


  1%|          | 105/12760 [00:04<09:21, 22.55it/s]

step: 100, loss: 3.169614315032959


  2%|▏         | 204/12760 [00:09<09:17, 22.51it/s]

step: 200, loss: 2.6747334003448486


  2%|▏         | 303/12760 [00:13<09:14, 22.48it/s]

step: 300, loss: 2.557680130004883


  3%|▎         | 405/12760 [00:17<09:10, 22.45it/s]

step: 400, loss: 2.34751296043396


  4%|▍         | 504/12760 [00:22<09:05, 22.48it/s]

step: 500, loss: 2.305171489715576


  5%|▍         | 603/12760 [00:26<09:01, 22.46it/s]

step: 600, loss: 2.1953837871551514


  6%|▌         | 705/12760 [00:31<08:58, 22.38it/s]

step: 700, loss: 2.1846060752868652


  6%|▋         | 804/12760 [00:35<08:51, 22.50it/s]

step: 800, loss: 2.0937535762786865


  7%|▋         | 903/12760 [00:40<08:47, 22.48it/s]

step: 900, loss: 2.0739471912384033


  8%|▊         | 1005/12760 [00:44<08:43, 22.46it/s]

step: 1000, loss: 2.0186572074890137


  9%|▊         | 1104/12760 [00:49<08:38, 22.47it/s]

step: 1100, loss: 2.044660806655884


  9%|▉         | 1203/12760 [00:53<08:34, 22.48it/s]

step: 1200, loss: 1.9710441827774048


 10%|█         | 1305/12760 [00:58<08:30, 22.45it/s]

step: 1300, loss: 2.02921462059021


 11%|█         | 1404/12760 [01:02<08:25, 22.48it/s]

step: 1400, loss: 1.9879050254821777


 12%|█▏        | 1503/12760 [01:06<08:26, 22.24it/s]

step: 1500, loss: 1.929155707359314


 13%|█▎        | 1605/12760 [01:11<08:21, 22.23it/s]

step: 1600, loss: 1.9586845636367798


 13%|█▎        | 1704/12760 [01:15<08:16, 22.26it/s]

step: 1700, loss: 1.9999128580093384


 14%|█▍        | 1803/12760 [01:20<08:10, 22.36it/s]

step: 1800, loss: 1.913053035736084


 15%|█▍        | 1905/12760 [01:24<08:03, 22.45it/s]

step: 1900, loss: 1.9616193771362305


 16%|█▌        | 2004/12760 [01:29<07:58, 22.47it/s]

step: 2000, loss: 2.0052990913391113


 16%|█▋        | 2103/12760 [01:33<07:54, 22.45it/s]

step: 2100, loss: 1.9287110567092896


 17%|█▋        | 2205/12760 [01:38<07:49, 22.48it/s]

step: 2200, loss: 1.8738030195236206


 18%|█▊        | 2304/12760 [01:42<07:46, 22.41it/s]

step: 2300, loss: 1.9159350395202637


 19%|█▉        | 2403/12760 [01:47<07:41, 22.45it/s]

step: 2400, loss: 1.9148452281951904


 20%|█▉        | 2505/12760 [01:51<07:26, 22.97it/s]

step: 2500, loss: 1.9300572872161865


 20%|██        | 2604/12760 [01:55<07:31, 22.49it/s]

step: 2600, loss: 1.8535243272781372


 21%|██        | 2703/12760 [02:00<07:21, 22.79it/s]

step: 2700, loss: 1.8744609355926514


 22%|██▏       | 2805/12760 [02:04<07:23, 22.46it/s]

step: 2800, loss: 1.87278413772583


 23%|██▎       | 2904/12760 [02:09<07:20, 22.40it/s]

step: 2900, loss: 1.8617777824401855


 24%|██▎       | 3003/12760 [02:13<07:14, 22.45it/s]

step: 3000, loss: 1.8946609497070312


 24%|██▍       | 3105/12760 [02:18<07:09, 22.50it/s]

step: 3100, loss: 1.8772170543670654


 25%|██▌       | 3204/12760 [02:22<07:05, 22.45it/s]

step: 3200, loss: 1.812859296798706


 26%|██▌       | 3303/12760 [02:27<07:01, 22.46it/s]

step: 3300, loss: 1.8554294109344482


 27%|██▋       | 3405/12760 [02:31<06:59, 22.32it/s]

step: 3400, loss: 1.848440170288086


 27%|██▋       | 3504/12760 [02:36<06:51, 22.47it/s]

step: 3500, loss: 1.91140878200531


 28%|██▊       | 3603/12760 [02:40<06:37, 23.04it/s]

step: 3600, loss: 1.8313504457473755


 29%|██▉       | 3705/12760 [02:44<06:31, 23.14it/s]

step: 3700, loss: 1.8310989141464233


 30%|██▉       | 3804/12760 [02:49<06:33, 22.76it/s]

step: 3800, loss: 1.8323014974594116


 31%|███       | 3903/12760 [02:53<06:32, 22.57it/s]

step: 3900, loss: 1.8190158605575562


 31%|███▏      | 4005/12760 [02:58<06:22, 22.86it/s]

step: 4000, loss: 1.797122597694397


 32%|███▏      | 4104/12760 [03:02<06:16, 23.00it/s]

step: 4100, loss: 1.8303736448287964


 33%|███▎      | 4203/12760 [03:06<06:23, 22.33it/s]

step: 4200, loss: 1.7949304580688477


 34%|███▎      | 4305/12760 [03:11<06:16, 22.46it/s]

step: 4300, loss: 1.8708149194717407


 35%|███▍      | 4404/12760 [03:15<06:13, 22.36it/s]

step: 4400, loss: 1.7947579622268677


 35%|███▌      | 4503/12760 [03:20<06:10, 22.28it/s]

step: 4500, loss: 1.8326525688171387


 36%|███▌      | 4605/12760 [03:24<06:06, 22.26it/s]

step: 4600, loss: 1.8497891426086426


 37%|███▋      | 4704/12760 [03:29<06:01, 22.28it/s]

step: 4700, loss: 1.790736198425293


 38%|███▊      | 4803/12760 [03:33<05:56, 22.34it/s]

step: 4800, loss: 1.8082343339920044


 38%|███▊      | 4905/12760 [03:38<05:49, 22.49it/s]

step: 4900, loss: 1.8204208612442017


 39%|███▉      | 5004/12760 [03:42<05:44, 22.48it/s]

step: 5000, loss: 1.8061001300811768


 40%|███▉      | 5103/12760 [03:46<05:40, 22.47it/s]

step: 5100, loss: 1.811509370803833


 41%|████      | 5205/12760 [03:51<05:35, 22.49it/s]

step: 5200, loss: 1.8021821975708008


 42%|████▏     | 5304/12760 [03:55<05:31, 22.47it/s]

step: 5300, loss: 1.7978709936141968


 42%|████▏     | 5403/12760 [04:00<05:28, 22.43it/s]

step: 5400, loss: 1.817740559577942


 43%|████▎     | 5505/12760 [04:04<05:22, 22.51it/s]

step: 5500, loss: 1.8240822553634644


 44%|████▍     | 5604/12760 [04:09<05:19, 22.42it/s]

step: 5600, loss: 1.8071706295013428


 45%|████▍     | 5703/12760 [04:13<05:13, 22.48it/s]

step: 5700, loss: 1.770641565322876


 45%|████▌     | 5805/12760 [04:18<05:09, 22.49it/s]

step: 5800, loss: 1.7620537281036377


 46%|████▋     | 5904/12760 [04:22<05:04, 22.49it/s]

step: 5900, loss: 1.8380358219146729


 47%|████▋     | 6003/12760 [04:26<05:00, 22.48it/s]

step: 6000, loss: 1.7759628295898438


 48%|████▊     | 6105/12760 [04:31<04:55, 22.52it/s]

step: 6100, loss: 1.7766577005386353


 49%|████▊     | 6204/12760 [04:35<04:51, 22.48it/s]

step: 6200, loss: 1.737634301185608


 49%|████▉     | 6303/12760 [04:40<04:47, 22.49it/s]

step: 6300, loss: 1.787656545639038


 50%|█████     | 6405/12760 [04:44<04:42, 22.49it/s]

step: 6400, loss: 1.723616361618042


 51%|█████     | 6504/12760 [04:49<04:37, 22.51it/s]

step: 6500, loss: 1.8204282522201538


 52%|█████▏    | 6603/12760 [04:53<04:33, 22.48it/s]

step: 6600, loss: 1.7556533813476562


 53%|█████▎    | 6705/12760 [04:58<04:29, 22.50it/s]

step: 6700, loss: 1.7448161840438843


 53%|█████▎    | 6804/12760 [05:02<04:24, 22.50it/s]

step: 6800, loss: 1.8162639141082764


 54%|█████▍    | 6903/12760 [05:06<04:20, 22.49it/s]

step: 6900, loss: 1.82279634475708


 55%|█████▍    | 7005/12760 [05:11<04:15, 22.49it/s]

step: 7000, loss: 1.8026946783065796


 56%|█████▌    | 7104/12760 [05:15<04:11, 22.48it/s]

step: 7100, loss: 1.7875865697860718


 56%|█████▋    | 7203/12760 [05:20<04:08, 22.38it/s]

step: 7200, loss: 1.803047776222229


 57%|█████▋    | 7305/12760 [05:24<04:02, 22.50it/s]

step: 7300, loss: 1.7834811210632324


 58%|█████▊    | 7404/12760 [05:29<03:57, 22.51it/s]

step: 7400, loss: 1.7949702739715576


 59%|█████▉    | 7503/12760 [05:33<03:53, 22.47it/s]

step: 7500, loss: 1.757644534111023


 60%|█████▉    | 7605/12760 [05:38<03:49, 22.46it/s]

step: 7600, loss: 1.7511390447616577


 60%|██████    | 7704/12760 [05:42<03:44, 22.50it/s]

step: 7700, loss: 1.8103666305541992


 61%|██████    | 7803/12760 [05:46<03:35, 22.97it/s]

step: 7800, loss: 1.7666041851043701


 62%|██████▏   | 7905/12760 [05:51<03:35, 22.52it/s]

step: 7900, loss: 1.783050537109375


 63%|██████▎   | 8004/12760 [05:55<03:31, 22.46it/s]

step: 8000, loss: 1.7265985012054443


 64%|██████▎   | 8103/12760 [06:00<03:27, 22.47it/s]

step: 8100, loss: 1.743269681930542


 64%|██████▍   | 8205/12760 [06:04<03:22, 22.49it/s]

step: 8200, loss: 1.814278244972229


 65%|██████▌   | 8304/12760 [06:09<03:18, 22.48it/s]

step: 8300, loss: 1.764228343963623


 66%|██████▌   | 8403/12760 [06:13<03:13, 22.50it/s]

step: 8400, loss: 1.7889931201934814


 67%|██████▋   | 8505/12760 [06:18<03:09, 22.45it/s]

step: 8500, loss: 1.766266107559204


 67%|██████▋   | 8604/12760 [06:22<03:04, 22.47it/s]

step: 8600, loss: 1.7856825590133667


 68%|██████▊   | 8703/12760 [06:26<03:00, 22.53it/s]

step: 8700, loss: 1.7396173477172852


 69%|██████▉   | 8805/12760 [06:31<02:56, 22.36it/s]

step: 8800, loss: 1.8033188581466675


 70%|██████▉   | 8904/12760 [06:35<02:53, 22.27it/s]

step: 8900, loss: 1.7866849899291992


 71%|███████   | 9003/12760 [06:40<02:47, 22.46it/s]

step: 9000, loss: 1.7306792736053467


 71%|███████▏  | 9105/12760 [06:45<02:50, 21.45it/s]

step: 9100, loss: 1.7259527444839478


 72%|███████▏  | 9204/12760 [06:49<02:39, 22.34it/s]

step: 9200, loss: 1.7401759624481201


 73%|███████▎  | 9303/12760 [06:54<02:33, 22.47it/s]

step: 9300, loss: 1.7385790348052979


 74%|███████▎  | 9405/12760 [06:58<02:29, 22.50it/s]

step: 9400, loss: 1.729856252670288


 74%|███████▍  | 9504/12760 [07:03<02:24, 22.47it/s]

step: 9500, loss: 1.7420759201049805


 75%|███████▌  | 9603/12760 [07:07<02:20, 22.47it/s]

step: 9600, loss: 1.7082765102386475


 76%|███████▌  | 9705/12760 [07:11<02:16, 22.46it/s]

step: 9700, loss: 1.754081130027771


 77%|███████▋  | 9804/12760 [07:16<02:11, 22.56it/s]

step: 9800, loss: 1.7449241876602173


 78%|███████▊  | 9903/12760 [07:20<02:06, 22.52it/s]

step: 9900, loss: 1.7367761135101318


 78%|███████▊  | 10005/12760 [07:25<02:02, 22.47it/s]

step: 10000, loss: 1.7421298027038574


 79%|███████▉  | 10104/12760 [07:29<01:58, 22.49it/s]

step: 10100, loss: 1.7671703100204468


 80%|███████▉  | 10203/12760 [07:34<01:53, 22.46it/s]

step: 10200, loss: 1.6856595277786255


 81%|████████  | 10305/12760 [07:38<01:49, 22.42it/s]

step: 10300, loss: 1.776463270187378


 82%|████████▏ | 10404/12760 [07:43<01:44, 22.50it/s]

step: 10400, loss: 1.8147400617599487


 82%|████████▏ | 10503/12760 [07:47<01:40, 22.49it/s]

step: 10500, loss: 1.7169241905212402


 83%|████████▎ | 10605/12760 [07:51<01:36, 22.40it/s]

step: 10600, loss: 1.7197169065475464


 84%|████████▍ | 10704/12760 [07:56<01:31, 22.46it/s]

step: 10700, loss: 1.7478634119033813


 85%|████████▍ | 10803/12760 [08:00<01:27, 22.46it/s]

step: 10800, loss: 1.7526520490646362


 85%|████████▌ | 10905/12760 [08:05<01:22, 22.47it/s]

step: 10900, loss: 1.8683745861053467


 86%|████████▌ | 11004/12760 [08:09<01:18, 22.37it/s]

step: 11000, loss: 1.7629541158676147


 87%|████████▋ | 11103/12760 [08:14<01:14, 22.31it/s]

step: 11100, loss: 1.7545894384384155


 88%|████████▊ | 11205/12760 [08:18<01:09, 22.46it/s]

step: 11200, loss: 1.7413052320480347


 89%|████████▊ | 11304/12760 [08:23<01:04, 22.49it/s]

step: 11300, loss: 1.7497577667236328


 89%|████████▉ | 11403/12760 [08:27<01:00, 22.43it/s]

step: 11400, loss: 1.7378063201904297


 90%|█████████ | 11505/12760 [08:32<00:55, 22.48it/s]

step: 11500, loss: 1.7347111701965332


 91%|█████████ | 11604/12760 [08:36<00:51, 22.45it/s]

step: 11600, loss: 1.7495036125183105


 92%|█████████▏| 11703/12760 [08:40<00:46, 22.49it/s]

step: 11700, loss: 1.7557374238967896


 93%|█████████▎| 11805/12760 [08:45<00:42, 22.45it/s]

step: 11800, loss: 1.7314285039901733


 93%|█████████▎| 11904/12760 [08:49<00:38, 22.48it/s]

step: 11900, loss: 1.7526450157165527


 94%|█████████▍| 12003/12760 [08:54<00:33, 22.42it/s]

step: 12000, loss: 1.717010259628296


 95%|█████████▍| 12105/12760 [08:58<00:29, 22.44it/s]

step: 12100, loss: 1.7595239877700806


 96%|█████████▌| 12204/12760 [09:03<00:24, 22.42it/s]

step: 12200, loss: 1.7270032167434692


 96%|█████████▋| 12303/12760 [09:07<00:20, 22.51it/s]

step: 12300, loss: 1.795576810836792


 97%|█████████▋| 12405/12760 [09:12<00:16, 22.00it/s]

step: 12400, loss: 1.7271603345870972


 98%|█████████▊| 12504/12760 [09:16<00:11, 22.22it/s]

step: 12500, loss: 1.7427892684936523


 99%|█████████▉| 12603/12760 [09:20<00:07, 22.23it/s]

step: 12600, loss: 1.7326565980911255


100%|█████████▉| 12705/12760 [09:25<00:02, 22.21it/s]

step: 12700, loss: 1.7633440494537354


100%|██████████| 12760/12760 [09:28<00:00, 22.46it/s]


NameError: name 'reload' is not defined