In [2]:
import torch
from tuned_lens.nn.lenses import TunedLens
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
from tuned_lens.scripts.train_loop_latent import RelatedCollator
from tuned_lens.scripts.lens_latent import RelatedDataset
from tuned_lens.plotting import plot_lens, get_lens_stream
#import ipywidgets as widgets
from plotly import graph_objects as go
import json
import torch as tf
import pandas as pd
from plotly.subplots import make_subplots


ModuleNotFoundError: No module named 'torch'

In [3]:
import os

import torch
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer

windows = False
try:
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  if os.name == 'nt':
    windows = True

bfloat16_available = torch.cuda.get_device_capability()[0] >= 8
floatDtype = torch.bfloat16 if bfloat16_available else torch.float16

In [6]:
device = torch.device('cuda')
loadLens = True
modelName = 'EleutherAI/pythia-1b'
LensName = "..\\outputs\\test-1b\\"
# modelName = 'sshleifer/tiny-gpt2'
# To try a diffrent modle / lens check if the lens is avalible then modify this code
model = AutoModelForCausalLM.from_pretrained(modelName, low_cpu_mem_usage=True, torch_dtype=floatDtype)
model.eval()
model.requires_grad_(False)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(modelName)
if loadLens:
    tuned_lens = TunedLens.load(LensName, map_location=device)
else:
    tuned_lens = TunedLens(model)
tuned_lens = tuned_lens.to(device)
# logit_lens = LogitLens(model)
logit_lens = None


In [17]:
dataDir = "..\\datasets/"
with open(dataDir +"related.json", "r") as f: 
    relatedJson = json.load(f)
        
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side='left', return_special_tokens_mask=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

dataset = RelatedDataset(relatedJson, tokenizer)
collator = RelatedCollator(tokenizer, dataset.pad_to_multiple_of)
dl = DataLoader(dataset, batch_size=1, collate_fn=collator, shuffle=True)





In [20]:
for prompt, response, related in dl:
    len_related_words = len(related['input_ids'])

    print("prompt", prompt["input_ids"].shape)
    
    for i in range(len_related_words): 
        hidden_lps, responseOutput = get_lens_stream(model, tokenizer, tuned_lens, 
                                                     input_ids=prompt["input_ids"], 
                                                     input_att_mask=prompt["attention_mask"], 
                                                     response_ids=response["input_ids"], 
                                                     response_att_mask=response["attention_mask"], 
                                                     related_ids = related['input_ids'], 
                                                     related_att_mask=related['attention_mask'])
        
        print(hidden_lps.layers[0].shape)
        print(responseOutput.logits.shape)
    break
    
    # for each word put it in the modified plot lens to make an appropriate return 
    #response_masked_input = response['input_ids'] * response['attention_mask']
   
            

prompt torch.Size([1, 9])
torch.Size([1, 9, 50304])
torch.Size([1, 3, 50304])


In [None]:

stats = tf.load("latent_stats.pt")
perplexity = stats['response_perplexity'])
thresholds = list(stats['precision'][6].keys())

id_6_precision = list(stats['precision'][6].values())
id_6_recall = list(stats['recall'][6].values())
id_6_f1 = list(stats['f1'][6].values())

id_12_precision = list(stats['precision'][12].values())
id_12_recall = list(stats['recall'][12].values())
id_12_f1 = list(stats['f1'][12].values())


In [None]:
# Precision Per Layer
fig_precision = go.Figure(data=[go.Table(
    header=dict(values=list(df_precision.columns),
                align='left'),
    cells=dict(values=[df_precision[col] for col in df_precision.columns],
               align='left'))
])

fig_precision.show()

In [None]:
fig_id_6 = make_subplots(specs=[[{"secondary_y": True}]])

fig_id_6.add_trace(
    go.Scatter(x=id_6_recall, y=id_6_precision, name="Lens Prediction"),
    secondary_y=False,
)

fig_id_6.update_layout(
    title_text="Layer 6: Precision v Recall"
)


fig_id_6.update_xaxes(title_text="Recall")
fig_id_6.update_yaxes(title_text="Precision", secondary_y=False)
fig_id_6.update_layout(showlegend=True)

fig_id_6.show()

In [None]:
fig_id_6 = make_subplots(specs=[[{"secondary_y": True}]])
fig_id_6.add_trace(
    go.Scatter(x=thresholds, y=id_12_precision, name="Precision"),
    secondary_y=False,
)

fig_id_6.add_trace(
    go.Scatter(x=thresholds, y=id_12_recall, name="Recall"),
    secondary_y=True,
)

fig_id_6.add_trace(
    go.Scatter(x=thresholds, y=id_12_f1, name="f1"),
    secondary_y=True,
)

fig_id_6.update_layout(
    title_text="Layer 6: Prediction Threshold v Precision, Recall and F1"
)

fig_id_6.update_xaxes(title_text="Prediction Threshold")


fig_id_6.show()

In [None]:
fig_id_12 = make_subplots(specs=[[{"secondary_y": True}]])

fig_id_12.add_trace(
    go.Scatter(x=id_12_recall, y=id_12_precision, name="Lens Prediction"),
    secondary_y=False,
)


fig_id_12.update_layout(
    title_text="Layer 12 Lens: Precision v Recall and F1"
)


fig_id_12.update_xaxes(title_text="Recall")


fig_id_12.update_yaxes(title_text="Precision", secondary_y=False)
fig_id_12.update_layout(showlegend=True)

fig_id_12.show()

In [None]:
fig_id_12 = make_subplots(specs=[[{"secondary_y": True}]])

fig_id_12.add_trace(
    go.Scatter(x=thresholds, y=id_12_precision, name="Precision"),
    secondary_y=False,
)

fig_id_12.add_trace(
    go.Scatter(x=thresholds, y=id_12_recall, name="Recall"),
    secondary_y=True,
)

fig_id_12.add_trace(
    go.Scatter(x=thresholds, y=id_12_f1, name="f1"),
    secondary_y=True,
)

fig_id_12.update_layout(
    title_text="Layer 12: Prediction Threshold v Precision, Recall and F1"
)

fig_id_12.update_xaxes(title_text="Prediction Threshold")


fig_id_12.show()