<a href="https://colab.research.google.com/github/M1croZavr/compression_horizon/blob/task%2Fhybrid_loss/notebooks/Compression_hybrid_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import subprocess

import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
try:
    subprocess.check_output(["nvidia-smi"], shell=True)
except subprocess.CalledProcessError:
    print("nvidia-smi is not available")

# Experiments launching

In [None]:
# %load_ext tensorboard
%reload_ext tensorboard
# %tensorboard --logdir=/content/compression_horizon/artifacts/experiments/common_loss
%tensorboard --logdir=/content/compression_horizon/artifacts/experiments/hybrid_loss

In [None]:
!git clone --branch task/hybrid_loss https://github.com/M1croZavr/compression_horizon.git

## Common loss launches

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 4 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 32 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 128 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100

In [None]:
!cp -R /content/compression_horizon/artifacts/experiments/common_loss ./drive/MyDrive/compression_horizon/

## Hybrid loss launches

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 4 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100 --loss_type l2 --hybrid_alpha 0.2 --num_alignment_layers 1

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 32 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100 --loss_type l2 --hybrid_alpha 0.2 --num_alignment_layers 1

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 128 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100 --loss_type l2 --hybrid_alpha 0.2 --num_alignment_layers 1

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 128 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100 --loss_type cosine --hybrid_alpha 0.2 --num_alignment_layers 1

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 128 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100 --loss_type l1 --hybrid_alpha 0.2 --num_alignment_layers 1

In [None]:
!cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 128 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100 --loss_type cosine --hybrid_alpha 0.2 --num_alignment_layers 3

In [None]:
# !cd ./compression_horizon/; uv run python scripts/hybrid_loss.py --model_checkpoint HuggingFaceTB/SmolLM2-1.7B --learning_rate 0.01 --max_sequence_length 128 --number_of_mem_tokens 1 --max_optimization_steps_per_sample 1000 --warmup_steps 100 --loss_type cosine --hybrid_alpha 0.3 --num_alignment_layers 5

In [None]:
!cp ...

# CE comparison

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running device:", device)

In [None]:
# checkpoint = "HuggingFaceTB/SmolLM2-135M"
checkpoint = "HuggingFaceTB/SmolLM2-1.7B"
model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype=torch.float32).to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token});

In [None]:
# Exactly sample indexed 0 as we trained on it
raw_dataset = load_dataset("mrsndmn/pg19", split="test")
train_dataset = raw_dataset.select(range(1))
example = tokenizer(train_dataset[0]["text"], truncation=True, max_length=4, return_tensors="pt")
input_ids = example["input_ids"].to(device)
attention_mask = example["attention_mask"].to(device)

In [None]:
result = load_from_disk(
    "/content/drive/MyDrive/compression_horizon/l2_None_0_4_666004aa-d739-42a9-8f39-15a11466c4f8/compressed_prefixes"
)
compressed_embeddings = torch.FloatTensor(result[0]["embedding"]).unsqueeze(dim=0).to(device)
with torch.no_grad():
    sequence_embeddings = model.model.embed_tokens(input_ids)
united_embeddings = torch.cat(
    (compressed_embeddings, sequence_embeddings),
    dim=1,
)
united_attention_mask = torch.cat(
    (torch.tensor([[1]]).to(device), attention_mask),
    dim=1,
)

In [None]:
with torch.no_grad():
    outputs = model(
        inputs_embeds=united_embeddings,
        attention_mask=attention_mask,
    )

In [None]:
torch.nn.functional.cross_entropy(outputs.logits[:, :-1, :].flatten(0, 1), input_ids.flatten()).item()

# Generation outside the compressed sequence

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running device:", device)

In [None]:
# checkpoint = "HuggingFaceTB/SmolLM2-135M"
checkpoint = "HuggingFaceTB/SmolLM2-1.7B"
model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype=torch.float32).to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token});

In [None]:
result = load_from_disk(
    "/content/drive/MyDrive/compression_horizon/l2_None_0_4_666004aa-d739-42a9-8f39-15a11466c4f8/compressed_prefixes"
)
compressed_embeddings = torch.FloatTensor(result[0]["embedding"]).unsqueeze(dim=0).to(device)

In [None]:
@torch.no_grad()
def generate_from_compression(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    compressed_embeddings: torch.Tensor,  # [1, mem, hidden]
    max_new_tokens: int,
    num_return_sequences: int = 1,
) -> list[str]:
    # Cast to the same device
    device = compressed_embeddings.device
    if model.device != device:
        model = model.to(device)
    model.eval()

    # Add pad_token to a tokenizer
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
    eos_token_id = tokenizer.eos_token_id

    # Prepare batch of prefixes
    if num_return_sequences > 1:
        compressed_embeddings = compressed_embeddings.expand(num_return_sequences, -1, -1)  # [batch, mem, hidden]
    batch_size, num_compression_tokens, hidden_size = compressed_embeddings.shape

    # Container for generated token ids
    generated_token_ids = torch.empty((batch_size, 0), dtype=torch.long, device=device)
    # Model's input embedding layer
    input_embeddings = model.get_input_embeddings()

    for _ in range(max_new_tokens):
        # Embeddings
        if generated_token_ids.size(1) == 0:
            generated_embeddings = torch.empty(batch_size, 0, hidden_size, device=device)
        else:
            generated_embeddings = input_embeddings(generated_token_ids)  # [batch, sequence, hidden]
        united_token_embeddings = torch.cat(
            [compressed_embeddings, generated_embeddings], dim=1
        )  # [batch, mem + sequence, hidden]

        # Attention mask
        compression_attention_mask = torch.ones((batch_size, num_compression_tokens), dtype=torch.long, device=device)
        attention_mask = torch.ones((batch_size, generated_embeddings.size(1)), dtype=torch.long, device=device)
        united_attention_mask = torch.cat((compression_attention_mask, attention_mask), dim=1)  # [batch, mem + sequence]

        outputs = model(inputs_embeds=united_token_embeddings, attention_mask=united_attention_mask)
        logits = outputs.logits[:, -1, :]  # [batch, vocabulary]

        next_token_ids = torch.argmax(logits, dim=-1)  # [batch]

        # If a sequence already reached EOS token leave EOS to the end
        if eos_token_id is not None:
            if generated_token_ids.size(1) > 0:
                reached_eos = generated_token_ids[:, -1].eq(eos_token_id)
                next_token_ids = torch.where(reached_eos, torch.full_like(next_token_ids, eos_token_id), next_token_ids)

        generated_token_ids = torch.cat([generated_token_ids, next_token_ids.unsqueeze(-1)], dim=-1)

        # Stop early if all sequences just produced eos and had eos previously
        if eos_token_id is not None and torch.all(next_token_ids.eq(eos_token_id)):
            break

    texts = tokenizer.batch_decode(generated_token_ids, skip_special_tokens=True)
    return texts