# Fluxentropy
By [Green](https://x.com/myainotez) and [Blue](https://x.com/tensorqt) knights.

## Initialize

In [None]:
import os
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

import plotly.graph_objects as go
from plotly.subplots import make_subplots


if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")
if device == torch.device("cuda"):
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("high")


class EntropixModel:
    def __init__(self, model, tokenizer, seed: int = 1337, dtype: torch.dtype = torch.bfloat16):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.dtype = dtype

    def entropy_characterize(self, input_strings: list, max_length: int = 512):
        """
        Computes the entropy and varentropy per token for a batch of input strings.

        Args:
            input_strings (list): List of input strings to analyze.
            max_length (int): Maximum sequence length for tokenization.

        Returns:
            dict: A dictionary containing entropy and varentropy tensors.
        """
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # Tokenize input strings with padding and truncation
        encodings = self.tokenizer(
            input_strings,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )

        input_ids = encodings["input_ids"].to(device)
        attention_mask = encodings["attention_mask"].to(device)
        batch_size, seq_len = input_ids.shape

        with torch.no_grad():
            # Forward pass
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits  # Shape: (batch_size, seq_len, vocab_size)

            # Compute entropy and varentropy per token
            entropy_list = []
            varentropy_list = []

            for i in range(seq_len):
                logits_i = logits[:, i, :]  # Logits for position i
                log_probs = torch.log_softmax(logits_i, dim=-1)
                probs = torch.exp(log_probs)

                # Compute entropy
                entropy = -torch.sum(probs * log_probs, dim=-1)  # Shape: (batch_size,)

                # Compute varentropy
                varentropy = torch.sum(probs * (log_probs + entropy.unsqueeze(-1)) ** 2, dim=-1)

                entropy_list.append(entropy)
                varentropy_list.append(varentropy)

            # Stack entropy and varentropy for all positions
            entropy_tensor = torch.stack(entropy_list, dim=1)  # Shape: (batch_size, seq_len)
            varentropy_tensor = torch.stack(varentropy_list, dim=1)  # Shape: (batch_size, seq_len)

            # Mask padding positions
            entropy_tensor = entropy_tensor * attention_mask
            varentropy_tensor = varentropy_tensor * attention_mask

        # Prepare results
        results = {
            "entropy": entropy_tensor.cpu(),
            "varentropy": varentropy_tensor.cpu(),
            "input_strings": input_strings,
            "tokens": [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids],
            "attention_mask": attention_mask.cpu()
        }

        return results

    def entropy_characterize_full(self, input_strings: list, max_length: int = 512):
        """
        Computes the entropy and varentropy for the entire input strings at once,
        based on the next token prediction after the full input.

        Args:
            input_strings (list): List of input strings to analyze.
            max_length (int): Maximum sequence length for tokenization.

        Returns:
            dict: A dictionary containing entropy and varentropy for each input string.
        """
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # Tokenize input strings without adding special tokens
        encodings = self.tokenizer(
            input_strings,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
            add_special_tokens=False
        )

        input_ids = encodings["input_ids"].to(device)
        attention_mask = encodings["attention_mask"].to(device)
        batch_size, seq_len = input_ids.shape

        with torch.no_grad():
            # Get logits for the next token after the full input
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits  # Shape: (batch_size, seq_len, vocab_size)

            # Get the logits for the next token prediction
            next_token_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)

            # Compute log probabilities
            log_probs = torch.log_softmax(next_token_logits, dim=-1)
            probs = torch.exp(log_probs)

            # Compute entropy
            entropy = -torch.sum(probs * log_probs, dim=-1)  # Shape: (batch_size,)

            # Compute varentropy
            varentropy = torch.sum(probs * (log_probs + entropy.unsqueeze(-1)) ** 2, dim=-1)

        # Prepare results
        results = {
            "entropy": entropy.cpu(),
            "varentropy": varentropy.cpu(),
            "input_strings": input_strings,
            "tokens": [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids],
            "attention_mask": attention_mask.cpu()
        }

        return results

    def visualize_results(self, results, title=None, height=800):
        """
        Creates interactive visualizations for entropy and varentropy results using plotly.

        Args:
            results (dict): Dictionary containing entropy_characterize results
            title (str, optional): Title for the visualization
            height (int, optional): Height of the plot in pixels

        Returns:
            plotly.graph_objects.Figure: Interactive figure with entropy and varentropy visualizations
        """
        # Determine if results are per-token or full-string
        is_full_string = len(results['entropy'].shape) == 1

        if is_full_string:
            # Visualization for full-string entropy
            fig = go.Figure()

            fig.add_trace(
                go.Bar(
                    x=results['input_strings'],
                    y=results['entropy'],
                    text=results['entropy'],
                    textposition='auto',
                    name='Entropy'
                )
            )

            fig.add_trace(
                go.Bar(
                    x=results['input_strings'],
                    y=results['varentropy'],
                    text=results['varentropy'],
                    textposition='auto',
                    name='Varentropy'
                )
            )

            fig.update_layout(
                title=title or 'Entropy and Varentropy Analysis (Full String)',
                xaxis_title='Input Strings',
                yaxis_title='Value',
                barmode='group',
                height=height
            )

        else:
            # Visualization for per-token entropy
            fig = make_subplots(
                rows=3, cols=1,
                subplot_titles=('Entropy Over Time', 'Varentropy Over Time', 'Token-wise Analysis'),
                vertical_spacing=0.1,
                row_heights=[0.35, 0.35, 0.3]
            )

            for batch_idx, input_string in enumerate(results['input_strings']):
                # Get masked values using attention mask
                mask = results['attention_mask'][batch_idx]
                seq_len = mask.sum().item()

                # Get tokens for this sequence
                tokens = results['tokens'][batch_idx][:seq_len]

                # Extract entropy and varentropy values
                entropy_values = results['entropy'][batch_idx][:seq_len].numpy()
                varentropy_values = results['varentropy'][batch_idx][:seq_len].numpy()

                # Create position indices
                positions = np.arange(seq_len)

                # Add entropy trace
                fig.add_trace(
                    go.Scatter(
                        x=positions,
                        y=entropy_values,
                        mode='lines+markers',
                        name=f'Entropy (String {batch_idx + 1})',
                        hovertemplate='Position: %{x}<br>Entropy: %{y:.3f}<extra></extra>'
                    ),
                    row=1, col=1
                )

                # Add varentropy trace
                fig.add_trace(
                    go.Scatter(
                        x=positions,
                        y=varentropy_values,
                        mode='lines+markers',
                        name=f'Varentropy (String {batch_idx + 1})',
                        hovertemplate='Position: %{x}<br>Varentropy: %{y:.3f}<extra></extra>'
                    ),
                    row=2, col=1
                )

                # Add token-wise heatmap
                fig.add_trace(
                    go.Heatmap(
                        z=[entropy_values, varentropy_values],
                        x=tokens,
                        y=['Entropy', 'Varentropy'],
                        colorscale='Viridis',
                        showscale=True,
                        hoverongaps=False,
                        hovertemplate='Token: %{x}<br>Metric: %{y}<br>Value: %{z:.3f}<extra></extra>'
                    ),
                    row=3, col=1
                )

            # Update layout
            fig.update_layout(
                height=height,
                showlegend=True,
                title=title or 'Entropy and Varentropy Analysis',
                hovermode='closest'
            )

            # Update axes labels
            fig.update_xaxes(title_text='Position', row=1, col=1)
            fig.update_xaxes(title_text='Position', row=2, col=1)
            fig.update_xaxes(title_text='Tokens', row=3, col=1)

            fig.update_yaxes(title_text='Entropy', row=1, col=1)
            fig.update_yaxes(title_text='Varentropy', row=2, col=1)

        return fig


if __name__ == "__main__":
    seed = 1337
    torch.manual_seed(seed=seed)

    model_id = "meta-llama/Llama-3.2-1B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    entropix_model = EntropixModel(model, tokenizer, seed=seed)


## Run

In [None]:
    # Example inputs
    input_strings = [
        "The quick brown fox jumps over the lazy dog.",  # Classic pangram
        "In quantum mechanics, particles can exist in multiple states simultaneously.",  # Scientific
        "她站在窗前，望着远方的山峰。",  # Chinese (Looking at distant mountains)
        "To be, or not to be, that is the question.",  # Literary/Shakespeare
        "The cryptocurrency market experienced significant volatility today.",  # Financial news
        "Je pense, donc je suis.",  # French philosophy (Descartes)
        "🌟 Dancing under the moonlight, spirits high and hearts light. 🌙",  # Emojis and poetic
        "SELECT * FROM users WHERE age > 18;",  # SQL code
        "The neural network achieved 98.5% accuracy on the test dataset.",  # AI/ML
        "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",  # Latin placeholder
        "Breaking: Major breakthrough in fusion energy announced today!",  # News headline
        "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",  # Python code
        "Step 1: Preheat oven to 350°F. Step 2: Mix ingredients thoroughly.",  # Recipe instructions
        "Once upon a time, in a galaxy far, far away...",  # Story opening
        "Error 404: Page not found. Please check the URL and try again.",  # Technical error
        "Climate change threatens biodiversity in coral reef ecosystems.",  # Environmental
        "おはようございます、今日はいい天気ですね。",  # Japanese (Good morning, nice weather today)
        "1234567890 !@#$%^&*()_+ <>?:\"{}|",  # Numbers and special characters
        "URGENT: Meeting rescheduled to 3PM EST - All hands required",  # Business communication
        "The composition of Bach's fugues demonstrates mathematical precision.",  # Music analysis
        "Das Leben ist wie ein Fahrrad. Man muss sich vorwärts bewegen.",  # German (Einstein quote)
        "for i in range(len(array)): if array[i] > max_val: max_val = array[i]",  # More Python code
        "CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(255));",  # SQL DDL
        "La vita è bella quando si vive con passione.",  # Italian (Life is beautiful...)
        "RT @SpaceX: Successful launch of Starship prototype #42! 🚀",  # Social media
        "В тихом омуте черти водятся.",  # Russian proverb
        "async function fetchData() { const response = await fetch(url); }",  # JavaScript async
        "🎮 Level Up! You've earned 1000 XP and unlocked new achievements! 🏆",  # Gaming with emojis
        "<!DOCTYPE html><html><head><title>Hello World</title></head></html>",  # HTML
        "Hola mundo, ¿cómo estás hoy?",  # Spanish greeting
        "import numpy as np; X = np.array([[1, 2], [3, 4]])",  # Scientific Python
        "Breaking News: Artificial Intelligence Achieves New Milestone in Protein Folding",  # Science news
        "public class HelloWorld { public static void main(String[] args) {} }",  # Java
        "The mitochondria is the powerhouse of the cell.",  # Biology
        "git commit -m \"Fix: resolve memory leak in main loop\"",  # Git command
        "अतिथि देवो भव:",  # Sanskrit (Guest is God)
        "try { throw new Error('Test'); } catch (e) { console.log(e); }",  # JavaScript error handling
        "Dans les champs de l'observation, le hasard ne favorise que les esprits préparés.",  # French (Pasteur)
        "docker run -d -p 80:80 nginx:latest",  # Docker command
        "While(true) { System.out.println(\"Hello, World!\"); }",  # Infinite loop
        "kubectl get pods -n kubernetes-dashboard",  # Kubernetes command
        "Χαίρετε! Πώς είστε σήμερα;",  # Greek greeting
        "const handleSubmit = (e) => { e.preventDefault(); setState(newValue); };",  # React code
        "مرحبا بالعالم",  # Arabic (Hello World)
        "SELECT COUNT(*) OVER (PARTITION BY department) FROM employees;",  # Advanced SQL
        "pip install tensorflow==2.8.0 torch==2.0.0 transformers==4.28.0",  # Package installation
        "한글은 세상에서 가장 과학적인 글자입니다.",  # Korean (Hangul is the most scientific writing system)
        "{ \"name\": \"John\", \"age\": 30, \"city\": \"New York\" }",  # JSON data
        "CRITICAL: Memory usage exceeded 90% threshold at 02:45:30 UTC",  # System log
        "@media (max-width: 768px) { .container { flex-direction: column; } }",  # CSS media query
        "Fibonacci sequence: 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144...",  # Mathematical sequence
        "$ curl -X POST https://api.example.com/v1/data -H \"Content-Type: application/json\"",  # CURL command
        "WARNING: Certificate expires in 7 days. Please renew SSL certificate.",  # Security warning
        "sudo apt-get update && sudo apt-get upgrade -y",  # Linux command
        "print(f\"Current temperature: {temp:.2f}°C at {time:%H:%M:%S}\")",  # Python f-string
        "Революция в квантовых вычислениях: создан 1000-кубитный процессор",  # Russian tech news
        "interface User { id: string; name: string; age: number; }",  # TypeScript interface
        "O Romeo, Romeo! wherefore art thou Romeo?",  # Shakespeare quote
        "Exception in thread \"main\" java.lang.NullPointerException at Main.java:42",  # Java error
        "今日は富士山に登りました。頂上からの景色は素晴らしかったです。"  # Japanese (Climbing Mt. Fuji)
    ]

    results_per_token = entropix_model.entropy_characterize(input_strings)
    results_full = entropix_model.entropy_characterize_full(input_strings)

    fig_per_token = entropix_model.visualize_results(results_per_token, title="Entropy Analysis (Per Token)")
    fig_full = entropix_model.visualize_results(results_full, title="Entropy Analysis (Full String)")

    # Save or display the figures
    # fig_per_token.write_html("entropy_analysis_per_token.html")
    # fig_full.write_html("entropy_analysis_full.html")


In [None]:
fig_full

In [None]:
fig_per_token