In [1]:
import torch

In [3]:
print('hello')

hello


In [6]:
!pwd

/home/bhavul/bhavul/NEKO/dev_notebooks


In [7]:
!export PYTHONPATH="${PYTHONPATH}:/home/bhavul/bhavul/NEKO/"

In [9]:
import sys
sys.path.insert(0, '/home/bhavul/bhavul/NEKO/')

# Imports

In [10]:
import time
import os

import wandb
import numpy as np
import torch
from gato.utils.utils import save_model

In [11]:
import random
import os
from datetime import datetime

import wandb
import torch

from peft import LoraConfig, TaskType, get_peft_model
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
from accelerate import DataLoaderConfiguration


  from .autonotebook import tqdm as notebook_tqdm


# Provide Args

# Text Task

In [135]:
from __future__ import annotations
# supports dataset in huggingface datasets library for now
from datasets import load_dataset, concatenate_datasets
from gato.tasks.task import Task
import numpy as np
from torch import nn
from typing import TYPE_CHECKING, List,Dict
from transformers import AutoTokenizer
import torch
import copy

In [225]:
class TextTask(Task): 
    def __init__(self, dataset_names:List[str], dataset_paths:List[str], context_length:int, tokenizer_model:str):
        super().__init__()
        self.context_length = context_length
        self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
        text_datasets_list = []
        assert len(dataset_names) == len(dataset_paths), "The dataset names and paths parameters should have corresponding values and hence equal lengths"
        for i, text_dataset in enumerate(dataset_names):
            text_datasets_list.append(load_dataset(path=dataset_paths[i], name=text_dataset))
        if len(text_datasets_list) == 1:
            self.text_dataset = text_datasets_list[0]
        else:            
            # https://huggingface.co/docs/datasets/v2.14.4/en/process#concatenate
            # must have the same feature columns
            self.text_dataset = concatenate_datasets(text_datasets_list)

    def sample_batch(self, batch_size, is_test=False) -> List[Dict]:
        split = 'train' if not is_test else 'test'
        dataset_split = self.text_dataset[split]
        sampled_indices = torch.randperm(len(dataset_split))[:batch_size]
        samples = dataset_split.select(sampled_indices)
        tokenized_outputs = self.text_tokenizer(samples['text'], truncation=True, max_length=self.context_length, return_tensors='pt')
    
        batch_dicts = []
        for input_ids in tokenized_outputs["input_ids"]:
            if input_ids.numel() > 0:  # Check if non-empty
                # Split into input and target tokens
                input_tokens = input_ids[:-1]
                target_tokens = input_ids[1:]
                batch_dicts.append({
                    'text': input_tokens,
                    'target': target_tokens,
                    'images': None,
                    'continuous_obs': None,
                    'discrete_obs': None,
                    'continuous_actions': None,
                    'discrete_actions': None
                })
    
        return batch_dicts
        
    def sample_batch_old(self, batch_size, is_test=False)->List[Dict]:
        """Gets called during training and test both, fetch as many examples as batch_size param."""
        split = 'train' if not is_test else 'test'
        random_indices = np.random.randint(0, len(self.text_dataset[split]), size=batch_size)
        tokenized_outputs = self.text_tokenizer(self.text_dataset[split][random_indices]['text'], truncation=True,
            max_length=self.context_length,
            return_overflowing_tokens=True,
            return_length=True)
        
        batch_dicts = []
        count = 0 
        # todo - ii. vectorize this? Also do we wanna impose any length constraint?
        for length, input_ids in zip(tokenized_outputs["length"], tokenized_outputs["input_ids"]):
            # we only pick non-empty string examples right now
            if length > 0:
                batch_example_dict = {
                    'text': input_ids,  # A list of tokens
                    'images': None,
                    'continuous_obs': None,
                    'discrete_obs': None,
                    'continuous_actions': None,
                    'discrete_actions': None
                }
                batch_dicts.append(batch_example_dict)
                count += 1
                if count == batch_size:
                    break
        
        return batch_dicts

    def evaluate(self, model: GatoPolicy, num_examples_to_test=50, deterministic=False, log_examples_to_output=False, is_test=True):
        split = 'train' if not is_test else 'test'
        dataset_split = self.text_dataset[split]
        num_examples_to_test = min(num_examples_to_test, len(dataset_split))
    
        if num_examples_to_test == 0:
            return {'loss': float('nan'), 'perplexity': float('nan')}
    
        batch_dicts = self.sample_batch(num_examples_to_test, is_test)
        total_loss, total_tokens = 0.0, 0
    
        for batch_dict in batch_dicts:
            tokens = batch_dict['text']
            ith_position = np.random.randint(1, len(tokens))
            input_tokens = tokens[:ith_position]
            target_tokens = tokens[ith_position:]
    
            pred_logits, pred_tokens = model.predict_text({'text': input_tokens}, max_length=len(target_tokens), deterministic=deterministic)
    
            target_tokens_tensor = torch.tensor(target_tokens, dtype=torch.long, device=model.device)
            loss = torch.nn.functional.cross_entropy(pred_logits[:, -len(target_tokens):], target_tokens_tensor.unsqueeze(0))
            total_loss += loss.item()
            total_tokens += len(target_tokens)
    
            if log_examples_to_output:
                print(f'Input: {self.text_tokenizer.decode(input_tokens)} | Target: {self.text_tokenizer.decode(target_tokens)} | Prediction: {self.text_tokenizer.decode(pred_tokens)}')
    
        avg_loss = total_loss / len(batch_dicts)
        perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
        return {'loss': avg_loss, 'perplexity': perplexity}
        
    def evaluate_old(self, model: GatoPolicy, num_examples_to_test=50, deterministic=False, log_examples_to_output=False, is_test=True):
        tokenizer = model.module.text_tokenizer
        loss_fn = nn.CrossEntropyLoss()
        total_loss = 0
        total_tokens = 0
    
        split = 'train' if not is_test else 'test'
        if num_examples_to_test > len(self.text_dataset[split]):
            print(f'num_examples_to_test chosen is more than test examples, so setting it to whole test dataset.')
            num_examples_to_test = len(self.text_dataset[split])
    
        if log_examples_to_output:
            print(f'--- examples ---')
        
        batch_dicts = self.sample_batch(num_examples_to_test, is_test)
        print(f'Num of examples to test : {num_examples_to_test} | Actual batch size of eval data : {len(batch_dicts)}')
        
        actual_examples_tested = 0
        for idx in range(min(num_examples_to_test, len(batch_dicts))):
            batch_dict = batch_dicts[idx]
            
            # Split the tokens into input and target tokens
            tokens = batch_dict['text']
            ith_position = np.random.randint(1, len(tokens))
            input_tokens = tokens[:ith_position]
            target_tokens = tokens[ith_position:]
    
            new_batch_dict = copy.deepcopy(batch_dict)
            new_batch_dict['text'] = input_tokens
    
            # Generate prediction
            pred_logits, pred_tokens = model.module.predict_text(new_batch_dict, max_length=len(target_tokens), deterministic=deterministic)
            
            if log_examples_to_output and idx%50==0:
                print(f'Text Example : {tokenizer.decode(batch_dict["text"])} \n Input passed to model : {tokenizer.decode(new_batch_dict["text"])} \n Predicted output : {tokenizer.decode(pred_tokens)}')
                print("----")
    
            # Calculate loss
            target_tokens_tensor = torch.tensor(target_tokens, dtype=torch.long, device=model.device)
            loss = loss_fn(pred_logits[:, -len(target_tokens):, :], target_tokens_tensor.unsqueeze(0))
            total_loss += loss.item()
            total_tokens += len(target_tokens)
            actual_examples_tested += 1
            
        if log_examples_to_output:
            print(f'--- examples end ---')
    
        avg_loss = total_loss / actual_examples_tested
        perplexity = torch.exp(torch.tensor(avg_loss))
    
        metrics = {
            'loss': avg_loss,
            'perplexity': perplexity.item()
        }
        return metrics

## Gato Policy

In [226]:
from __future__ import annotations
from typing import Optional, Union, TYPE_CHECKING
import torch
import torch.nn as nn

import transformers
from transformers import AutoTokenizer

# import gato
from gato.transformers import GPT2Model

In [227]:
class GatoPolicy(nn.Module):
    def __init__(
        self,
        device: Union[torch.device, str],
        embed_dim: int,
        layers: int,
        heads: int,
        dropout: float,

        activation_fn='gelu',

        mu: int = 100,
        M: int = 256,

        patch_size: int = 16,
        resid_mid_channels: int = 132,
        num_groups: int = 32,
        position_vocab_size: int = 128,
        continuous_tokens: int = 1024,
        discrete_tokens: int = 1024,

        context_len=1024,

        use_pos_encoding: bool = True,
        use_patch_pos_encoding: bool = True,

        pretrained_lm: Optional[str] = None, # Optional, name of pretrained language model to use
        flash: bool = False, # TODO verify correctness
        tokenizer_model_name: str = 'gpt2',
        pad_seq: bool = False
    ):
        super().__init__()
        self.device = device
        self.context_len = context_len
        
        # Text Tokenizer
        self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)        
        # tokens
        self.vocab_size = self.text_tokenizer.vocab_size 
        

        if pretrained_lm is not None:
            print(f'loading pretrained GPT2 weights')
            config = transformers.GPT2Config.from_pretrained(pretrained_lm)
            config.attn_pdrop = dropout # 0.1
            config.resid_pdrop = dropout
            self.transformer = GPT2Model.from_pretrained(
                pretrained_lm,
                config=config,
            )
            embed_dim = config.n_embd
            assert self.transformer.wte.weight.shape[0] == self.text_tokens, "pretrained token/expected mimsatch" # potentially make text_tokens dynamic
        else:
            gate = False
            if activation_fn == 'geglu':
                gate = True
                activation_fn = 'gelu'
            config = transformers.GPT2Config(
                vocab_size=1,  # doesn't matter -- we don't use the vocab
                n_embd=embed_dim,
                n_head=heads,
                n_layer=layers,
                resid_pdrop=dropout,
                attn_pdrop=dropout,
                n_positions=context_len,
                n_inner=embed_dim * 4,
                activation_function=activation_fn,
            )
            config.n_ctx = context_len
            config.gate = gate
            config.flash = flash
            self.transformer = self.transformer = GPT2Model(config)
        
        # embedding tokens
        self.embed_token = nn.Embedding(self.vocab_size, embed_dim)
        if pretrained_lm is not None:
            self.embed_token.weight.data[:] = self.transformer.wte.weight.data
        
        
        # head
        self.predict_token = nn.Linear(embed_dim, self.vocab_size, bias=False)
        self.separator_token = nn.Parameter(torch.zeros(embed_dim))

    @property
    def module(self):
        return self

    def forward(self, inputs: Optional[list]=None, compute_loss=False, **kwargs):
        # tokenize inputs
        if inputs is not None:
            token_embeddings, tokens, token_target_masks, token_masks = self.tokenize_input_dicts(inputs)
        else:
            token_embeddings = kwargs['token_embeddings']
            tokens = kwargs['tokens']
            token_target_masks = kwargs['token_target_masks']
            token_masks = kwargs['token_masks']

        assert token_embeddings is not None, "token_embeddings is None"
        assert token_masks is not None, "token_masks is None"

        final_representations = self.transformer(inputs_embeds=token_embeddings, attention_mask=token_masks)['last_hidden_state']
        logits = self.predict_token(final_representations)
        
        if compute_loss and 'target' in kwargs:
            target_tokens = kwargs['target']
            loss_masks = (target_tokens != self.text_tokenizer.pad_token_id).float()
            loss = torch.nn.functional.cross_entropy(logits.view(-1, self.vocab_size), target_tokens.view(-1), reduction='none')
            loss = (loss * loss_masks.view(-1)).sum() / loss_masks.sum()
        else:
            loss = None
    
        return logits, loss
        
    
    # predicts next token (for each input token)
    def forward_old(self, inputs: Optional[list]=None, compute_loss=False, **kwargs):
        # tokenize inputs
        if inputs is not None:
            token_embeddings, tokens, token_target_masks, token_masks = self.tokenize_input_dicts(inputs)
        else:
            token_embeddings = kwargs['token_embeddings']
            tokens = kwargs['tokens']
            token_target_masks = kwargs['token_target_masks']
            token_masks = kwargs['token_masks']

        # pass to transformer
        #final_representations = self.transformer(x = token_embeddings, custom_mask = token_masks, batch_first=True)
        final_representations = self.transformer(inputs_embeds=token_embeddings, attention_mask=token_masks)['last_hidden_state']
        logits = self.predict_token(final_representations)

        if compute_loss:
            loss_logits = logits[:, :-1, :].contiguous().view(-1, self.vocab_size)
            target_tokens = tokens[:, 1:].contiguous().view(-1)
            loss_masks = token_target_masks[:, 1:].contiguous().view(-1)
            loss = torch.nn.functional.cross_entropy(loss_logits, target_tokens, reduction='none')
            loss = (loss * loss_masks).sum() / loss_masks.sum()
        else:
            loss = None

        return logits, loss

    def tokenize_input_dicts(self, inputs: list):
        # if not inputs:
        #     return None, None, None, None
    
        batch_len = len(inputs)
        max_tokens = max(len(batch['text']) for batch in inputs)
    
        # Allocate tensors
        token_embeddings = torch.zeros((batch_len, max_tokens, self.embed_token.embedding_dim), device=self.device)
        tokens = torch.zeros((batch_len, max_tokens), dtype=torch.long, device=self.device)
        token_target_masks = torch.zeros((batch_len, max_tokens), device=self.device)
        token_masks = torch.zeros((batch_len, max_tokens), device=self.device)
    
        for i, batch in enumerate(inputs):
            text_tokens = torch.tensor(batch['text'], dtype=torch.long, device=self.device)
            n_timesteps = len(text_tokens)
    
            tokens[i, :n_timesteps] = text_tokens
            token_embeddings[i, :n_timesteps] = self.embed_token(text_tokens)
            token_target_masks[i, 1:n_timesteps] = 1  # Target mask should start from second token to avoid prediction of first token
            token_masks[i, :n_timesteps] = 1

        print("token_embeddings shape:", token_embeddings.shape)  # Debug print
        print("token_masks shape:", token_masks.shape)  # Debug print
    
        return token_embeddings, tokens, token_target_masks, token_masks

    def tokenize_input_dicts_old(self, inputs: list):
        """"
        inputs: list of dicts for a batch
        [
            {
                # observations
                text: T x L  or None
                
            },
            ...
            {
            }
        ]

        returns: the tokens_id, tokens_embedding for each batch respectively
        """
        batch_len = len(inputs)

        token_embeddings = []
        tokens = []
        token_target_masks = []
        max_tokens = -1 # max number of timesteps across all batches

        for batch in inputs:

            text_tokens = torch.tensor(batch['text'], dtype=torch.long, device=self.device).unsqueeze(0)

            # Split the tokens into input and target tokens
            input_tokens = text_tokens[:, :-1]
            target_tokens = text_tokens[:, 1:]
        
            text_embeddings = self.embed_token(text_tokens)
            text_targets_masks = torch.ones_like(target_tokens)
            n_timesteps = text_tokens.shape[1]

            batch_embeddings = text_embeddings
            batch_tokens = input_tokens
            batch_target_masks = text_targets_masks

            token_embeddings.append(batch_embeddings)
            tokens.append(batch_tokens)
            token_target_masks.append(batch_target_masks)
            max_tokens = max(max_tokens, batch_embeddings.shape[1])

        token_masks = []
        # (left pad) to max tokens
        for i in range(batch_len):
            pad_len = max_tokens - token_embeddings[i].shape[1]
            token_masks.append(torch.cat([torch.ones_like(token_embeddings[i]), torch.zeros(1, pad_len, device=self.device)], dim=1))
            token_embeddings[i] = torch.cat([token_embeddings[i], torch.zeros(1, pad_len, self.embed_token.embedding_dim, device=self.device)], dim=1)
            tokens[i] = torch.cat([tokens[i], torch.zeros(1, pad_len, dtype=torch.long, device=self.device)], dim=1)
            token_target_masks[i] = torch.cat([token_target_masks[i], torch.zeros(1, pad_len, device=self.device)], dim=1)

        # Check if token_embeddings list is not empty before concatenating
        if token_embeddings:
            token_embeddings = torch.cat(token_embeddings, dim=0)
            tokens = torch.cat(tokens, dim=0)
            token_target_masks = torch.cat(token_target_masks, dim=0)
            token_masks = torch.cat(token_masks, dim=0)
        else:
            token_embeddings = None
            tokens = None
            token_target_masks = None
            token_masks = None
        
        # if self.pad_seq:
        #     # get seq length
        #     seq_len = token_embeddings.shape[1]
        #     pad_len = self.context_len - seq_len
        #     if pad_len > 0:
        #         token_embeddings = torch.cat([token_embeddings, torch.zeros(batch_len, pad_len, self.embed_dim, device=self.device)], dim=1)
        #         tokens = torch.cat([tokens, torch.zeros(batch_len, pad_len, dtype=torch.long, device=self.device)], dim=1)
        #         token_target_masks = torch.cat([token_target_masks, torch.zeros(batch_len, pad_len, device=self.device)], dim=1)
        #         token_masks = torch.cat([token_masks, torch.zeros(batch_len, pad_len, device=self.device)], dim=1)
        return token_embeddings, tokens, token_target_masks, token_masks


    def predict_text(self, batch_dict, max_length=20, deterministic=True, top_p=0.9):
        input_tokens = torch.tensor(batch_dict['text'], dtype=torch.long, device=self.device).unsqueeze(0)
        
        predicted_tokens = []
    
        for _ in range(max_length):
            token_embeddings = self.embed_token(input_tokens)
            token_masks = torch.ones_like(input_tokens)

            logits, _ = self.forward(token_embeddings=token_embeddings, tokens=input_tokens, token_target_masks=None, token_masks=token_masks)
            logits = logits[:, -1, :]  # focus on the last time-step logits
    
            if deterministic:
                token = torch.argmax(logits, dim=-1)
            else:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                sorted_logits[sorted_indices_to_remove] = 0
                sorted_probs = torch.nn.functional.softmax(sorted_logits, dim=-1)
                token = torch.multinomial(sorted_probs, 1)[0]
                token = sorted_indices[torch.where(sorted_probs == token)]
    
            predicted_tokens.append(token.item())
    
            if token == self.text_tokenizer.eos_token_id:
                break

            input_tokens = torch.cat([input_tokens, next_token.unsqueeze(0)], dim=1)  # Append the predicted token

            # Ensure that sequence does not exceed context_len
            # if input_tokens.size(1) > self.context_len:
            #     input_tokens = input_tokens[:, -self.context_len:]
            #     token_embeddings = token_embeddings[:, -self.context_len:, :]
            #     token_masks = token_masks[:, -self.context_len:]
    
        return logits, predicted_tokens

    def predict_text_old(self, batch_dict, max_length:int=20, deterministic:bool=True, top_p:float=0.9):
        input_tokens = torch.tensor(batch_dict['text'], dtype=torch.long, device=self.device).unsqueeze(0)
        token_embeddings = self.embed_token(input_tokens)
        token_masks = torch.ones_like(input_tokens)
        predicted_tokens = []
        
        # predict tokens, sampling or deterministically picking best token
        for i in range(max_length):
            logits, _ = self.forward(token_embeddings=token_embeddings, tokens=input_tokens, token_target_masks=None, token_masks=token_masks)
            logits = logits[0, -1, :]
    
            if deterministic:
                token = torch.argmax(logits, dim=-1)
            else:
                probs = torch.nn.functional.softmax(logits, dim=-1)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                sorted_probs[sorted_indices_to_remove] = 0.0
                sorted_probs /= sorted_probs.sum()
                token = torch.multinomial(sorted_probs, num_samples=1)[0]
                token = sorted_indices[token]
    
            predicted_tokens.append(token.item())
    
            if token == self.text_tokenizer.eos_token_id:
                break
            
            input_tokens = torch.cat([input_tokens, token.unsqueeze(0).unsqueeze(0)], dim=1)
            token_embeddings = self.embed_token(input_tokens)
            token_masks = torch.ones_like(input_tokens)
    
            input_tokens = input_tokens[:, -self.context_len:]
            token_embeddings = token_embeddings[:, -self.context_len:, :]
            token_masks = token_masks[:, -self.context_len:]

        return logits, predicted_tokens

## Trainer

In [228]:
class Trainer:
    def __init__(
        self,
        model,
        optimizer,
        accelerator,
        scheduler,
        tasks,
        exp_name,
        args
    ):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.accelerator = accelerator
        self.tasks = tasks
        self.args = args
        self.print_logs = True # args.print_logs
        self.device = torch.device(args.device)

        self.min_lr = self.args.learning_rate / self.args.min_factor
        self.deterministic = self.args.eval_mode == 'deterministic'

        self.exp_name = exp_name
        self.exp_dir = os.path.join(self.args.save_dir, self.exp_name)

        self.steps = 0
        self.start_time = None

    def train(self):
        self.start_time = time.time()
        iters = self.args.training_steps // self.args.log_eval_freq
        for i in range(iters):
            logs = self.train_iteration(self.args.log_eval_freq, i)
            self.accelerator.log(logs)

        ## Save model at end of training only if not saving checkpoints
        if self.args.save_model and self.args.save_mode == 'last':
            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process:
                unwrapped_model = self.accelerator.unwrap_model(self.model)
                save_model(unwrapped_model, self.exp_dir, f'checkpoint_{self.steps}', self.args)

        self.accelerator.end_training()


    def train_iteration(self, num_steps, iter):
        logs = {}

        train_start = time.time()

        train_losses = []

        self.model.train()
        for i in range(num_steps):
            self.steps += 1
            train_loss, step_logs = self.train_step()
            train_losses.append(train_loss)

        # add logs from last train_step as well
        for log in step_logs:
            logs[log] = step_logs[log]

        logs['time/training'] = time.time() - train_start

        eval_start = time.time()
        self.model.eval()

        # loop over eval for each env
        with torch.no_grad():
            for task in self.tasks:
                eval_logs = {}
                if isinstance(task, TextTask):
                    eval_logs = task.evaluate(self.model, num_examples_to_test=self.args.eval_text_num_examples, deterministic=self.deterministic, log_examples_to_output=self.args.eval_text_log_examples)
                    for k, v in eval_logs.items():
                        logs[f'evaluation/text/{k}'] = v
                    pass

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        if self.accelerator.is_main_process:
            if self.print_logs:
                print('=' * 80)
                print(f'Iteration {iter}')
                for k, v in logs.items():
                    print(f'{k}: {v}')
                print('=' * 80)

        ## Save model
        if self.args.save_model and self.args.save_mode == 'checkpoint':
            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process:
                unwrapped_model = self.accelerator.unwrap_model(self.model)
                save_model(unwrapped_model, self.exp_dir, f'checkpoint_{self.steps}', self.args)

        return logs

    def train_step(self):
        logs = {}
        logs['training/learning_rate'] = self.scheduler.get_lr()[0] # store LR at current step
        # Build training batch
        start_time = time.time()

        # Calculate batch size for each task, the following need to be revised to including more new tasks
        text_batch_size = int(self.args.text_prop * self.args.batch_size)
        remainder = self.args.batch_size - text_batch_size

        if remainder > 0: 
            text_batch_size += remainder

        assert self.args.batch_size == text_batch_size, "Total batch size is not eqaual to the sum of each task's batch size" 

        text_batch_dicts = []

        # Sample text and control batches
        if text_batch_size > 0:
            text_batch_dicts = self.sample_text_batch(text_batch_size)

        assert text_batch_dicts, "Batch dicts is empty"
        print(f'text_batch_size:{text_batch_size}')


        logs['time/sample_batch'] = time.time() - start_time
        with self.accelerator.accumulate(self.model):
            # Compute loss and update model
            logits, loss = self.model.forward(inputs = text_batch_dicts, compute_loss=True)
            self.accelerator.backward(loss)

            if not self.args.disable_grad_clip and self.accelerator.sync_gradients:
                self.accelerator.clip_grad_norm_(self.model.parameters(), self.args.grad_norm_clip)

            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()

        return loss.detach().cpu().item(), logs

    def sample_text_batch(self, batch_size):
        batch_dicts = []
        text_tasks = [t for t in self.tasks if isinstance(t, TextTask)]
        for i,task in enumerate (text_tasks):
            batch_dicts.extend(task.sample_batch(batch_size))
        return batch_dicts

## Arguments

In [229]:
args = TrainingArgs(
    training_steps=1,
    log_eval_freq=1,
    warmup_steps=1,
    batch_size=16,
    eval_episodes=1,
    text_prop=1,
    eval_text_log_examples=True,
    pretrained_lm=None,
    text_datasets=['wikitext-2-v1'],
    text_datasets_paths=['wikitext'],
    learning_rate=0.00001,
    use_wandb=False
)

# train.py

In [230]:
from gato.training.trainer import Trainer
from gato.training.schedulers import get_linear_warmup_cosine_decay_scheduler

In [231]:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if args.use_wandb:
    log_with = 'wandb'
else:
    log_with = None

dl_config = DataLoaderConfiguration(split_batches=True)
accelerator = Accelerator(
    cpu=args.cpu,
    dataloader_config=dl_config, 
    mixed_precision=args.mixed_precision,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    kwargs_handlers=[ddp_kwargs],
    log_with=log_with
)
args.device = accelerator.device.type

exp_date = datetime.now().strftime('%y-%m-%d_%H-%M-%S')
exp_name = f'neko-gato_{exp_date}'

In [232]:
tasks = []

In [233]:
tasks.append(TextTask(args.text_datasets, args.text_datasets_paths, args.sequence_length, tokenizer_model=args.tokenizer_model_name))

In [234]:
model = GatoPolicy(
        device=args.device,
        embed_dim=args.embed_dim,
        layers=args.layers,
        heads=args.heads,
        dropout=args.dropout,
        mu=args.mu,
        M=args.M,
        patch_size=args.patch_size,
        resid_mid_channels=args.resid_mid_channels,
        continuous_tokens=args.continuous_tokens,
        discrete_tokens=args.discrete_tokens,
        context_len=args.sequence_length,
        use_patch_pos_encoding=not args.disable_patch_pos_encoding,
        use_pos_encoding=not args.disable_inner_pos_encoding,
        activation_fn=args.activation_fn,
        pretrained_lm=args.pretrained_lm,
        flash=args.flash,
        tokenizer_model_name=args.tokenizer_model_name,
        pad_seq=args.pad_seq,
    )

In [235]:
model = accelerator.prepare(model)

In [236]:
# print trainable parameters
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Trainable Parameters:', '{}M'.format(params / 1e6))
args.trainable_params = params

Trainable Parameters: 133.9008M


In [237]:
model.device = args.device

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.beta_1, args.beta_2),
    eps=args.adam_eps,
    weight_decay=args.weight_decay,
)

# Setup scheduler
scheduler = get_linear_warmup_cosine_decay_scheduler(optimizer, args.warmup_steps, args.training_steps, base_lr=args.learning_rate, init_lr=args.init_lr, min_lr=args.learning_rate / args.min_factor, cosine_decay=not args.disable_cosine_decay)

# setup up Accelerate, without dataloader:
#model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
optimizer, scheduler = accelerator.prepare(optimizer, scheduler)

if args.use_wandb:
    accelerator.init_trackers(args.wandb_project, init_kwargs={'wandb': {'name': exp_name, 'config': args}})
else:
    accelerator.init_trackers('')

In [238]:
# Create save dir if does not exist
if args.save_model and not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)
trainer = Trainer(
    model = model,
    optimizer = optimizer,
    scheduler = scheduler,
    accelerator = accelerator,
    tasks = tasks,
    exp_name = exp_name,
    args=args
)
trainer.train()

ValueError: max() arg is an empty sequence

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wand