In [1]:
!nvidia-smi

Sat May 11 10:51:12 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PG5...  On   | 00000000:10:00.0 Off |                    0 |
| N/A   37C    P0    80W / 330W |   5257MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-PG5...  On   | 00000000:13:00.0 Off |                    0 |
| N/A   32C    P0    54W / 330W |      3MiB / 40960MiB |      0%      Default |
|       

In [2]:
import sys
sys.path.insert(0, '/home/bhavul/bhavul/NEKO/')

In [3]:
from __future__ import annotations
# supports dataset in huggingface datasets library for now

import torch
import time
import os

import wandb
import numpy as np
import torch
import random
import os
from datetime import datetime

import wandb
import torch

from peft import LoraConfig, TaskType, get_peft_model
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
from accelerate import DataLoaderConfiguration

from datasets import load_dataset, concatenate_datasets
import numpy as np
from torch import nn
from typing import TYPE_CHECKING, List,Dict
from transformers import AutoTokenizer
import torch
import copy

from __future__ import annotations
from typing import Optional, Union, TYPE_CHECKING
import torch
import torch.nn as nn

import transformers
from transformers import AutoTokenizer

# import gato
from gato.transformers import GPT2Model
from gato.training.trainer import Trainer
from gato.training.schedulers import get_linear_warmup_cosine_decay_scheduler
from gato.tasks.task import Task
from gato.utils.utils import save_model
from gato.training.arguments import TrainingArgs


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class GatoPolicy(nn.Module):
    def __init__(
        self,
        device: Union[torch.device, str],
        embed_dim: int,
        layers: int,
        heads: int,
        dropout: float,

        activation_fn='gelu',

        mu: int = 100,
        M: int = 256,

        patch_size: int = 16,
        resid_mid_channels: int = 132,
        num_groups: int = 32,
        position_vocab_size: int = 128,
        continuous_tokens: int = 1024,
        discrete_tokens: int = 1024,

        context_len=1024,

        use_pos_encoding: bool = True,
        use_patch_pos_encoding: bool = True,

        pretrained_lm: Optional[str] = None, # Optional, name of pretrained language model to use
        flash: bool = False, # TODO verify correctness
        tokenizer_model_name: str = 'gpt2',
        pad_seq: bool = False
    ):
        super().__init__()
        self.device = device
        self.context_len = context_len
        
        # Text Tokenizer
        self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)        
        # tokens
        self.vocab_size = self.text_tokenizer.vocab_size 
        if self.text_tokenizer.pad_token is None:
            self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
        

        if pretrained_lm is not None:
            print(f'loading pretrained GPT2 weights')
            config = transformers.GPT2Config.from_pretrained(pretrained_lm)
            config.attn_pdrop = dropout # 0.1
            config.resid_pdrop = dropout
            config.flash = flash
            config.gate = False
            config.attn_pdrop = dropout # 0.1
            config.resid_pdrop = dropout
            self.transformer = GPT2Model.from_pretrained(
                pretrained_lm,
                config=config,
            )
            embed_dim = config.n_embd
            # assert self.transformer.wte.weight.shape[0] == self.text_tokens, "pretrained token/expected mimsatch" # potentially make text_tokens dynamic
        else:
            gate = False
            if activation_fn == 'geglu':
                gate = True
                activation_fn = 'gelu'
            config = transformers.GPT2Config(
                vocab_size=1,  # doesn't matter -- we don't use the vocab
                n_embd=embed_dim,
                n_head=heads,
                n_layer=layers,
                resid_pdrop=dropout,
                attn_pdrop=dropout,
                n_positions=context_len,
                n_inner=embed_dim * 4,
                activation_function=activation_fn,
            )
            config.n_ctx = context_len
            config.gate = gate
            config.flash = flash
            self.transformer = self.transformer = GPT2Model(config)
        
        # embedding tokens
        self.embed_token = nn.Embedding(self.vocab_size, embed_dim)
        if pretrained_lm is not None:
            self.embed_token.weight.data[:] = self.transformer.wte.weight.data
        
        
        # head
        self.predict_token = nn.Linear(embed_dim, self.vocab_size, bias=False)
        self.separator_token = nn.Parameter(torch.zeros(embed_dim))

    @property
    def module(self):
        return self

    def forward(self, inputs: Optional[list]=None, compute_loss=False, **kwargs):
        # tokenize inputs
        if inputs is not None:
            token_embeddings, tokens, token_masks, target_tokens, target_masks = self.tokenize_input_dicts(inputs)
        else:
            token_embeddings = kwargs['token_embeddings']
            tokens = kwargs['tokens']
            token_target_masks = kwargs['token_target_masks']
            token_masks = kwargs['token_masks']

        assert token_embeddings is not None, "token_embeddings is None"
        assert token_masks is not None, "token_masks is None"

        final_representations = self.transformer(inputs_embeds=token_embeddings, attention_mask=token_masks)['last_hidden_state']
        logits = self.predict_token(final_representations)
        # assert 'target' in kwargs, "target is not there in kwargs"

        # print(f"Type of target_tokens: {type(target_tokens)}")
        # print(f"Shape of target_tokens: {target_tokens.shape if isinstance(target_tokens, torch.Tensor) else 'N/A'}")
        # print(f"Type of pad_token_id: {type(self.text_tokenizer.pad_token_id)}")
        if compute_loss:
            # Ensuring target_tokens is a tensor
            if not isinstance(target_tokens, torch.Tensor):
                raise TypeError("target_tokens must be a torch.Tensor")
            
            # Correctly computing the loss mask
            loss_masks = (target_tokens != self.text_tokenizer.pad_token_id)
            if isinstance(loss_masks, torch.Tensor):
                loss_masks = loss_masks.float()  # Convert boolean tensor to float
            else:
                raise TypeError("Loss mask calculation did not return a tensor.")
            # loss_masks = (target_tokens != self.text_tokenizer.pad_token_id).float()
            loss = torch.nn.functional.cross_entropy(logits.view(-1, self.vocab_size), target_tokens.view(-1), reduction='none')
            loss = (loss * loss_masks.view(-1)).sum() / loss_masks.sum()
        else:
            loss = None
    
        return logits, loss


    def tokenize_input_dicts(self, inputs: list):
        # if not inputs:
        #     return None, None, None, None
    
        batch_len = len(inputs)
        max_input_tokens = max(len(batch['text']) for batch in inputs)
        max_target_tokens = max(len(batch['target']) for batch in inputs) if 'target' in inputs[0] else 0
        
        # Allocate tensors for input tokens
        token_embeddings = torch.zeros((batch_len, max_input_tokens, self.embed_token.embedding_dim), device=self.device)
        tokens = torch.zeros((batch_len, max_input_tokens), dtype=torch.long, device=self.device)
        token_masks = torch.zeros((batch_len, max_input_tokens), device=self.device)
        
        # Allocate tensors for target tokens if they exist
        target_tokens = torch.zeros((batch_len, max_target_tokens), dtype=torch.long, device=self.device)
        target_masks = torch.zeros((batch_len, max_target_tokens), device=self.device)
    
        for i, batch in enumerate(inputs):
            # Process input tokens
            input_tokens = batch['text'].to(device=self.device) if isinstance(batch['text'], torch.Tensor) else torch.tensor(batch['text'], dtype=torch.long, device=self.device)
            n_input_timesteps = len(input_tokens)
            
            tokens[i, :n_input_timesteps] = input_tokens
            token_embeddings[i, :n_input_timesteps] = self.embed_token(input_tokens)
            token_masks[i, :n_input_timesteps] = 1
            
            # Process target tokens if they exist
            if 'target' in batch:
                target_data = batch['target'].to(device=self.device) if isinstance(batch['target'], torch.Tensor) else torch.tensor(batch['target'], dtype=torch.long, device=self.device)
                n_target_timesteps = len(target_data)
                target_tokens[i, :n_target_timesteps] = target_data
                target_masks[i, :n_target_timesteps] = 1
    
        return token_embeddings, tokens, token_masks, target_tokens, target_masks

    def predict_text(self, input_text, max_length=20, deterministic=True, context_length=1024):
        tokenized_outputs = self.text_tokenizer(input_text, truncation=True, padding="longest", max_length=args.sequence_length, return_tensors='pt')

        input_tokens = tokenized_outputs['input_ids']
        predicted_tokens = input_tokens.clone()
    
        for _ in range(max_length):
            token_embeddings = self.embed_token(predicted_tokens.to(device))
            token_masks = torch.ones((predicted_tokens.to(device).shape[0], 1), device=device)

            logits, _ = self.forward(token_embeddings=token_embeddings, tokens=predicted_tokens.to(device), token_masks=token_masks, token_target_masks=None)
            logits = logits[:, -1, :]
                
    
            if deterministic:
                next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)  # Ensure it keeps batch dimension
            else:
                probs = torch.nn.functional.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)  # Sampling a token
    
            predicted_tokens = torch.cat([predicted_tokens.to(device), next_token.to(device)], dim=1)
    
        # all_logits = torch.cat(logits_list, dim=1)
        return predicted_tokens[:, input_tokens.size(1):]

    
    def predict_text_single_single(self, batch_dict, max_length=20, deterministic=True, top_p=0.9):
        input_tokens = torch.tensor(batch_dict['text'], dtype=torch.long, device=self.device).unsqueeze(0)
        
        predicted_tokens = []
    
        for _ in range(max_length):
            token_embeddings = self.embed_token(input_tokens)
            token_masks = torch.ones_like(input_tokens)

            logits, _ = self.forward(token_embeddings=token_embeddings, tokens=input_tokens, token_target_masks=None, token_masks=token_masks)
            logits = logits[:, -1, :]  # focus on the last time-step logits
    
            if deterministic:
                token = torch.argmax(logits, dim=-1)
            else:
                probs = torch.nn.functional.softmax(logits, dim=-1)
                token = torch.multinomial(probs, 1)  # Sampling a token
    
            if token.numel() == 1:  # Checking if token is a single element
                predicted_tokens.append(token.item())
            else:
                print(f"Expected a single element, got {token.numel()} elements.")
    
            input_tokens = torch.cat([input_tokens, token], dim=1)  # Append the predicted token

            if token == self.text_tokenizer.eos_token_id:
                break
    
        return logits, predicted_tokens

In [5]:
class TextTask(Task): 
    def __init__(self, dataset_names:List[str], dataset_paths:List[str], context_length:int, tokenizer_model:str):
        super().__init__()
        self.context_length = context_length
        self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
        if self.text_tokenizer.pad_token is None:
            self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
        text_datasets_list = []
        assert len(dataset_names) == len(dataset_paths), "The dataset names and paths parameters should have corresponding values and hence equal lengths"
        for i, text_dataset in enumerate(dataset_names):
            text_datasets_list.append(load_dataset(path=dataset_paths[i], name=text_dataset))
        if len(text_datasets_list) == 1:
            self.text_dataset = text_datasets_list[0]
        else:            
            # https://huggingface.co/docs/datasets/v2.14.4/en/process#concatenate
            # must have the same feature columns
            self.text_dataset = concatenate_datasets(text_datasets_list)

    def sample_batch(self, batch_size, is_test=False) -> List[Dict]:
        split = 'train' if not is_test else 'test'
        dataset_split = self.text_dataset[split]

        if len(dataset_split) < batch_size:
            print(f"Warning: Requested batch size {batch_size} is larger than the dataset size {len(dataset_split)}.")
            batch_size = len(dataset_split)  # Adjust batch size to available data size

        if batch_size == 0:
            return []  # Early exit if no data is available

        
        sampled_indices = torch.randperm(len(dataset_split))[:batch_size]
        samples = dataset_split.select(sampled_indices)
        tokenized_outputs = self.text_tokenizer(samples['text'], truncation=True, padding="longest", max_length=self.context_length, return_tensors='pt')
    
        batch_dicts = []
        for input_ids in tokenized_outputs["input_ids"]:
            if input_ids.numel() > 0:  # Check if non-empty
                # Split into input and target tokens
                input_tokens = input_ids[:-1]
                target_tokens = input_ids[1:]
                batch_dicts.append({
                    'text': input_tokens,
                    'target': target_tokens,
                })
    
        return batch_dicts

    def evaluate(self, model: GatoPolicy, num_examples_to_test=50, deterministic=False, is_test=True):
        split = 'train' if not is_test else 'test'
        dataset_split = self.text_dataset[split]
        
        num_examples_to_test = min(num_examples_to_test, len(dataset_split))
        
        if num_examples_to_test == 0:
            return {'loss': float('nan'), 'perplexity': float('nan')}
    
        batch_dicts = self.sample_batch(num_examples_to_test, is_test)

        # input_tokens = torch.stack([b['text'] for b in batch_dicts]).to(model.device)
        # target_tokens = torch.stack([b['target'] for b in batch_dicts]).to(model.device)
        
        # input_tokens = torch.stack([b['text'] for b in batch_dicts]).to(model.device)
        # target_tokens = torch.stack([b['target'] for b in batch_dicts]).to(model.device)

        # Forward pass    
        logits, loss = model(batch_dicts, compute_loss=True)
        
        # total_tokens = input_tokens.size(0) * input_tokens.size(1)
        # print(f'total tokens:{total_tokens}')
        avg_loss = loss.item() 
        perplexity = torch.exp(torch.tensor(avg_loss)).item()
                        
        return {'loss': avg_loss, 'perplexity': perplexity}

    def evaluate_single_single(self, model: GatoPolicy, num_examples_to_test=50, deterministic=False, log_examples_to_output=False, is_test=True):
        split = 'train' if not is_test else 'test'
        dataset_split = self.text_dataset[split]
        num_examples_to_test = min(num_examples_to_test, len(dataset_split))
    
        if num_examples_to_test == 0:
            return {'loss': float('nan'), 'perplexity': float('nan')}
    
        batch_dicts = self.sample_batch(num_examples_to_test, is_test)
        total_loss, total_tokens = 0.0, 0
    
        for batch_dict in batch_dicts:
            input_tokens = batch_dict['text'].to(device=model.device)
            target_tokens = batch_dict['target'].to(device=model.device)
    
            total_loss_per_sequence = 0.0
            pred_tokens = []
    
            for idx in range(input_tokens.size(0)):
                pred_logits, single_pred_tokens = model.predict_text({'text': input_tokens[idx].unsqueeze(0)}, max_length=1, deterministic=deterministic)
                loss = torch.nn.functional.cross_entropy(pred_logits, target_tokens[idx].unsqueeze(0))
                total_loss_per_sequence += loss.item()
                pred_tokens.extend(single_pred_tokens)
            
            total_loss += total_loss_per_sequence / input_tokens.size(0)
            total_tokens += input_tokens.size(0)
    
            if log_examples_to_output:
                decoded_input = self.text_tokenizer.decode(input_tokens.squeeze(), skip_special_tokens=True)
                decoded_target = self.text_tokenizer.decode(target_tokens.squeeze(), skip_special_tokens=True)
                decoded_prediction = self.text_tokenizer.decode(torch.tensor(pred_tokens), skip_special_tokens=True)            
                print(f'=>Input: {decoded_input} \n =>Target: {decoded_target} \n =>Prediction: {decoded_prediction}')
    
        avg_loss = total_loss / total_tokens
        perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
        return {'loss': avg_loss, 'perplexity': perplexity}

## trainer stuff

In [15]:
args = TrainingArgs(
    training_steps=10000,
    log_eval_freq=10,
    warmup_steps=100,
    batch_size=8,
    sequence_length=1024,
    eval_episodes=5,
    text_prop=1,
    eval_text_log_examples=True,
    # pretrained_lm='gpt2',
    text_datasets=['wikitext-2-v1'],
    text_datasets_paths=['wikitext'],
    use_wandb=True,
    device='cuda',
    eval_mode='stochastic',
    eval_text_num_examples=100,
    # disable_cosine_decay=True
)

In [16]:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if args.use_wandb:
    log_with = 'wandb'
else:
    log_with = None
dl_config = DataLoaderConfiguration(split_batches=True)
accelerator = Accelerator(
    cpu=args.cpu,
    dataloader_config=dl_config, 
    # mixed_precision=args.mixed_precision,
    # gradient_accumulation_steps=args.gradient_accumulation_steps,
    kwargs_handlers=[ddp_kwargs],
    log_with=log_with
)
args.device = accelerator.device.type
exp_date = datetime.now().strftime('%y-%m-%d_%H-%M-%S')
exp_name = f'neko-gato_{exp_date}'

model = GatoPolicy(
        device=args.device,
        embed_dim=args.embed_dim,
        layers=args.layers,
        heads=args.heads,
        dropout=args.dropout,
        mu=args.mu,
        M=args.M,
        patch_size=args.patch_size,
        resid_mid_channels=args.resid_mid_channels,
        continuous_tokens=args.continuous_tokens,
        discrete_tokens=args.discrete_tokens,
        context_len=args.sequence_length,
        use_patch_pos_encoding=not args.disable_patch_pos_encoding,
        use_pos_encoding=not args.disable_inner_pos_encoding,
        activation_fn=args.activation_fn,
        pretrained_lm=args.pretrained_lm,
        flash=args.flash,
        tokenizer_model_name=args.tokenizer_model_name,
        pad_seq=args.pad_seq,
    )
model = accelerator.prepare(model)
model.device = args.device

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.beta_1, args.beta_2),
    eps=args.adam_eps,
    weight_decay=args.weight_decay,
)

scheduler = get_linear_warmup_cosine_decay_scheduler(optimizer, args.warmup_steps, args.training_steps, base_lr=args.learning_rate, init_lr=args.init_lr, min_lr=args.learning_rate / args.min_factor, cosine_decay=not args.disable_cosine_decay)
optimizer, scheduler = accelerator.prepare(optimizer, scheduler)

if args.use_wandb:
    accelerator.init_trackers(args.wandb_project, init_kwargs={'wandb': {'name': exp_name, 'config': args}})
else:
    accelerator.init_trackers('')

tasks = [TextTask(args.text_datasets, args.text_datasets_paths, args.sequence_length, tokenizer_model=args.tokenizer_model_name)]
args = args
print_logs = True # args.print_logs
device = torch.device(args.device)

min_lr = args.learning_rate / args.min_factor
deterministic = args.eval_mode == 'deterministic'

exp_name = exp_name
exp_dir = os.path.join(args.save_dir, exp_name)

steps = 0
start_time = None

# Create save dir if does not exist
if args.save_model and not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

Using pad_token, but it is not set yet.


Using pad_token, but it is not set yet.


In [17]:
def sample_text_batch(batch_size):
    batch_dicts = []
    text_tasks = [t for t in tasks if isinstance(t, TextTask)]
    for i,task in enumerate (text_tasks):
        return task.sample_batch(batch_size)

In [18]:
def train_step():
    logs = {}
    logs['training/learning_rate'] = scheduler.get_lr()[0] # store LR at current step
    # Build training batch
    start_time = time.time()

    # Calculate batch size for each task, the following need to be revised to including more new tasks
    text_batch_size = int(args.text_prop * args.batch_size)
    remainder = args.batch_size - text_batch_size

    if remainder > 0: 
        text_batch_size += remainder

    assert args.batch_size == text_batch_size, "Total batch size is not eqaual to the sum of each task's batch size" 

    text_batch_dicts = []

    # Sample text and control batches
    if text_batch_size > 0:
        text_batch_dicts = sample_text_batch(text_batch_size)

    if not text_batch_dicts:  # Handle empty batch case
        # print("Received an empty batch. Skipping this step.")
        return None  # You could return None or handle this case based on your training logic

    # print(f'text_batch_size:{text_batch_size}')

    logs['time/sample_batch'] = time.time() - start_time
    with accelerator.accumulate(model):
        # Compute loss and update model
        logits, loss = model.forward(inputs = text_batch_dicts, compute_loss=True)
        accelerator.backward(loss)

        if not args.disable_grad_clip and accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model.parameters(), args.grad_norm_clip)

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    return loss.detach().cpu().item(), logs

In [41]:
def train_iteration(num_steps, iter):
    logs = {}

    train_start = time.time()

    train_losses = []
    steps = 0
    model.train()
    for i in range(num_steps):
        steps += 1
        result = train_step()
        if result is None:
            # steps -= 1
            # print("Skipped a training step due to empty batch.")
            continue
        train_loss, step_logs = result
        train_losses.append(train_loss)

    # add logs from last train_step as well
    for log in step_logs:
        logs[log] = step_logs[log]

    logs['time/training'] = time.time() - train_start

    eval_start = time.time()
    model.eval()

    # loop over eval for each env
    with torch.no_grad():
        for task in tasks:
            eval_logs = {}
            if isinstance(task, TextTask):
                eval_logs = task.evaluate(model, num_examples_to_test=args.eval_text_num_examples, deterministic=deterministic)
                for k, v in eval_logs.items():
                    logs[f'evaluation/text/{k}'] = v
                pass

                if iter % 100 == 0 and args.eval_text_log_examples:
                    dataset_split = task.text_dataset['test']

                    sampled_indices = torch.randperm(len(dataset_split))[:5]
                    samples = dataset_split.select(sampled_indices)
                    
                    for sample in samples:
                        actual_text = sample['text']
                        # roughly speaking...splitting by spaces
                        words_list = actual_text.split()
                        if len(words_list) > 1:
                            split_index = random.randint(1, len(words_list)-1)
                            input_text, target_text = ' '.join(words_list[:split_index]), ' '.join(words_list[split_index:])  
                            pred_tokens = model.predict_text(input_text='Hello how are', max_length=len(words_list[split_index:]), deterministic=deterministic)
                            decoded_target = task.text_tokenizer.decode(pred_tokens.squeeze(), skip_special_tokens=True)
                            print(f'Input: {input_text} | Output : {target_text} | Prediction: {decoded_target}')

    logs['time/total'] = time.time() - start_time
    logs['time/evaluation'] = time.time() - eval_start
    logs['training/train_loss_mean'] = np.mean(train_losses)
    logs['training/train_loss_std'] = np.std(train_losses)

    if accelerator.is_main_process:
        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter}')
            for k, v in logs.items():
                print(f'{k}: {v}')
            print('=' * 80)

    ## Save model
    if args.save_model and args.save_mode == 'checkpoint':
        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            unwrapped_model = accelerator.unwrap_model(model)
            save_model(unwrapped_model, exp_dir, f'checkpoint_{steps}', args)
                

    return logs

In [38]:
# torch.cuda.empty_cache()

In [39]:
with torch.no_grad():
    for task in tasks:
        eval_logs = {}
        if isinstance(task, TextTask):
            eval_logs = task.evaluate(model, num_examples_to_test=args.eval_text_num_examples, deterministic=deterministic)
            print(eval_logs)


{'loss': 8.033501625061035, 'perplexity': 3082.5166015625}


In [40]:
start_time = time.time()
iters = args.training_steps // args.log_eval_freq
print(f'iters:{iters}')
for i in range(iters):
    logs = train_iteration(args.log_eval_freq, i)
    accelerator.log(logs)

## Save model at end of training only if not saving checkpoints
if args.save_model and args.save_mode == 'last':
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        save_model(unwrapped_model, exp_dir, f'checkpoint_{steps}', args)
        torch.cuda.empty_cache()    

accelerator.end_training()

iters:1000
Input: <unk>openedtheyearfollowingtherevoltbysharingthe<unk>with<unk>.Again,thehonoursuggested<unk>hadplayedapartinuncoveringtheconspiracy,perhapsinafashionsimilartowhathedidduringthe<unk>conspiracyunder | Output : <unk>.Alternatively,<unk>mayhaveselected<unk>ashiscolleaguetoemphasisethestabilityandstatus@-@<unk>oftheregime.Therevolthadbeensuppressed,andtheEmpirecouldreturntoorder. | Prediction:  321 the entirely externalToEVAOnlyOwner..." Jak seats 255. ABE Gordon card DEN the clone defiant fusionabbyocr
Input: = | Output : ==<unk>=== | Prediction: ,review forolding fram militants debts GM unbelievable diplomat taps Exercise caveats Sitting prostitution ultrasound manner initiallymodulesscore
Input: TheFrenchnavybuiltthefirstironcladtotrytogainastrategicadvantageovertheBritish,butwereconsistentlyout@-@builtbytheBritish.Despitetakingtheleadwithanumberofinnovationslikebreech@-@loadingweaponsandsteelconstruction,theFrenchnavycouldnevermatchthesizeoftheRoyalNavy.Inthe1870s,thec

0,1
evaluation/text/loss,█▆▆▅▅▅▄▄▃▄▃▄▃▃▂▃▄▄▃▄▃▂▃▂▃▄▂▃▃▂▂▁▂▂▁▁▁▁▁▁
evaluation/text/perplexity,█▅▆▄▄▄▃▃▂▃▂▃▂▂▂▂▃▃▂▃▂▂▂▂▂▃▂▂▂▂▁▁▂▂▁▁▁▁▁▁
time/evaluation,▅▂▅▆▅▃█▄▆▂▅▄▂▃▂▅▄▁▄▃▃▅▆▂▆▃▁▄▂▃▄▃▄▂▂▁▁▃▃▂
time/sample_batch,▄▆▅▁▅▂▆▇▄▆▆▆▄▁▅▅▇▅▄▅▄▆▄▂█▅▅▃▃▆▄▅▅▅▄▃▆▃▃▆
time/total,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
time/training,▆▅▄▄▆▇▆██▂▅▃▄▅▃▃▅▅▅▃▃▂▄▅▅▇▄▁▂▂▇▄▄▅▃▆▃▆▅▅
training/learning_rate,███████▇▇▇▇▇▆▆▆▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁
training/train_loss_mean,█▆▅▅▅▅▄▄▄▃▃▃▄▂▃▄▂▄▃▂▁▃▂▂▂▂▂▁▂▂▃▂▂▂▂▁▁▁▂▂
training/train_loss_std,▄▇▁▃▂▃▂▃▂▃▁▁▂█▃▂▃▃▂▄█▃▃▂▂▂▂▃▂▃▁▂▂▆▃▂▆▃▂▂

0,1
evaluation/text/loss,4.94261
evaluation/text/perplexity,140.1357
time/evaluation,0.82331
time/sample_batch,0.00898
time/total,1697.54563
time/training,1.22272
training/learning_rate,1e-05
training/train_loss_mean,4.88799
training/train_loss_std,0.31569


## Testing of the trained model.

In [56]:
def test_predict_text_on_random_examples(task, num_of_examples_to_test=10, deterministic=False):
    
    model.eval()
    with torch.no_grad():
        dataset_split = task.text_dataset['test']
        
        sampled_indices = torch.randperm(len(dataset_split))[:num_of_examples_to_test]
        samples = dataset_split.select(sampled_indices)
        
        for sample in samples:
            actual_text = sample['text']
            # roughly speaking...splitting by spaces
            words_list = actual_text.split()
            if len(words_list) > 1:
                split_index = random.randint(1, len(words_list)-1)
                input_text, target_text = ' '.join(words_list[:split_index]), ' '.join(words_list[split_index:])  
                pred_tokens = model.predict_text(input_text=input_text, max_length=len(words_list[split_index:]), deterministic=deterministic)
                decoded_target = task.text_tokenizer.decode(pred_tokens.squeeze(), skip_special_tokens=True)
                print(f'Input: {input_text} \nOutput : {target_text} \nPrediction: {decoded_target}\n\n')

In [57]:
test_predict_text_on_random_examples(tasks[0], deterministic=True)

Input: Below this frequency 
Output : , the image <unk> is real , 
Prediction:  of the <unk>, the


Input: The first text to suggest that <unk> ordered the execution of an <unk> is a letter by Clement to the <unk> traditional dated to around 96 <unk> The <unk> Ascension of Isaiah , a Christian writing from the 2nd century says , " the <unk> of his mother , who 
Output : himself ( even ) this king , will <unk> the plant which the Twelve Apostles of the Beloved have planted . Of the Twelve one will be delivered into his hands " was interpreted to mean <unk> . 
Prediction:  is a <unk>, and the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the


Input: = = = <unk> 
Output : = = = 
Prediction:  
,


Input: = = Reactions of <unk> = 
Output : = 
Prediction:  


Input: The loss snapped Butler 's 25 @-@ game winning streak , the longest in school history . Butler became the smallest school to play for a National Championship since Jacksonville in 1970 . Stevens becam

In [79]:
test_predict_text_on_random_examples(tasks[0], deterministic=False)

Input: = = = <unk> <unk> Records and New Album Evolution ( 
Output : 2006 – 2015 ) = = = 
Prediction:  August challenge with an American Soldier visas


Input: After marrying Robin <unk> , brother of artist Gloria <unk> , <unk> <unk> moved to the region of <unk> , north @-@ east of Alice Springs , which is where she was living when she began painting around 1990 . They had seven children , one of whom , <unk> <unk> , went on to become an artist like his mother . By 2008 , <unk> 
Output : <unk> 's husband had died , and <unk> was dividing her time between Alice Springs and <unk> Range , to its north @-@ east . 
Prediction: . A touchdowns — " The book gave her points out for the primaryrate <unk> ", respectively. Art, she


Input: The continuous shadows in the south polar craters cause the floors of these formations to maintain a temperature that never exceeds about 100 K. For <unk> , the average temperature was determined to be about 90 K , reaching 88 K at the crater floor . 
Output : 

### Nucleus sampling

In [93]:
def test_new_predict_text(input_text, max_length=20, deterministic=True, temperature=1.0, top_p=1.0, context_length=1024):
    tokenized_outputs = model.text_tokenizer(input_text, truncation=True, padding="longest", max_length=context_length, return_tensors='pt')
    input_tokens = tokenized_outputs['input_ids']
    predicted_tokens = input_tokens.clone()

    for _ in range(max_length):
        token_embeddings = model.embed_token(predicted_tokens.to(device))
        token_masks = torch.ones((predicted_tokens.shape[0], 1), device=device)

        logits, _ = model.forward(token_embeddings=token_embeddings, tokens=predicted_tokens, token_masks=token_masks, token_target_masks=None)
        logits = logits[:, -1, :] / temperature  # Apply temperature scaling

        if deterministic:
            next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
        else:
            # Apply nucleus (top-p) filtering
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = 0

            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            logits[indices_to_remove] = float('-inf')

            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)

        predicted_tokens = torch.cat([predicted_tokens.to(device), next_token.to(device)], dim=1)

    return predicted_tokens[:, input_tokens.size(1):]

In [155]:
def test_predict_text_on_random_examples_with_nucleus(task, num_of_examples_to_test=10, deterministic=False, temperature=1.0, top_p=1.0, split='test'):    
    model.eval()
    with torch.no_grad():
        dataset_split = task.text_dataset['test']
        
        sampled_indices = torch.randperm(len(dataset_split))[:num_of_examples_to_test]
        samples = dataset_split.select(sampled_indices)
        
        for sample in samples:
            actual_text = sample['text']
            # roughly speaking...splitting by spaces
            words_list = actual_text.split()
            if len(words_list) > 1:
                split_index = random.randint(1, len(words_list)-1)
                input_text, target_text = ' '.join(words_list[:split_index]), ' '.join(words_list[split_index:])  
                pred_tokens = test_new_predict_text(input_text='Hello how are', max_length=len(words_list[split_index:]), deterministic=deterministic, temperature=temperature, top_p=top_p)
                decoded_target = task.text_tokenizer.decode(pred_tokens.squeeze(), skip_special_tokens=True)
                print(f'Input: {input_text} \nOutput : {target_text} \nPrediction: {decoded_target}\n\n')

In [100]:
test_predict_text_on_random_examples_with_nucleus(tasks[0], deterministic=False)

Input: = = = = Public transportation = = = 
Output : = 
Prediction:  shot


Input: The discovery of a colossal head at <unk> <unk> in the nineteenth century spurred the first archaeological investigations 
Output : of <unk> culture by Matthew <unk> in 1938 . Seventeen confirmed examples are known from four sites within the <unk> <unk> on the Gulf Coast of Mexico . Most colossal heads were sculpted from spherical boulders but two from San Lorenzo <unk> were re @-@ carved from massive stone <unk> . An additional monument , at <unk> <unk> in Guatemala , is a throne that may have been carved from a colossal head . This is the only known example from outside the <unk> <unk> . 
Prediction:  Do Army's allowed to result it building withReason with a widely first technology record came – 620 murder. The later was each of 17 musicalph favorable reviews from their country.93 in get only three prince after 200 @-@ containing a memorableDánhised by audiences. It change the confusion, statement L Os

In [104]:
test_predict_text_on_random_examples_with_nucleus(tasks[0], deterministic=False, temperature=0.7, top_p=0.8)

Input: Operation <unk> , the Allied invasion of French North Africa in November 1942 , was coordinated from the " Rock " . General Dwight D. Eisenhower , who was given command of the operation , set up his headquarters in Gibraltar during the planning phases of the operation . Following the successful completion of the North African campaign and the surrender of Italy in 1943 , Gibraltar 's role shifted from a forward operating base to a rear @-@ area supply position . The harbour continued to operate dry docks and supply depots for the convoy routes through the 
Output : Mediterranean until V @-@ E Day in 1945 . 
Prediction:  the first to a self @-@ <


Input: = = Recent 
Output : events = = 
Prediction:  the world,


Input: Family 
Output : <unk> 
Prediction:  the


Input: Manila has six representative districts for the lower house of the Philippine Congress . Furthermore , the city is composed of 16 districts , namely : <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , Port Area , <un

In [154]:
# lower top-p : less diversity in choice of next words, higher top-p large set of next possible words
# low temp : more deterministic, less diverse. 
test_predict_text_on_random_examples_with_nucleus(tasks[0], deterministic=False, temperature=0.84, top_p=0.89)

Input: <unk> began his reign in 54 by promising the Senate more autonomy . In this first year , he forbade others to refer to 
Output : him with regard to <unk> , for which he was praised by the Senate . <unk> was known for spending his time visiting <unk> and <unk> during this period . 
Prediction:  his name to his character in his Christian take his name, in a market, he was the whose English style of the film that was ", a


Input: <unk> Place Manila is the largest shopping mall in the city . The mall was the second and by @-@ far , the largest Robinson Mall ever built by John <unk> . SM <unk> maintains presence in the city . One of their shopping mall is the SM City Manila , the first SM <unk> in the city featuring major SM brands like The SM Store , SM <unk> , SM <unk> and SM <unk> . It is located right beside the Manila City Hall . SM City San <unk> is the second SM <unk> in Manila . It is located in Santa Cruz . SM City San <unk> was constructed on the site of the former San <un

In [195]:
# lower top-p : less diversity in choice of next words, higher top-p large set of next possible words
# low temp : more deterministic, less diverse. 
test_predict_text_on_random_examples_with_nucleus(tasks[0], deterministic=False, temperature=0.7, top_p=0.5, split='test')

Input: Canadian Agency is headed by a secretary @-@ general and responsible for Canada , the entire Americas ( including the 
Output : Caribbean ) 
Prediction:  the top


Input: In mid @-@ 1941 , the Royal Armoured Corps in Britain created three tank squadrons for special overseas operations , known as ' A ' , ' B ' and ' C ' Special Service Squadrons . Both ' A ' and ' B ' Squadrons were equipped with Valentine Infantry tanks and Mark <unk> light tanks , but ' C ' Squadron was equipped with twelve <unk> transferred from the 2nd Armoured Brigade , 1st Armoured Division . On 31 July 1941 , ' C ' Squadron was officially activated and immediately received orders to prepare for overseas service alongside ' A ' and ' B ' Squadrons in an unspecified tropical climate . All three squadrons were transported to <unk> in Scotland for intensive training that focused 
Output : on embarkation and <unk> from ships and landing craft to prepare them for action in potential amphibious operations . In ea

In [150]:
def find_good_ones(task, num_of_examples_to_test=10, deterministic=False, temperature=1.0, top_p=1.0):    
    model.eval()
    with torch.no_grad():
        dataset_split = task.text_dataset['test']
        
        sampled_indices = torch.randperm(len(dataset_split))[:num_of_examples_to_test]
        samples = dataset_split.select(sampled_indices)

        for temperature, top_p in [(0.5,0.5),(0.5,0.7),(0.5,0.9),(0.75,0.75),(0.75,0.9),(0.8,0.7),(0.8,0.9),(0.9,0.5),(0.9,0.75),(0.9,0.9),(1.0,0.5),(1.0,0.75),(1.0,0.9),(1.1,0.9)]:
            print('--'*30)
            print(f'Temperature :  {temperature} || Top_p : {top_p}')
            for sample in samples:
                actual_text = sample['text']
                # roughly speaking...splitting by spaces
                words_list = actual_text.split()
                if len(words_list) > 1:
                    split_index = random.randint(1, len(words_list)-1)
                    input_text, target_text = ' '.join(words_list[:split_index]), ' '.join(words_list[split_index:])  
                    pred_tokens = test_new_predict_text(input_text='Hello how are', max_length=len(words_list[split_index:]), deterministic=deterministic, temperature=temperature, top_p=top_p)
                    decoded_target = task.text_tokenizer.decode(pred_tokens.squeeze(), skip_special_tokens=True)
                    print(f'[Input]: {input_text} \n[Output]: {target_text} \n[Prediction]: {decoded_target}\n\n\n')
        print('='*80)

In [151]:
find_good_ones(tasks[0], deterministic=False)

------------------------------------------------------------
Temperature :  0.5 || Top_p : 0.5
[Input]: A number of design faults of the <unk> were revealed through its operational use . Its size limited the possible crew to three , a driver in the hull and a gunner and commander in the turret , resulting in too few crew members to operate the <unk> effectively . The gunner or commander , in addition to his own duties , had to act as <unk> for the 2 pounder , which caused delays in combat . A report on the tank written in January 1941 stated that as the commander had to both fight and control the tank , controlling a troop of <unk> during combat would be 
[Output]: almost impossible . 
[Prediction]:  a <unk



[Input]: = = Habitat and ecology = 
[Output]: = 
[Prediction]:  a



[Input]: On the regimental left along the main <unk> @-@ <unk> @-@ <unk> road North Korean soldiers completely overran C Company by 0300 September 1 . Only seven men of C Company could be accounted for , and thr

## Things pending

*Integration*
- Fixing text inside the current way codebase is written

*Benchmarking*:
- Benchmark on PILE
- Benchmark on treepenn
- Add a way to use different text dataset for eval, and diff for training (perplexity on wikitext after training on penn treebank)

[[2nd half]]
- Varying batch_size, params, dropout, etc. -- see what's the lowest we can do?

*Deployment*:
- A way to easily load a trained model
- Save checkpoints with lowest perplexity
- Deploy via gradio
- Allow different kinds of sampling
- Be able to predict

# Rough - ignore

In [73]:
with torch.no_grad():
    model.eval()
    predicted_tokens = test_predict_text('hello how', max_length=20, deterministic=False)

In [74]:
model.text_tokenizer.decode(predicted_tokens.squeeze(), skip_special_tokens=True)

'cription. RBI, 28 Wat Lists in remain, area notedoca. species the Socrates, Kurd 13'

In [29]:
with torch.no_grad():
    for task in tasks:
        eval_logs = {}
        if isinstance(task, TextTask):
            eval_logs = task.evaluate(model, num_examples_to_test=args.eval_text_num_examples, deterministic=deterministic, log_examples_to_output=args.eval_text_log_examples)
            print(eval_logs)


{'loss': 6.181807041168213, 'perplexity': 483.86553955078125}


In [33]:
test_tokenized_outputs = model.text_tokenizer('Hello how are', truncation=True, padding="longest", max_length=args.sequence_length, return_tensors='pt')

In [35]:
test_tokenized_outputs

{'input_ids': tensor([[15496,   703,   389]]), 'attention_mask': tensor([[1, 1, 1]])}

In [38]:
test_token_embeddings = model.embed_token(test_tokenized_outputs['input_ids'].to(device))

In [39]:
test_token_embeddings.shape

torch.Size([1, 3, 768])

In [44]:
test_token_masks = torch.ones((test_tokenized_outputs['input_ids'].to(device).shape[0], 1), device=device)

In [46]:
test_token_masks.shape

torch.Size([1, 1])

In [49]:
logits, _ = model.forward(token_embeddings=test_token_embeddings, tokens=test_tokenized_outputs['input_ids'].to(device), token_masks=test_token_masks, token_target_masks=None)

In [50]:
logits = logits[:, -1, :]

In [51]:
det_next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)

In [52]:
probs = torch.nn.functional.softmax(logits, dim=-1)
nondet_next_token = torch.multinomial(probs, 1)

In [54]:
decoded_det = model.text_tokenizer.decode(det_next_token.squeeze(), skip_special_tokens=True)

In [55]:
decoded_nondet = model.text_tokenizer.decode(nondet_next_token.squeeze(), skip_special_tokens=True)

In [56]:
decoded_det

','

In [57]:
decoded_nondet

' Museum'