In [6]:
import torch, re, json
import pandas as pd
from pathlib import Path
from src.main import main
from src.prepare_latents import compute_chunks
from src.utils import device
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from IPython.display import clear_output
import torchmetrics.functional as tmf

In [14]:
from semantic_decoding.decoding.GPT import GPT
from semantic_decoding.decoding.LanguageModel import LanguageModel
from semantic_decoding.decoding.Decoder import Decoder

data_lm = Path("data/data_lm")
with open(data_lm / "perceived" / "vocab.json", "r") as f:
    gpt_vocab = json.load(f)
with open(data_lm / "decoder_vocab.json", "r") as f:
    decoder_vocab = json.load(f)
gpt = GPT(path = data_lm / "perceived" / "model", vocab=gpt_vocab, device=device)
lm = LanguageModel(gpt, decoder_vocab, nuc_mass=0.9, nuc_ratio=0.1)
decoder = Decoder([1, 2, 3], 5)

In [17]:
lm.beam_propose(decoder.beam, 6)

[(['i', 'we', 'she', 'he', 'they', 'it'],
  array([-1.79175947, -1.79175947, -1.79175947, -1.79175947, -1.79175947,
         -1.79175947]))]

In [2]:
datasets = ["li2022_EN_SS_trimmed_mean"]
subjects = None

In [2]:
datasets = ["lebel2023"]
subjects = {"lebel2023": ["UTS03"]}

In [3]:
config = {
    "datasets": datasets,
    "subjects": subjects,
    "model": "bert-base-uncased",
    "decoder": "brain_decoder",
    "loss": "mixco",
    "valid_ratio": 0.1,
    "test_ratio": 0.1,
    "context_length": 6,
    "lag": 3,
    "smooth": 6,
    "stack": 0,
    "dropout": 0.7,
    "patience": 20,
    "lr": 1e-4,
    "weight_decay": 1e-6,
    "batch_size": 1,
    "temperature": 0.05,
}

In [4]:
gpt2 = pipeline("text-generation", model="gpt2", device=device)

# Fetch data and decoder

In [5]:
df_train, df_valid, df_test = main(return_data=True, caching=False, **config)
_, decoder = main(**config)
decoder = decoder.to(device)
clear_output()

In [None]:
df_train.run

In [None]:
chunks = []
df = pd.concat([df_train[["dataset", "run"]], df_valid[["dataset", "run"]], df_test[["dataset", "run"]]]).drop_duplicates()
for _, row in df.iterrows():
    if row.dataset == "lebel2023":
        textgrid_path = f"data/lebel2023/derivative/TextGrids/{row.run}.TextGrid"
    chunks.append([row.dataset, row.run, compute_chunks(textgrid_path, 2, 0)])
chunks = pd.DataFrame(chunks, columns=["dataset", "run", "text"])
chunks["num_tokens"] = chunks.text.apply(lambda x: len(gpt2.tokenizer.encode(chunk))

In [None]:
df_train = df_train.drop(columns=["text"]).merge(chunks)
df_valid = df_valid.drop(columns=["text"]).merge(chunks)
df_test = df_test.drop(columns=["text"]).merge(chunks)

In [None]:
num_tokens = []
for text in chunks.text:
    for chunk in text:
        num_tokens.append(len(gpt2.tokenizer.encode(chunk)))

In [None]:
import plotly.express as px
px.histogram(num_tokens, nbins=50, marginal="box", title="Number of tokens per chunk")

# Decode

In [None]:
model = SentenceTransformer(config["model"], device=device)
clear_output()

In [None]:
_, row = next(iter(df_train.iterrows()))
with torch.no_grad():
    predicted_latents = decoder(decoder.projector[row.dataset + "/" + row.subject](row.X.to(device)))

## Bert generation

In [None]:
prompt = ""
generated_chunks_lengths = []
current_crop_length = 0
torch.cuda.empty_cache()
for i in range(len(row.text))[:10]:
    generated_sentences = gpt2(
        prompt,
        max_new_tokens=8,
        num_return_sequences=1000,
        pad_token_id=50256,
        top_k=0,
        top_p=0.6,
        temperature=1.8,
        repetition_penalty=1.8,
    )
    generated_sentences = [re.sub(r'\n+', ' ', s["generated_text"]) for s in generated_sentences]
    if i > config["context_length"]:
        current_crop_length = generated_chunks_lengths[i - config["context_length"] - 1]
    generated_sentences_cropped = [s[current_crop_length:] for s in generated_sentences]
    embeddings = model.encode(generated_sentences_cropped, convert_to_numpy=False, convert_to_tensor=True)
    best_sentence_index = tmf.pairwise_cosine_similarity(predicted_latents[[i]], embeddings).argmax()
    best_sentence = generated_sentences[best_sentence_index]
    generated_chunks_lengths.append(len(best_sentence) - len(prompt))
    print("Generated: ", best_sentence[len(prompt):])
    print("Correct: ", row.text[i])
    print(i, current_crop_length)
    prompt = best_sentence