# My Transformer

I wrote this to help me understand _Attention Is All You Need_: https://arxiv.org/abs/1706.03762

I cut out the Encoder and I'm using it to generate English words

In [None]:
from collections import defaultdict, Counter
import multiprocessing.pool
from math import sqrt, sin, cos, floor
import os
import random
import sys
import string
import time
import numpy as np

import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.checkpoint import checkpoint
from torch.optim.lr_scheduler import LambdaLR
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
from torchtext.datasets import WikiText2
from torchtext.datasets.language_modeling import LanguageModelingDataset


## Setup
 

In [None]:
# Basic Config
conf = {
        'attn_heads': 8,
        'training_bptt_len': 180,
        #'cuda_device_ids': [3, 2, 1, 0],  # I need better GPU coolng first
        'cuda_device_ids': [0],
        'd_model': 512,
        #'datafile': './city_names.txt', # from: https://www.britannica.com/topic/list-of-cities-and-towns-in-the-United-States-2023068
        #'datafile': './corncob_lowercase.txt',  # from: http://www.mieliestronk.com/corncob_lowercase.txt
        #'datafile': './alphabet_short.txt',  
        #'datafile': './dummy_data.txt', 
        #'dataset': 'WikiText2',
        'dataset': 'WikiText103',
        'dropout': 0.1,
        'initial_learning_rate': 1,
        'total_epochs': 4,
        'num_blocks_encoder': 0,
        'num_blocks_decoder': 6,
        'max_bptt_len': 1024,
        'max_vocab_size': 60000,
        #'minibatch_size': 32 * 16,
        'minibatch_size': 16,
        #'optimizer': 'Adam',  
        #'optimizer': 'SGD',
        'random_seed': 0,
        'warmup_steps': 2500,
        }


# debugging
#conf['attn_heads'] = 1
#conf['d_model'] = 1
#conf['bptt_len'] = 2
#conf['datafile'] = './dummy_data.txt'  
#conf['num_blocks_decoder'] = 1
#conf['minibatch_size'] = 1
#conf['total_epochs'] = 1


# Make sure d_model, heads, and d_key are compatible
assert conf['d_model'] % conf['attn_heads'] == 0, \
    f'attn_heads=%s does not evenly divide d_model=%s' % (conf['attn_heads'], 
                                                         conf['d_model'])
conf['d_key'] = conf['d_model'] / conf['attn_heads']

# Set up the RNGs for repeatability
if conf['random_seed']:
    torch.manual_seed(conf['random_seed'])
    torch.cuda.manual_seed(conf['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(conf['random_seed'])
    
    
# Set up Cuda
print("Using", len(conf['cuda_device_ids']), "GPU(s):")
for i in conf['cuda_device_ids']:
    print("    cuda:%s:" % i, torch.cuda.get_device_name(i))

device = torch.device('cuda:' + str(conf['cuda_device_ids'][0]))

print()

# Logging
experiment = None
experiment = wandb.init(project="my-transformer", config=conf)

## Model Architecture

### Section 3.2.1: Scaled Dot-Product Attention

In [None]:
class AttentionHead(nn.Module):
    """Implements section 3.2.1"""
    def __init__(self, d_model=conf['d_model'], d_key=conf['d_key']):
        super().__init__()
        
        self.d_key = d_key

        # head projections
        self.Wq = nn.Linear(d_model, d_key, bias=False)
        self.Wk = nn.Linear(d_model, d_key, bias=False)
        self.Wv = nn.Linear(d_model, d_key, bias=False)

        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, queries, keys, values, mask=None):
        # project queries, keys, values
        queries = self.Wq(queries)
        keys = self.Wk(keys)
        values = self.Wv(values)

        # calculate compatibility function
        scores = torch.matmul(queries, torch.transpose(keys, -2, -1)) 
        scores = scores / sqrt(self.d_key)

        # Filter out attention to future positions
        if mask is not None:
            scores.masked_fill_(mask == 0, float('-inf'))
            
        # softmax
        scores = self.softmax(scores)
        
        # sum the weighted value vectors
        attn = torch.matmul(scores, values)  # shape = (bptt_len, d_key)

        return attn

### 3.2.2 Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    "Section 3.2.2"
    def __init__(self, d_model=conf['d_model'], heads=conf['attn_heads']):
        super().__init__()
        d_key = int(d_model / heads)

        attn_heads = [AttentionHead(d_model=d_model, d_key=d_key) for _ in range(heads)]
        self.attn_heads = nn.ModuleList(attn_heads)
        self.Wo = nn.Linear(d_model, d_model, bias=False)
                    
    def forward(self, queries, keys, values, mask=None):
        head_attns = [h(queries=queries, keys=keys, values=values, mask=mask) 
                      for h in self.attn_heads]
        head_attn = torch.cat(head_attns, dim=-1)
        ret = self.Wo(head_attn)
        return ret

### 3.3 Position-wise Feed-Forward Networks

In [None]:
class FFN(nn.Module):
    "Section 3.3"
    def __init__(self, d_model=conf['d_model'], multiplier=4):
        super().__init__()
        
        d_ff = int(multiplier * d_model)

        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff, bias=False), 
                                 nn.ReLU(), 
                                 nn.Linear(d_ff, d_model, bias=False))

    def forward(self, x):
        return self.ffn(x)

### 3.1 Encoder and Decoder Stacks

In [None]:
class EncoderBlock(nn.Module):
    "Section 3.1, Encoder"
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 dropout=conf['dropout']):
        super().__init__()
        
        self.attn = MultiHeadAttention(d_model=d_model, heads=heads)
        self.attn_drop = nn.Dropout(p=dropout)
        self.attn_norm = nn.LayerNorm(d_model)
        
        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        a1 = self.attn(x, x, x)
        a1 = self.attn_drop(a1)
        a1 = self.attn_norm(x + a1) 

        a2 = self.ffn(a1)
        a2 = self.ffn_drop(a2)
        a2 = self.ffn_norm(a1 + a2)

        return a2

In [None]:
class DecoderBlock(nn.Module):
    "Section 3.1, Decoder"
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 dropout=conf['dropout'],
                 encoder_attention=True):
        super().__init__()
           
        
        self.self_attn = MultiHeadAttention(d_model=d_model, heads=heads)
        self.self_attn_drop = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        if encoder_attention:
            self.enc_attn = MultiHeadAttention(d_model=d_model, heads=heads)
            self.enc_attn_drop = nn.Dropout(p=dropout)
            self.enc_attn_norm = nn.LayerNorm(d_model)
        else:
            self.enc_attn = None
            
        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_out=None, self_attn_mask=None):
        a1 = self.self_attn(x, x, x, self_attn_mask)
        a1 = self.self_attn_drop(a1)
        a1 = x + a1  # residual
        a1 = self.self_attn_norm(a1) 
        
        if self.enc_attn is None:
            a2 = a1
        else:
            a2 = self.enc_attn(a1, encoder_out, encoder_out)
            a2 = self.enc_attn_drop(a2)
            a2 = a1 + a2  # residual
            a2 = self.enc_attn_norm(a2)

        a3 = self.ffn(a2)
        a3 = self.ffn_drop(a3)
        a3 = a2 + a3  # residual
        a3 = self.ffn_norm(a3)
        
        return a3

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 num_blocks=conf['num_blocks_encoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(d_model, heads, dropout) 
                                     for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a)
        return a

In [None]:
class Decoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 num_blocks=conf['num_blocks_decoder'],
                 dropout=conf['dropout'],
                 encoder_attention=True):
        super().__init__()

        self.blocks = nn.ModuleList([DecoderBlock(d_model, heads, dropout, encoder_attention) 
                                     for _ in range(num_blocks)])
            
    def forward(self, decoder_in, encoder_out=None, self_attn_mask=None):
        a = decoder_in
        for block in self.blocks:
            a = block(a, encoder_out, self_attn_mask)
        return a

### Section 3 Model Architecture

In [None]:
class Transformer(nn.Module):
    def __init__(self, 
                 vocab_len, 
                 pad_index,
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 max_bptt_len=conf['max_bptt_len'],
                 num_blocks_encoder=conf['num_blocks_encoder'],
                 num_blocks_decoder=conf['num_blocks_decoder'], 
                 dropout=conf['dropout'],
                 encoder_decoder_attention=True):
        super().__init__()
        
        self.d_model = d_model
        self.vocab_len = vocab_len
        self.pad_index = pad_index
        self.max_bptt_len = max_bptt_len
        self.encoder_decoder_attention = encoder_decoder_attention
  
        self.embedding = nn.Embedding(vocab_len, d_model, padding_idx=pad_index)
        self.register_buffer('position_encoding', self._position_encoding(max_bptt_len, d_model))
        self.register_buffer('self_attn_mask', self._make_mask(max_bptt_len))
                                            
        #self.encoder = Encoder(d_model, heads, num_blocks_encoder, dropout)
        self.decoder = Decoder(d_model, heads, num_blocks_decoder, dropout, encoder_decoder_attention)

        self.linear = nn.Linear(d_model, vocab_len, bias=False)
           

    @classmethod
    def _make_mask(cls, bptt_len):
        return torch.ones([bptt_len, bptt_len]).tril()

    @classmethod
    def _position_encoding(cls, bptt_len, d_model):
        rows = [tensor([sin(pos/(10000**(i/d_model))) 
                        if i % 2 == 0 
                        else 
                        cos(pos/(10000**((i-1)/d_model))) 
                        for i in range(d_model)])
                for pos in range(bptt_len)]
        stack = torch.stack(rows, dim=1)
        
        return stack.T

    def embed(self, indices):
        """Implements the embedding from Section 3.4 Embeddings and Softmax"""
        this_bptt_len = indices.shape[-1]
        pe = self.position_encoding[:this_bptt_len, :]

        embedded = self.embedding(indices)
        
        #print('pe.shape:', pe.shape)
        #print('embedded.shape:', embedded.shape)
        return pe + embedded
        
    #def forward(self, encoder_in, encoder_out=None, decoder_in=[]):
    def forward(self, decoder_in, encoder_out=None, pos=None, pre_embedded=False):
        """parameters:
        encoder_in:  (rank-1 tensor) vocab indices of encoder input token 
                     sequence
        encoder_out: (optional rank-1 tensor) passing this skips 
                     the encoder execution, and acts and if this were 
                     the indices the encoder produced.
        decoder_in:  (optional rank-1 tensor) vocab indices of prior 
                     decoder output for auto-regression. Right 
                     shifted by one position."""

        this_bptt_len = decoder_in.shape[-1]

        if encoder_out:
            assert self.encoder_decoder_attention, \
            "encoder_out passed to model created without encoder-decoder attention"

        # Embed
        if pre_embedded:
            #eo_embedded = encoder_out
            di_embedded = decoder_in
        else:
            #eo_embedded = self.embed(encoder_out) 
            di_embedded = self.embed(decoder_in)
        
        # Encode
        #encoded = self.encoder(self.embed(encoder_in))
        #encoded = eo_embedded
        
        # Decode
        self_attn_mask = self.self_attn_mask[:this_bptt_len, :this_bptt_len]
        decoded = self.decoder(di_embedded, self_attn_mask=self_attn_mask)
        
        # Return predictions for next token
        if pos is not None:
            decoded = decoded[:, pos, :]
        
        y_pred = self.linear(decoded)
                
        return y_pred


## Load Data and Build the Model

In [None]:
# dataloader and vocab
#train_ds = load_dataset()
def get_dataloader():
    dataloader = getattr(torchtext.datasets, conf.get('dataset', None))
    TEXT = torchtext.data.Field()
    train, val, test = dataloader.splits(TEXT)
    TEXT.build_vocab(train, max_size=conf['max_vocab_size'])
    return torchtext.data.BPTTIterator.splits((train, val, test), 
                                              batch_size=conf['minibatch_size'], 
                                              bptt_len=conf['training_bptt_len'] * 2,
                                              device=device)


train_ds, val_ds, test_ds = get_dataloader()
train_ds_len = float(len(list(iter(train_ds))))
#train_ds, val_ds, test_ds = dataloader.iters(batch_size=conf['minibatch_size'], 
#                                             bptt_len=2 * conf['training_bptt_len'],
#                                             device=device)
vocab = train_ds.dataset.fields['text'].vocab
vocab_len = len(vocab)
pad_token = '_'
pad_index = vocab.stoi[pad_token]
print('There are', int(train_ds_len), 'training batches')
print('vocab_len =', vocab_len)

In [None]:
# Create the model
model = Transformer(vocab_len, pad_index, encoder_decoder_attention=False)
model = model.to(device)
num_params = sum([np.prod(p.shape) for p in model.parameters()])
print(f"Model has {num_params:,d} parameters")
#model = nn.DataParallel(model, device_ids=conf['cuda_device_ids'])

# Define the Loss
#criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
criterion = nn.CrossEntropyLoss()

### Validate the Model

## Training Helper Functions

In [None]:
def _scheduler_lr(step, d_model=conf['d_model'], warmup_steps=conf['warmup_steps']):
    """return the learning rate multiplier for this step. This number gets multiplied by 
    conf['initial_learnign_rate'] inside the scheduler."""
    step = step + 1  # handle step 0
    lr = (d_model**-.5) * min(step**-.5, step * (warmup_steps**-1.5))
    return lr

optimizer = torch.optim.Adam(model.parameters(), 
                             betas=(0.9, 0.98), 
                             eps=1e-9, 
                             lr=conf['initial_learning_rate'])

scheduler = LambdaLR(optimizer, lr_lambda=_scheduler_lr)

In [None]:
def accuracy(output, expected_indices):
    indices = torch.max(output, dim=-1)[1]
    indices = indices.squeeze()
    acc = (indices == expected_indices) / float(indices.numel())
    acc = float(acc.sum()) * 100
    return acc

def run_minibatch(di, y, optimizer, scheduler, model):
    """Runs one minibatch training and returns the loss and accuracy for that minibatch"""
    model = model.train()
    start_time = time.time()
    optimizer.zero_grad()
    y_pred = model(decoder_in=di)
    acc = accuracy(y_pred, y)
    y_pred = y_pred.transpose(-2, -1)
    loss = criterion(y_pred, y)
    loss.backward()  # Not sure why, but this step logs a UserWarning
    optimizer.step()
    scheduler.step()
    duration = time.time() - start_time
    sys.stdout.write('.')
    return loss.item(), acc, duration

def validate(model):
    with torch.no_grad():
        eval_model = model.eval()
        accuracies = []
        losses = []
        minibatches = 0
        for batch in test_ds:
            target = batch.target.T
            y_pred = model(decoder_in=batch.text.T)
            accuracies.append(accuracy(y_pred, target))
            y_pred = y_pred.transpose(-2, -1)
            losses.append(criterion(y_pred, target).item())
            minibatches += 1
            break
    acc = tensor(accuracies, device=device).float().mean().item()
    loss = tensor(losses, device=device).float().mean()
    perplexity = torch.exp(loss).item()
    return acc, perplexity

def _log_status(learning_rate, epoch, epoch_step, epoch_secs, step_secs, step_loss, 
                step_train_accuracy, test_accuracy, test_perplexity,
                steps_per_epoch=train_ds_len):

    ts_format = '%I:%M%p'
    now = time.time()
    epoch_pct_done = 100.0 * epoch_step / steps_per_epoch
    est_total_secs = epoch_secs / (epoch_pct_done/100)
    est_end_time = now - epoch_secs + est_total_secs
    now_ts = time.strftime(ts_format, time.localtime())
    est_end_ts = time.strftime(ts_format, time.localtime(est_end_time))
    cumulative_step = (epoch * steps_per_epoch) + epoch_step
    print()
    print(now_ts,
          '(%.1f%%)' % epoch_pct_done,
          '(eta:%s)' % est_end_ts,
          'epoch:%s' % epoch,
          'step:%s' % epoch_step,
          #'(%.1fs)' % epoch_secs, 
          'lr:%.6f' % learning_rate,
          'test_perplexity=%.1f' % test_perplexity,
          'test_accuracy=%.1f%%' % test_accuracy,
          'train_step_loss=%.4f' % step_loss, 
          'train_accuracy=%.1f%%' % step_train_accuracy, 
         )   

    if experiment:
        experiment.log({'learning_rate': learning_rate,
                        'epoch': epoch,
                        'epoch_step': epoch_step,
                        'cumulative_step': cumulative_step,
                        'epoch_secs': epoch_secs,
                        'step_secs': step_secs,
                        'step_loss': step_loss,
                        'step_train_accuracy': step_train_accuracy,
                        'test_accuracy': test_accuracy,
                        'test_perplexity': test_perplexity,
                        'epoch_pct_done': epoch_pct_done})
    
            
def do_epoch(epoch, model, optimizer, scheduler, checkpoint_freq=500, steps_per_epoch=train_ds_len):
    """Runs one full training batch and returns the average loss,
    accuracy, and duration time in seconds"""
    last_epoch_step = steps_per_epoch - 1

    epoch_start_time = None
    for epoch_step, batch in enumerate(train_ds):
        epoch_step += 1
        if epoch_start_time is None:
            epoch_start_time = time.time()
        #print('eo.shape:', encoder_out.shape, 'di.shape', decoder_in.shape, 'y.shape:', y.shape)
        learning_rate = optimizer.param_groups[0]['lr']
        
        step_loss, step_train_accuracy, step_secs = run_minibatch(batch.text.T, 
                                                                  batch.target.T, 
                                                                  optimizer, 
                                                                  scheduler,
                                                                  model) 
        if (epoch_step % checkpoint_freq == 1) or (epoch_step == last_epoch_step):
            test_accuracy, test_perplexity = validate(model)
            epoch_secs = time.time() - epoch_start_time
            _log_status(learning_rate=learning_rate, 
                        epoch=epoch, 
                        epoch_step=epoch_step,
                        epoch_secs=epoch_secs, 
                        step_secs=step_secs, 
                        step_loss=step_loss, 
                        step_train_accuracy=step_train_accuracy, 
                        test_accuracy=test_accuracy,
                        test_perplexity=test_perplexity)
            save_file_name = './my-transformer_%s_%s-layer_%s-vocab_%s-epoch_%s_step.pt' % \
                             (conf['dataset'], conf['num_blocks_decoder'], len(vocab), epoch-1, epoch_step)
            if hasattr(model, 'module'):
                torch.save(model.module.state_dict(), save_file_name)
            else:
                torch.save(model.state_dict(), save_file_name)
    return

def train(num_epochs=conf['total_epochs'], model=model, checkpoint_frequency=500,
          vocab=vocab, optimizer=optimizer, scheduler=scheduler, criterion=criterion,
         start_epoch=0):
    """Runs num_epochs training batches and prints out results"""
        
    for epoch in range(start_epoch, start_epoch+num_epochs):
        do_epoch(epoch, model, optimizer, scheduler, checkpoint_frequency)


In [None]:
#PATH="""./my-transformer_WikiText103_6-layer_60002-vocab_-1-epoch_17921_step.pt"""
#model.load_state_dict(torch.load(PATH, map_location=device))

## Train the Model

In [None]:
train(num_epochs=conf['total_epochs'], start_epoch=0, checkpoint_frequency=500)

## Sampling Helper Functions

In [None]:
unk_token = '<unk>'
pad_token = '<pad>'
eos_token = '<eos>'

def numericalize(string, pad=True):
    """Takse a string and returns a tensor of indices for the tokens"""
    bptt_len = conf['training_bptt_len']
    TEXT_FIELD = train_ds.dataset.fields['text']
    tokens = TEXT_FIELD.tokenize(string)
    tokens = tokens[-bptt_len:]
    t = TEXT_FIELD.numericalize([tokens]).T
    num_tokens = t.shape[1]
    if pad:
        pad_len = max(0, bptt_len - num_tokens)
        pad_t = torch.zeros([1, pad_len]).long()
        t = torch.cat((t, pad_t), dim=1)
    return t, num_tokens

def tokenize(indices):
    "Takes a tensor of token indices and returns a string"
    tokens = [vocab.itos[i] for i in indices.squeeze()]
    return ' '.join(tokens)

def get_next_token(decoder_in, pos, model=model, deterministic=False):
    """Runs one step of auto-regression, returning the output token for
    position `pos`."""
    
    #print('decoder_in=', decoder_in)
    #print('pos=', pos)
    #print('decoder_in[0,pos]=', decoder_in[0,pos])
    #if pos + 1 < decoder_in.shape[1]:
        #print('decoder_in[0,pos+1]=', decoder_in[0,pos+1])
    decoder_out = model(decoder_in=decoder_in)
    #print('decoder_out=', decoder_out)
    #print('decoder_out[o,pos,:]=', decoder_out[0,pos,:])
    
    if deterministic:
        _, indices = torch.max(decoder_out, dim=-1)
    else:
        probs = nn.functional.softmax(decoder_out.float(), dim=-1)
        m = torch.distributions.multinomial.Multinomial(probs=probs)
        _, indices = torch.max(m.sample(), dim=-1)

    next_index = int(indices[0,pos])
    return next_index, vocab.itos[next_index]

def sample(prompt, deterministic=False, vocab=vocab, prnt=True):
    """Auto-regresses using prompt to create the encoder_out tensor"""
    bptt_len = conf['training_bptt_len']
    
    di, num_tokens = numericalize(prompt)
    di = di.to(device)

    with torch.no_grad():
        eval_model = model.eval()
        out = []
        next_token = None
        next_index = None
        pos = num_tokens - 1
        #print(pos)
        for _ in range(2 * bptt_len):
            #print('di =', di)
            next_index, next_token = get_next_token(di, pos=pos, model=eval_model, deterministic=deterministic)
            #print('next_index =', next_index)
            #print('next_token =', next_token)
            if next_token in (eos_token, pad_token):
                break
            if next_token is not None:
                out.append(next_token)
                pos += 1
                if pos >= di.shape[1]:
                    di = torch.roll(di, -1, -1)
                    pos -= 1
                di[0, pos] = next_index
        
    out = ' '.join(out)
    if prnt:
        print(prompt + '\n --> \n' + out)
    return out

## Sample the model

In [None]:
prompt = """Born in Omaha , Nebraska , Malcolm X spent his teenage years living in a series of foster homes after his father 's death and his mother 's hospitalization . He engaged in several illicit activities there , eventually being sentenced"""

In [None]:
prompt = """Albert Einstein ( 14 March 1879 @–@ 18 April 1955 ) was a German @-@ born theoretical physicist who developed the theory of relativity , one of the two pillars of modern physics ( alongside quantum mechanics ) . His work"""

In [None]:
for i, _ in enumerate(range(5)):
    print()
    print('Completion #%s:' % i)
    #print('-' * 20)
    out = sample(prompt, deterministic=False, prnt=False)
    print(prompt, '-->', out)
    #print('-' * 20)

In [None]:
prompts = ['The',
           'Of',
           'To',
           'In',
           'A', 
           'Was',
           'The',
           'On',
           'That',
           'For',
           'As',
           'With']
for prompt in prompts:
    #print('Prompt:')
    #print("=" * 40)
    #print(prompt, '...')
    #print("=" * 40)
    for i, _ in enumerate(range(1)):
        print()
        #print('Completion #%s:' % i)
        #print('-' * 20)
        out = sample(prompt, deterministic=False, prnt=False)
        print(prompt, '-->', out)
        #print('-' * 20)