In [1]:
import os
import json
from argparse import ArgumentParser
import torch
from transformers import AutoTokenizer

from arch.config import Config
from arch.model import NanoFormerForCausalLM

In [2]:
os.environ['CUDA_VISIBLE_DEVICES']='1'

In [3]:
model_path = '/home/datta0/models/nanoformer/ngpt_full_new_1ep/best_model'
# model_path = '/home/datta0/models/nanoformer/gqa_minipile_1ep/checkpoint-0'

if not os.path.isdir(model_path):
    raise OSError(f"Path {model_path} does not exist")
config_path = os.path.join(model_path, 'config.json')
if not os.path.isfile(config_path):
    raise OSError(f'Config file does not exist in {model_path}')

with open(config_path, 'r') as f:
    config = Config(**json.load(f))
    config.gradient_checkpointing = False

model = NanoFormerForCausalLM(config)


`RotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


In [4]:
model.model.embed_tokens.weight, model.model.layers[0].attention.q.weight

(Parameter containing:
 tensor([[-0.7722,  0.7752, -0.5999,  ...,  2.0298, -0.3548,  0.6717],
         [-0.5416,  1.1817,  1.1714,  ...,  0.0441, -0.9582, -0.9147],
         [-0.2746, -0.7430,  1.0229,  ..., -1.8631,  0.0672, -0.7881],
         ...,
         [-0.6127, -0.6503, -0.1991,  ...,  0.2042,  0.6728,  0.8433],
         [-0.6183,  0.3486,  1.2508,  ..., -0.5478,  0.7902,  0.5463],
         [-0.0797, -0.5059,  0.7187,  ...,  0.4446,  0.8132,  0.4625]],
        requires_grad=True),
 Parameter containing:
 tensor([[ 0.0619, -0.0468, -0.0600,  ...,  0.0579, -0.0264,  0.0656],
         [-0.0053, -0.0014, -0.0181,  ..., -0.0016, -0.0706, -0.0135],
         [-0.0212,  0.0609,  0.0306,  ..., -0.0723, -0.0439,  0.0407],
         ...,
         [ 0.0315, -0.0125,  0.0224,  ...,  0.0249,  0.0431,  0.0302],
         [ 0.0103, -0.0107, -0.0215,  ...,  0.0274, -0.0368, -0.0064],
         [-0.0631,  0.0515, -0.0587,  ..., -0.0079, -0.0129,  0.0227]],
        requires_grad=True))

In [5]:
model.load_state_dict(torch.load(os.path.join(model_path, 'pytorch_model.bin'), weights_only=True),strict=False)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("NeelNanda/gpt-neox-tokenizer-digits")
tokenizer.pad_token = tokenizer.eos_token
model.tokenizer = tokenizer

print('Model loaded successfully')
    

Model loaded successfully


In [6]:
model.model.layers[0].attention.use_ngpt

True

In [7]:
# model.model.embed_tokens.weight, model.model.layers[0].attention.q.weight
# model.model.layers[0].attention.sqk

In [8]:
import torch
import torch.nn.functional as F  # Import for softmax

with torch.no_grad():
    text = "Be afraid, England and Wales 2019. The Aussies are coming. Or rather, the Aussies are still coming, after an 86-run defeat of a New Zealand team who seemed consumed by the occasion at Lord’s. At times in the Black Caps’ attempts to chase 243 this felt a bit like a Sunday morning junior age group game. Steve Smith sent down some weird, wonky all-sorts. Wickets were greeted with jokey huddles. It took the return of Mitchell Starc to restore a sense of World Cup order, figures of five for 26 reflecting a spell of brutal, high-grade, white-ball fast-bowling that blew away the tail. Pakistan’s Imad Wasim holds nerve to see off Afghanistan in thriller Read more Victory leaves Australia on their own at the top of the group stage table with seven wins from eight, and with some of their own question marks finding an answer or two. They had some help along the way, not least from Kane Williamson’s diffident captaincy. On a sun-baked north London day New Zealand had first shown how to beat Australia; then almost immediately they showed how to fail to beat Australia. Exposing that thin-looking middle order had always looked a plan. Failing to punch through by taking off your best bowlers was where the game got away, captured by the sight of the skipper wheeling out seven overs of mid-innings part-time leg-spin. Trent Boult even had time at the end of Australia’s innings to conjure a largely pointless World Cup hat-trick. Instead it was a gutsy, occasionally streaky 107-run sixth-wicket partnership between Usman Khawaja and Alex Carey that decided this game. From the start Lord’s was a place of Trans-Tasman good cheer as the grey shroud of the last few weeks lifted. Australia had won the toss and elected to bat. In any list of David Warner’s top five career sledges"
    tokens = tokenizer([text], return_tensors='pt')
    input_ids = tokens['input_ids']
    attention_mask = tokens['attention_mask']

    new_tokens = 50  # Set a maximum length for the generated text
    temperature = 1e-3 # Adjust this for more or less randomness
    top_k = 10 # Adjust this for more or less randomness
    count=0
    
    while count < new_tokens:
        count+=1
        output = model(input_ids, attention_mask)
        
        # Get the logits for the last token
        next_token_logits = output[0][:, -1, :]
        
        # Apply temperature
        next_token_logits = next_token_logits / temperature
        
        # Apply top-k filtering
        if top_k > 0:
            values, _ = torch.topk(next_token_logits, top_k)
            min_values = values[:, -1].unsqueeze(1).expand_as(next_token_logits)
            next_token_logits = torch.where(next_token_logits < min_values,
                                            torch.ones_like(next_token_logits) * -1e10,
                                            next_token_logits)
        
        # Convert logits to probabilities
        next_token_probs = F.softmax(next_token_logits, dim=-1)
        
        # Sample from the distribution
        predicted_token_id = torch.multinomial(next_token_probs, num_samples=1)
        
        # Append the predicted token to the input sequence
        input_ids = torch.cat((input_ids, predicted_token_id), dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones(1, 1, dtype=attention_mask.dtype, device=attention_mask.device)], dim=1)
        
        # Decode the current sequence
        out_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        
        print(tokenizer.decode(input_ids[0][-1], skip_special_tokens=True).replace('Ġ',''), end=' ', flush=True)
        text = out_text
        
        # Check for end-of-sequence token (optional, depends on your tokenizer)
        # if predicted_token_id == tokenizer.eos_token_id:
        #     break

IndexError: index out of range in self

In [11]:
tokenizer.vocab_size

48262

In [24]:
input_ids[0], predicted_token_id

(tensor([5023,  220,   75,  337,  292,   17, 1671,  248, 3106, 2373,   19,  317,
          259,  971,  362,  414,  372,  247,   19, 1035, 3442,   17,  223,  259,
          971,  362,  414, 1761,  372,  247,   17,  656,  348,  890,   27,   18,
           87,  314, 3541,  236,  220,  594, 2649,  781,  561,  436,  310,  229,
          780, 4583,  333,  223, 3974,  251,  343, 3853, 1551,   88,   19, 1030,
         2153,  239,  223, 2256, 2554,   88, 1551, 2293, 3321,  254,  422,  642,
         1855,   24,  630,  240, 2717,  220,  244,  235, 1800,  220, 4344,  296,
          255, 4728,  558, 3436, 2012, 1031, 1187,   19, 4957, 2811, 2174, 1493,
          980, 1547, 1125,   17, 1193, 3360,  725,   18,   88, 1352,   19,  316,
          718, 1038,  452,  308, 1522,  229,  331,  558,  506,  522,  263, 2558,
          578,   19,  519, 1543,  223, 1222,  236,  266,  235,  733,  743, 3269,
           72,  254, 1596,  393,  220,  238, 1869,  236,  934, 1480, 1677,   17,
         3248,  958,  236, 1

In [1]:
next_token_probs

In [47]:
input_ids, predicted_token_id

In [23]:
out_text

In [34]:
torch.argmax(output[0][0], dim=-1), torch.argmax(output[0][-1])

In [36]:
torch.argmax(output[0][-1][-1])

In [2]:
input_ids = tokens['input_ids']

In [8]:
x = model.model.embed_tokens(input_ids)
x

In [9]:
model.model.layers

In [10]:
position_ids = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0).expand(input_ids.size(0), -1)
position_embeddings = model.model.rotary_emb(input_ids, position_ids)

In [11]:
with torch.no_grad():
    for layer in model.model.layers:
        x = layer(x, position_embeddings)
        print(x.shape)
        print(x)
        print("\n\n\n")