In [18]:
import torch, re, json
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from src.main import main
from src.prepare_latents import compute_chunks
from src.utils import device, progress
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from IPython.display import clear_output
import torchmetrics.functional as tmf

In [2]:
datasets = ["li2022_EN_SS_trimmed_mean"]
subjects = None

In [None]:
datasets = ["lebel2023"]
subjects = {"lebel2023": ["UTS03"]}

In [4]:
config = {
    "datasets": datasets,
    "subjects": subjects,
    "model": "bert-base-uncased",
    "decoder": "brain_decoder",
    "loss": "mixco",
    "valid_ratio": 0.1,
    "test_ratio": 0.1,
    "context_length": 6,
    "lag": 3,
    "smooth": 6,
    "stack": 0,
    "dropout": 0.7,
    "patience": 20,
    "lr": 1e-4,
    "weight_decay": 1e-6,
    "batch_size": 1,
    "temperature": 0.05,
    # "top_encoding_voxels": 5000,
}

In [None]:
gpt2 = pipeline("text-generation", model="gpt2", device=device)

# Fetch data and decoder

In [5]:
df_train, df_valid, df_test = main(
    return_data=True, cache=False, wandb_mode="disabled", **config
)
# _, decoder = main(wandb_mode="disabled", **config)
# decoder = decoder.to(device)
clear_output()

In [None]:
row = df_train[df_train.run == "wheretheressmoke"].iloc[0]
with torch.no_grad():
    predicted_latents = decoder(decoder.projector[row.dataset + "/" + row.subject](row.X.to(device)))

# Encode

In [None]:
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score

In [None]:
X = np.concatenate(tuple(df_train.X))
Y = np.concatenate(tuple(df_train.Y))
model = Ridge().fit(Y, X)

In [None]:
X = np.concatenate(tuple(df_valid.X))
Y = np.concatenate(tuple(df_valid.Y))
r2 = r2_score(X, model.predict(Y), multioutput="raw_values")

In [None]:
r2.argsort()[-1000000:].shape

In [None]:
import plotly.express as px
px.histogram(r2[r2 > 0], nbins=100)

# Decode simple

In [None]:
model = SentenceTransformer(config["model"], device=device)
clear_output()

In [None]:
chunks = set(df_train.drop_duplicates(["dataset", "run"]).text.sum())
chunks |= set(df_valid.drop_duplicates(["dataset", "run"]).text.sum())
chunks = pd.Series(list(chunks))

In [None]:
row = df_train.iloc[1]
with torch.no_grad():
    predicted_latents = decoder(decoder.projector[row.dataset + "/" + row.subject](row.X.to(device)))

In [None]:
decoded_chunks = []
with progress:
    task = progress.add_task(f"Decoding {row.run}", total=len(row.text))
    for i in range(len(row.X)):
        context_sentence = " ".join(decoded_chunks[-config["context_length"]:])
        continuations = context_sentence + " " + chunks
        continuations_latents = model.encode(continuations, convert_to_numpy=False, convert_to_tensor=True)
        scores = tmf.pairwise_cosine_similarity(predicted_latents[[i]], continuations_latents)[0].cpu()
        best_continuation = chunks[scores.argmax().item()]
        decoded_chunks.append(best_continuation)
        correct_chunks = row.text[max(0, i-config["context_length"]):i+1]
        predicted_chunks = decoded_chunks[-config["context_length"]:]
        for j, (correct, predicted) in enumerate(zip(correct_chunks, predicted_chunks)):
            if correct == predicted:
                correct_chunks[j] = f"\033[92m{correct}\033[0m"
                predicted_chunks[j] = f"\033[92m{correct}\033[0m"
        
        print(f"Chunk {i+1}/{len(row.X)}")
        print("Correct  :", " \033[91m|\033[0m ".join(correct_chunks))
        print("Predicted:", " \033[91m|\033[0m ".join(predicted_chunks))
        progress.update(task, advance=1)

# Decode Tang

In [None]:
from semantic_decoding.decoding.GPT import GPT
from semantic_decoding.decoding.LanguageModel import LanguageModel
from semantic_decoding.decoding.Decoder import Decoder, Hypothesis

data_lm = Path("data/data_lm")
with open(data_lm / "perceived" / "vocab.json", "r") as f:
    gpt_vocab = json.load(f)
with open(data_lm / "decoder_vocab.json", "r") as f:
    decoder_vocab = json.load(f)
gpt = GPT(path = data_lm / "perceived" / "model", vocab=gpt_vocab, device=device)
lm = LanguageModel(gpt, decoder_vocab, nuc_mass=0.9, nuc_ratio=0.1)

In [None]:
gpt_decoder = Decoder(word_times=range(sum(row.num_words)), beam_width=200)

In [None]:
model = SentenceTransformer(config["model"], device=device)
clear_output()

In [None]:
with tqdm(total=sum(row.num_words)) as pbar:
    for i, num_words in enumerate(row.num_words):
        # if i > 0:
        #     print("\033[F\033[F", end='')
        pbar.set_description(f"Chunk {i+1} / {len(row.num_words)}")
        context_window = sum(row.num_words[max(0, i-config["context_length"]):i])
        for _ in range(num_words):
            beam_nucs = lm.beam_propose(gpt_decoder.beam, context_window)
            for c, (hyp, nextensions) in enumerate(gpt_decoder.get_hypotheses()):
                nuc, logprobs = beam_nucs[c]
                if len(nuc) < 1: continue
                extend_words = [' '.join(hyp.words[-context_window:] + [x]) for x in nuc]
                embs = model.encode(extend_words, convert_to_numpy=False, convert_to_tensor=True)
                scores = tmf.pairwise_cosine_similarity(predicted_latents[[i]], embs)[0].cpu()
                embs = [None] * len(embs)
                local_extensions = [Hypothesis(parent = hyp, extension = x) for x in zip(nuc, scores, embs)]
                gpt_decoder.add_extensions(local_extensions, scores, nextensions)
            gpt_decoder.extend(verbose = False)
            context_window += 1
            pbar.update(1)
        best_hyp = np.argmax([sum(hyp.logprobs) for hyp in gpt_decoder.beam])
        print("Correct chunk:", row.text[i])
        print("Best hypothesis:", " ".join(gpt_decoder.beam[best_hyp].words[-num_words:]))
        print()

# Decode

In [None]:
model = SentenceTransformer(config["model"], device=device)
clear_output()

In [None]:
_, row = next(iter(df_train.iterrows()))
with torch.no_grad():
    predicted_latents = decoder(decoder.projector[row.dataset + "/" + row.subject](row.X.to(device)))

## Bert generation

In [None]:
prompt = ""
generated_chunks_lengths = []
current_crop_length = 0
torch.cuda.empty_cache()
for i in range(len(row.text))[:10]:
    generated_sentences = gpt2(
        prompt,
        max_new_tokens=8,
        num_return_sequences=1000,
        pad_token_id=50256,
        top_k=0,
        top_p=0.6,
        temperature=1.8,
        repetition_penalty=1.8,
    )
    generated_sentences = [re.sub(r'\n+', ' ', s["generated_text"]) for s in generated_sentences]
    if i > config["context_length"]:
        current_crop_length = generated_chunks_lengths[i - config["context_length"] - 1]
    generated_sentences_cropped = [s[current_crop_length:] for s in generated_sentences]
    embeddings = model.encode(generated_sentences_cropped, convert_to_numpy=False, convert_to_tensor=True)
    best_sentence_index = tmf.pairwise_cosine_similarity(predicted_latents[[i]], embeddings).argmax()
    best_sentence = generated_sentences[best_sentence_index]
    generated_chunks_lengths.append(len(best_sentence) - len(prompt))
    print("Generated: ", best_sentence[len(prompt):])
    print("Correct: ", row.text[i])
    print(i, current_crop_length)
    prompt = best_sentence

# BERT decoder

In [None]:
from transformers import EncoderDecoderModel, AutoTokenizer

model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")

In [None]:
ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
input_ids = tokenizer(ARTICLE_TO_SUMMARIZE, return_tensors="pt").input_ids

# autoregressively generate summary (uses greedy decoding by default)
generated_ids = model.generate(input_ids)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [None]:
np.array([0, 3, 1, 2]).argsort()[::-1].argsort().argmax()

In [None]:
embeddings = model.encoder(input_ids)