In [None]:
from engine import Model
import torch
model = Model('EleutherAI/gpt-j-6b')
#model = Model('gpt2')

In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [None]:
def get_scores(hs):
    return model.lm_head(model.transformer.ln_f(hs))

def get_prob_tokens(scores,topk=1):
    probs = torch.nn.functional.softmax(scores, dim=-1)
    favorite_probs, favorite_tokens = probs.topk(k=topk, dim=-1)
    return favorite_probs, favorite_tokens

In [None]:
def show_logit_lens(model, tok, prefix, topk=5, color=None):
    from baukit import show
    num_layers = len(model.transformer.h)

    with model.generate(device_map='cuda:0', max_new_tokens=3) as generator:
        with generator.invoke(prefix) as invoker:
            hs = []
            for i in range(num_layers):
                hs.append(get_scores(model.transformer.h[i].output[0]).save())
    
    
    output = generator.output
    
    hs = [curr_hs.value for curr_hs in hs]
    hs = torch.stack(hs)
    print(hs.shape)
    
    # The full decoder head normalizes hidden state and applies softmax at the end.
    favorite_probs, favorite_tokens = get_prob_tokens(hs, topk=topk)

    # Let's also plot hidden state magnitudes
    magnitudes = hs.norm(dim=-1)
    
    # For some reason the 0th token always has huge magnitudes, so normalize based on subsequent token max.
    if (len(magnitudes[0][0]) > 1):
        magnitudes = magnitudes / magnitudes[:,:,1:].max()
    
    # All the input tokens.
    prompt_tokens = [tok.decode(t) for t in tok.encode(prefix)]

    # Foreground color shows token probability, and background color shows hs magnitude
    if color is None:
        color = [66, 135, 245]
    def color_fn(m, p):
        a = [int(255 * (1-m) + c * m) for c in color]
        b = [int(255 * (1-p) + 0 * p)] * 3
        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})',
                          color=f'rgb({b[0]}, {b[1]}, {b[2]})' )

    # In the hover popup, show topk probabilities beyond the 0th.
    def hover(tok, prob, toks, m):
        lines = [f'mag: {m:.2f}']
        for p, t in zip(prob, toks):
            lines.append(f'{tok.decode(t)}: prob {p:.2f}')
        return show.attr(title='\n'.join(lines))
    
    # Construct the HTML output using show.
    header_line = [ # header line
             [[show.style(fontWeight='bold'), 'Layer']] +
             [
                 [show.style(background='orange'), show.attr(title=f'Token {i}'), t]
                 for i, t in enumerate(prompt_tokens)
             ]
         ]
    layer_logits = [
             # first column
             [[show.style(fontWeight='bold'), layer]] +
             [
                 # subsequent columns
                 [color_fn(m, p[0]), hover(tok, p, t, m), show.style(overflowX='hide'), tok.decode(t[0])]
                 for m, p, t in zip(wordmags, wordprobs, words)
             ]
        for layer, wordmags, wordprobs, words in
                zip(range(len(magnitudes)), magnitudes[:, 0], favorite_probs[:, 0], favorite_tokens[:,0])]
    
    # If you want to get the html without showing it, use show.html(...)
    show(header_line + layer_logits + header_line)

In [None]:
show_logit_lens(model, model.tokenizer, '/* Copyright (C)', topk=20)

In [None]:
def get_last_layer_scores():

    hs = model.transformer.h[-1].output[0]
    return model.lm_head(model.transformer.ln_f(hs))


def decode(scores):
    print(len(scores))
    scores = scores.argmax(dim=2)[0, -1]
    return model.tokenizer.decode(scores)

In [None]:
num_layers = len(model.transformer.h)
num_toks_gen = 3

with model.generate(device_map='cuda:0', max_new_tokens=3) as generator:
    with generator.invoke('Madison square garden is located in the city of New') as invoker:
        tokenized = invoker.tokens
        init_logits = get_last_layer_scores().save()
        next_tok_logits =  []
        for i in range(1, num_toks_gen):
            invoker.next()
            next_tok_logits.append(get_last_layer_scores().save())

output = generator.output
init_logits = init_logits.value
next_toks = []
for next_tok in next_tok_logits:
    init_logits = torch.cat((init_logits, next_tok.value), dim=1)

pred, fav_toks = get_prob_tokens(init_logits)
for ft in fav_toks[0]:
    print(model.tokenizer.decode(ft))
print(output.shape)
print(init_logits.shape)
# for p, t in zip(pred, fav_toks[0]):
#     print(f'{model.tokenizer.decode(t)}: prob {p:.2f}')

In [None]:
def show_logit_lens_extended(model, tok, prefix, topk=5, num_toks_gen=3, color=None):
    from baukit import show
    num_layers = len(model.transformer.h)
    
    with model.generate(device_map='cuda:0', max_new_tokens=num_toks_gen) as generator:
        with generator.invoke(prefix) as invoker:
            tokenized = invoker.tokens
            init_logits = []
            for i in range(num_layers):
                init_logits.append(get_scores(model.transformer.h[i].output[0]).save())
            next_tok_logits =  []
            for _ in range(1, num_toks_gen):
                invoker.next()
                curr_tok_logits = []
                for i in range(num_layers):
                    curr_tok_logits.append(get_scores(model.transformer.h[i].output[0]).save())
                next_tok_logits.append(curr_tok_logits)
    
    output = generator.output

    hs = [curr_hs.value for curr_hs in init_logits]
    hs = torch.stack(hs)

    next_hs = []
    for next_curr_hs in next_tok_logits:
        next_sub_hs = [curr_hs.value for curr_hs in next_curr_hs]
        next_sub_hs = torch.stack(next_sub_hs)
        hs = torch.cat((hs, next_sub_hs), dim=2)
    
    # The full decoder head normalizes hidden state and applies softmax at the end.
    favorite_probs, favorite_tokens = get_prob_tokens(hs, topk=topk)
    # Let's also plot hidden state magnitudes
    magnitudes = hs.norm(dim=-1)
    
    # For some reason the 0th token always has huge magnitudes, so normalize based on subsequent token max.
    if (len(magnitudes[0][0]) > 1):
        magnitudes = magnitudes / magnitudes[:,:,1:].max()
    
    # All the input tokens.
    prompt_tokens = [tok.decode(t) for t in output[0,:-1]]

    # Foreground color shows token probability, and background color shows hs magnitude
    if color is None:
        color = [66, 135, 245]
    def color_fn(m, p):
        a = [int(255 * (1-m) + c * m) for c in color]
        b = [int(255 * (1-p) + 0 * p)] * 3
        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})',
                          color=f'rgb({b[0]}, {b[1]}, {b[2]})' )

    # In the hover popup, show topk probabilities beyond the 0th.
    def hover(tok, prob, toks, m):
        lines = [f'mag: {m:.2f}']
        for p, t in zip(prob, toks):
            lines.append(f'{tok.decode(t)}: prob {p:.2f}')
        return show.attr(title='\n'.join(lines))
    
    # Construct the HTML output using show.
    header_line = [ # header line
             [[show.style(fontWeight='bold'), 'Layer']] +
             [
                 [show.style(background='orange'), show.attr(title=f'Token {i}'), t]
                 for i, t in enumerate(prompt_tokens)
             ]
         ]
    layer_logits = [
             # first column
             [[show.style(fontWeight='bold'), layer]] +
             [
                 # subsequent columns
                 [color_fn(m, p[0]), hover(tok, p, t, m), show.style(overflowX='hide'), tok.decode(t[0])]
                 for m, p, t in zip(wordmags, wordprobs, words)
             ]
        for layer, wordmags, wordprobs, words in
                zip(range(len(magnitudes)), magnitudes[:, 0], favorite_probs[:, 0], favorite_tokens[:,0])]
    
    # If you want to get the html without showing it, use show.html(...)
    show(header_line + layer_logits + header_line)

In [None]:
show_logit_lens_extended(model, model.tokenizer, 'Madison square garden is located in the city of New', topk=5, num_toks_gen=4)

In [None]:
def load_prefix(path):
    prefix_vector = torch.load(path)
    return prefix_vector.to(torch.float32)
    
#context_vector = load_prefix("/disk/u/koyena/PrefixLens/results/training/layer13to4_tk1/soft_prefix.pt")
context_vector = load_prefix("/disk/u/koyena/PrefixLens/results/conll_training/layer13to27_tk1/soft_prefix.pt")
print(context_vector.shape)
# context_vector = context_vector.unsqueeze(0)
#print(context_vector.shape)
context_vector = context_vector.expand(1, -1, -1)
print(context_vector.shape)
context_vector.to("cuda:0")
#inputs_embeds = 1.00 * context_vector
#inputs_embeds.to("cuda:0")
#print(inputs_embeds.shape)

In [None]:
def show_future_lens(model, tok, prefix, context, in_layer = 13, tgt_in_layer = 4, topk=5, num_toks_gen=3, color=None):
    from baukit import show

    prefix_pos = len(tok(prefix)['input_ids']) - 1
    
    context = context.detach()
    print(context)
    # prefix_pos = -1
    context_pos = 9
    #context_pos = len(tok(context)['input_ids']) - 1
    num_layers = len(model.transformer.h)

    with model.generate(device_map='cuda:0', max_new_tokens=num_toks_gen) as generator:
        with generator.invoke(prefix) as invoker:
            context_tokenized = invoker.tokens
            transplant_hs = model.transformer.h[in_layer].output[0].t[prefix_pos].save()
            overall_hs = []
            init_logits = []
            for i in range(num_layers):
                init_logits.append(get_scores(model.transformer.h[i].output[0]).save())
                overall_hs.append(model.transformer.h[i].output[0].save())
    
    output = generator.output
    print(tok.decode(output[0][prefix_pos+1:]))
    first_set_logits = [curr_hs.value for curr_hs in init_logits]
    hs = [curr_hs.value for curr_hs in overall_hs]
    first_set_logits = torch.stack(first_set_logits)
    counter = 0
    future_outputs = []
    future_preds = []
    for curr_hs in hs:
        curr_future_outputs = []
        curr_future_preds = []
        for x in range(0, curr_hs.shape[1]):
            sub_hs = curr_hs[:,x,:][None,:]
            with model.generate(device_map='cuda:0', max_new_tokens=num_toks_gen) as generator2:
                with generator2.invoke("_ _ _ _ _ _ _ _ _ _") as invoker:
                    #model.transformer.wte.output = context.detach().to('cuda')
                    model.transformer.wte.output = context
                    model.transformer.h[tgt_in_layer].output[0].t[context_pos] = sub_hs
                    invoker.next()
                    future_output_logits = model.lm_head.output.save()
                    invoker.next()
                    future_output_logits_next = model.lm_head.output.save()
                    invoker.next()
                    future_output_logits_next_next = model.lm_head.output.save()
            counter+= 1
            #curr_output = generator2.output
            #print(curr_output)
            curr_output = torch.squeeze(future_output_logits.value,0)
            curr_output_next = torch.squeeze(future_output_logits_next.value,0)
            curr_output_next_next = torch.squeeze(future_output_logits_next_next.value,0)
            #print("CURR OUTPUT", curr_output.shape)
            curr_fav_tok_pred, curr_fav_tok = get_prob_tokens(curr_output[0], topk=1)
            curr_fav_tok_pred_next, curr_fav_tok_next = get_prob_tokens(curr_output_next[0], topk=1)
            curr_fav_tok_pred_next_next, curr_fav_tok_next_next = get_prob_tokens(curr_output_next_next[0], topk=1)
            #print(curr_fav_tok)
            curr_future_outputs.append([tok.decode(curr_fav_tok),tok.decode(curr_fav_tok_next),tok.decode(curr_fav_tok_next_next)])
            curr_future_preds.append([curr_fav_tok_pred[0].item(), curr_fav_tok_pred_next[0].item(), curr_fav_tok_pred_next_next[0].item()])
        future_outputs.append(curr_future_outputs)
        future_preds.append(curr_future_preds)
    # The full decoder head normalizes hidden state and applies softmax at the end.
    favorite_probs, favorite_tokens = get_prob_tokens(first_set_logits, topk=topk)
    # # Let's also plot hidden state magnitudes
    # magnitudes = first_set_logits.norm(dim=-1)
    
    # # For some reason the 0th token always has huge magnitudes, so normalize based on subsequent token max.
    # if (len(magnitudes[0][0]) > 1):
    #     magnitudes = magnitudes / magnitudes[:,:,1:].max()
    
    # All the input tokens.
    prompt_tokens = [tok.decode(t) for t in tok.encode(prefix)]

    # Foreground color shows token probability, and background color shows hs magnitude
    if color is None:
        #color = [66, 135, 245]
        color = [50, 168, 123]
    def color_fn(p, future_probs = None):
        a = [int(255 * (1-p) + c * p) for c in color]
        if future_probs is not None:
            total_probs = p + sum(future_probs)
            new_p = total_probs / (len(future_probs) + 1)
            a = [int(255 * (1-new_p) + c * new_p) for c in color]
        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})')

    # In the hover popup, show topk probabilities beyond the 0th.
    def hover(tok, prob, toks):
        lines = []
        for p, t in zip(prob, toks):
            lines.append(f'{tok.decode(t)}: prob {p:.2f}')
        return show.attr(title='\n'.join(lines))

    def decode_escape(tok,token,actual_decode=True):
        if not actual_decode:
            if type(token) == list:
                return [t.encode("unicode_escape").decode() for t in token]
            return token.encode("unicode_escape").decode()
        if type(token) == list:
                return [tok.decode(t).encode("unicode_escape").decode() for t in token]
        return tok.decode(token).encode("unicode_escape").decode()
    # Construct the HTML output using show.
        #[[show.style(fontWeight='bold'), 'Layer/Input']] +
    # background=f'rgb(247, 212, 200)'
    header_line = [ # header line
                [[' ']] + 
             [
                 [show.style(fontWeight='bold', width='50px'), show.attr(title=f'Token {i}'), t]
                 for i, t in enumerate(prompt_tokens)
             ]
         ]
    layer_logits = [
             # first column
             # [[show.style(fontWeight='bold', width='50px'), f'L{layer}']] +
             [
                 # subsequent columns
                 [color_fn(p[0], fprobs), hover(tok, p, t), show.style(overflowX='hide'), f"{decode_escape(tok, t[0])}{''.join(decode_escape(tok, ft, False))}"]
                 for p, t, ft, fprobs in zip(wordprobs, words, future_words, future_probs)
             ]
        for layer, wordprobs, words, future_words, future_probs in
                zip(range(len(favorite_probs[:, 0])), favorite_probs[:, 0], favorite_tokens[:,0], future_outputs, future_preds)]
    
    # If you want to get the html without showing it, use show.html(...)
    # show(header_line + layer_logits + header_line)
    # print(show.html(header_line + layer_logits + header_line))

    show(layer_logits)
    print(show.html(layer_logits))

In [None]:
prefix = 'Madison square garden is located in the city of New'
#prefix = """As is known in the art, it is frequently desirable to detect and segment an object from a background of other objects and/or from a background of noise. One application, for example, is in MRI where it is desired to segment an anatomical feature of a human patient, such as, for example, a vertebra of the patent. In other cases it would be desirable to segment a moving, deformable anatomical feature such as the heart. In 1988, Osher and Sethian, in a paper entitled “Fronts propagation with curvature dependent speed: Algorithms based on Hamilton-Jacobi formulations” J. of Comp. Phys., 79:12-49, 1988, introduced the level set method, it being noted that a precursor of the level set method was proposed by Dervieux and Thomasset in a paper entitled “A finite element method for the simulation of Raleigh-Taylor instability”. Springer Lect. Notes in Math., 771:145-158, 1979, as a means to implicitly propagate hypersurfaces C(t) in a domain Ω⊂Rn by evolving an appropriate"""
#prefix = "Bill Nelson studied at University"
#context = '���chargedansionFollowing SlipUFC gallery RugOURsn'
#context = "Hello! Could you please tell me more about"
show_future_lens(model, 
                 model.tokenizer, prefix, context_vector, topk=5, in_layer = 13, tgt_in_layer = 13, num_toks_gen=4)

In [None]:
prefix = 'Marty McFly from'
show_future_lens(model, 
                 model.tokenizer, prefix, context_vector, topk=5, in_layer = 13, tgt_in_layer = 13, num_toks_gen=4)

In [None]:
prefix = 'Marty McFly from'
show_future_lens(model, 
                 model.tokenizer, prefix, context_vector, topk=5, in_layer = 13, tgt_in_layer = 13, num_toks_gen=4)

In [None]:
print(context_vector.shape)

In [None]:
prefix = "Bill Nelson studied at University"
show_future_lens(model, 
                 model.tokenizer, prefix, context_vector, topk=5, in_layer = 13, tgt_in_layer = 4, num_toks_gen=3)

In [None]:
show_logit_lens(model, model.tokenizer, "+ + + + + + + + + +", topk=5, color=None)

In [None]:
with torch.no_grad():
    context_vector = context_vector

print("initial_context_vector", context_vector.detach())

with model.generate(device_map='cuda:0', max_new_tokens=3) as generator:
    
    with generator.invoke("Madison square garden is located in the city of New") as invoker:
        #model.transformer.wte.output = context_vector.detach().to('cuda')
        model.transformer.wte.output = context_vector
        embeddings = model.transformer.wte.output.save()

        logits = model.lm_head.output.save()

logits = logits.value.argmax(dim=-1)[0]

print("context vector", context_vector.detach().to('cuda'))
print("embeddings", embeddings.value)



with model.generate(device_map='cuda:0', max_new_tokens=3) as generator:

    with generator.invoke("_ _ _ _ _ _ _ _ _ _") as invoker:

        model.transformer.wte.output = embeddings.value

        future_output_logits = model.lm_head.output.save()

        logits = model.lm_head.output.save()

logits = logits.value.argmax(dim=-1)[0]

print(model.tokenizer.decode(logits[-1:]))

In [None]:
show_logit_lens(model, model.tokenizer, 'Marty McFly from', topk=1)