# Fluxentropy
By [Green](https://x.com/myainotez) and [Blue](https://x.com/tensorqt) knights.

## Initialize

In [1]:
import os
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from google.colab import userdata

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from huggingface_hub import notebook_login
notebook_login(userdata.get('HF_TOKEN'))


# Device configuration
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")
if device.type == "cuda":
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    torch.cuda.empty_cache()
torch.set_float32_matmul_precision("high")


class EntropixModel:
    def __init__(self, model, tokenizer, seed: int = 1337, dtype: torch.dtype = torch.bfloat16):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.dtype = dtype
        self.seed = seed
        torch.manual_seed(seed)

    def entropy_characterize(
        self,
        input_strings: list,
        config: dict,
        max_length: int = 512
    ):
        """
        Computes specified characteristics based on the configuration for a batch of input strings.

        Args:
            input_strings (list): List of input strings to analyze.
            config (dict): Configuration dictionary specifying characteristics and mechanism.
                Example:
                {
                    "mechanism": "per_token",  # or "per_string"
                    "compute_entropy": True,
                    "compute_varentropy": False,
                    "compute_additional_metric": True,
                    ...
                }
            max_length (int): Maximum sequence length for tokenization.

        Returns:
            dict or tensor/list: Depending on the configuration, returns a dictionary of characteristics
                                 or a list/tensor with characteristics per token.
        """
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # Tokenize input strings with padding and truncation
        encodings = self.tokenizer(
            input_strings,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )

        input_ids = encodings["input_ids"].to(device)
        padding_mask = encodings["attention_mask"].to(device)
        batch_size, seq_len = input_ids.shape

        results = {
            "input_strings": input_strings,
            "tokens": [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids],
            "attention_mask": padding_mask.cpu()
        }

        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=padding_mask)
            logits = outputs.logits  # Shape: (batch_size, seq_len, vocab_size)

            mechanism = config.get("mechanism", "per_token")
            compute_entropy = config.get("compute_entropy", False)
            compute_varentropy = config.get("compute_varentropy", False)
            # Add more characteristics as needed

            if mechanism == "per_token":
                if compute_entropy or compute_varentropy:
                    entropy_list = []
                    varentropy_list = []

                    for i in range(seq_len):
                        logits_i = logits[:, i, :]  # Logits for position i
                        log_probs = torch.log_softmax(logits_i, dim=-1)
                        probs = torch.exp(log_probs)

                        if compute_entropy:
                            entropy = -torch.sum(probs * log_probs, dim=-1)  # Shape: (batch_size,)
                            entropy_list.append(entropy)
                        if compute_varentropy:
                            if compute_entropy:
                                entropy_unsqueezed = entropy.unsqueeze(-1)
                            else:
                                entropy_unsqueezed = torch.sum(probs * log_probs, dim=-1).unsqueeze(-1)
                            varentropy = torch.sum(probs * (log_probs + entropy_unsqueezed) ** 2, dim=-1)
                            varentropy_list.append(varentropy)

                    if compute_entropy:
                        entropy_tensor = torch.stack(entropy_list, dim=1)  # Shape: (batch_size, seq_len)
                        results["entropy"] = (entropy_tensor * padding_mask).cpu()

                    if compute_varentropy:
                        varentropy_tensor = torch.stack(varentropy_list, dim=1)  # Shape: (batch_size, seq_len)
                        results["varentropy"] = (varentropy_tensor * padding_mask).cpu()

                # Add additional characteristics computations here

            elif mechanism == "per_string":
                if compute_entropy or compute_varentropy:
                    # Get logits for the next token after the full input
                    next_token_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
                    log_probs = torch.log_softmax(next_token_logits, dim=-1)
                    probs = torch.exp(log_probs)

                    if compute_entropy:
                        entropy = -torch.sum(probs * log_probs, dim=-1)  # Shape: (batch_size,)
                        results["entropy"] = entropy.cpu()

                    if compute_varentropy:
                        varentropy = torch.sum(probs * (log_probs + entropy.unsqueeze(-1)) ** 2, dim=-1)
                        results["varentropy"] = varentropy.cpu()

                # Add additional characteristics computations here

            else:
                raise ValueError(f"Unknown mechanism: {mechanism}")

        # Depending on config, decide the output format
        output_format = config.get("output_format", "dict")  # or "tensor"

        if output_format == "dict":
            return results
        elif output_format in ["tensor", "list"]:
            # Collect characteristics into a single tensor or list
            characteristics = []
            if compute_entropy:
                characteristics.append(results["entropy"])
            if compute_varentropy:
                characteristics.append(results["varentropy"])
            # Add additional characteristics here

            if not characteristics:
                raise ValueError("No characteristics were computed based on the configuration.")

            # Stack characteristics along a new dimension
            combined = torch.stack(characteristics, dim=-1)  # Shape: (batch_size, features)
            if output_format == "tensor":
                return combined
            else:
                return combined.cpu().numpy().tolist()
        else:
            raise ValueError(f"Unknown output_format: {output_format}")

    def visualize_results(self, results, config: dict, title=None, height=800):
        """
        Creates interactive visualizations for entropy and varentropy results using Plotly.

        Args:
            results (dict or tensor/list): Output from entropy_characterize.
            config (dict): Configuration dictionary specifying what was computed.
            title (str, optional): Title for the visualization.
            height (int, optional): Height of the plot in pixels.

        Returns:
            plotly.graph_objects.Figure: Interactive figure with visualizations.
        """
        mechanism = config.get("mechanism", "per_token")
        compute_entropy = config.get("compute_entropy", False)
        compute_varentropy = config.get("compute_varentropy", False)
        # Add more characteristics as needed

        if mechanism == "per_string":
            # Visualization for full-string characteristics
            fig = go.Figure()

            if compute_entropy:
                fig.add_trace(
                    go.Bar(
                        x=results['input_strings'],
                        y=results['entropy'],
                        text=results['entropy'],
                        textposition='auto',
                        name='Entropy'
                    )
                )

            if compute_varentropy:
                fig.add_trace(
                    go.Bar(
                        x=results['input_strings'],
                        y=results['varentropy'],
                        text=results['varentropy'],
                        textposition='auto',
                        name='Varentropy'
                    )
                )

            fig.update_layout(
                title=title or 'Entropy and Varentropy Analysis (Full String)',
                xaxis_title='Input Strings',
                yaxis_title='Value',
                barmode='group',
                height=height
            )

        elif mechanism == "per_token":
            # Visualization for per-token characteristics
            fig = make_subplots(
                rows=3, cols=1,
                subplot_titles=('Entropy Over Tokens', 'Varentropy Over Tokens', 'Token-wise Analysis'),
                vertical_spacing=0.1,
                row_heights=[0.35, 0.35, 0.3]
            )

            for batch_idx, input_string in enumerate(results['input_strings']):
                mask = results['attention_mask'][batch_idx]
                seq_len = mask.sum().item()

                tokens = results['tokens'][batch_idx][:seq_len]
                positions = np.arange(seq_len)

                if compute_entropy:
                    entropy_values = results['entropy'][batch_idx][:seq_len].numpy()
                    fig.add_trace(
                        go.Scatter(
                            x=positions,
                            y=entropy_values,
                            mode='lines+markers',
                            name=f'Entropy (String {batch_idx + 1})',
                            hovertemplate='Position: %{x}<br>Entropy: %{y:.3f}<extra></extra>'
                        ),
                        row=1, col=1
                    )

                if compute_varentropy:
                    varentropy_values = results['varentropy'][batch_idx][:seq_len].numpy()
                    fig.add_trace(
                        go.Scatter(
                            x=positions,
                            y=varentropy_values,
                            mode='lines+markers',
                            name=f'Varentropy (String {batch_idx + 1})',
                            hovertemplate='Position: %{x}<br>Varentropy: %{y:.3f}<extra></extra>'
                        ),
                        row=2, col=1
                    )

                # Token-wise heatmap for entropy and varentropy
                heatmap_z = []
                heatmap_y = []
                if compute_entropy:
                    heatmap_z.append(results["entropy"][batch_idx][:seq_len].numpy())
                    heatmap_y.append('Entropy')
                if compute_varentropy:
                    heatmap_z.append(results["varentropy"][batch_idx][:seq_len].numpy())
                    heatmap_y.append('Varentropy')

                if heatmap_z:
                    fig.add_trace(
                        go.Heatmap(
                            z=heatmap_z,
                            x=tokens,
                            y=heatmap_y,
                            colorscale='Viridis',
                            showscale=True,
                            hoverongaps=False,
                            hovertemplate='Token: %{x}<br>Metric: %{y}<br>Value: %{z:.3f}<extra></extra>'
                        ),
                        row=3, col=1
                    )

            fig.update_layout(
                height=height,
                showlegend=True,
                title=title or 'Entropy and Varentropy Analysis',
                hovermode='closest'
            )

            # Update axes labels
            fig.update_xaxes(title_text='Token Position', row=1, col=1)
            fig.update_xaxes(title_text='Token Position', row=2, col=1)
            fig.update_xaxes(title_text='Tokens', row=3, col=1)

            if compute_entropy:
                fig.update_yaxes(title_text='Entropy', row=1, col=1)
            if compute_varentropy:
                fig.update_yaxes(title_text='Varentropy', row=2, col=1)

        else:
            raise ValueError(f"Unknown mechanism: {mechanism}")

        return fig

    def permute_dataset(
        self,
        dataset: list,
        config: dict,
        sort_by: str,
        descending: bool = False,
        max_length: int = 512
    ):
        """
        Permutes (sorts) the dataset based on a specified characteristic.

        Args:
            dataset (list): List of input strings.
            config (dict): Configuration dictionary for entropy_characterize.
            sort_by (str): The characteristic to sort by (e.g., 'entropy', 'varentropy').
                            For per_token mechanism, specify 'entropy_token_avg', etc.
            descending (bool): Whether to sort in descending order.
            max_length (int): Maximum sequence length for tokenization.

        Returns:
            tuple: (permuted_dataset, sorted_characteristics)
        """
        # Compute characteristics
        results = self.entropy_characterize(
            input_strings=dataset,
            config=config,
            max_length=max_length
        )

        mechanism = config.get("mechanism", "per_token")
        compute_entropy = config.get("compute_entropy", False)
        compute_varentropy = config.get("compute_varentropy", False)
        # Add more characteristics as needed

        if sort_by not in results:
            # Handle per-token mechanism by aggregating per-token characteristics
            if mechanism == "per_token":
                padding_mask = results["attention_mask"]
                if sort_by == "entropy_token_avg" and compute_entropy:
                    # Masked mean
                    entropy = results["entropy"]
                    masked_entropy = torch.sum(entropy * padding_mask, dim=1) / torch.sum(padding_mask, dim=1)
                    characteristic = masked_entropy
                elif sort_by == "entropy_token_sum" and compute_entropy:
                    entropy = results["entropy"]
                    masked_entropy = torch.sum(entropy * padding_mask, dim=1)
                    characteristic = masked_entropy
                elif sort_by == "varentropy_token_avg" and compute_varentropy:
                    varentropy = results["varentropy"]
                    masked_varentropy = torch.sum(varentropy * padding_mask, dim=1) / torch.sum(padding_mask, dim=1)
                    characteristic = masked_varentropy
                elif sort_by == "varentropy_token_sum" and compute_varentropy:
                    varentropy = results["varentropy"]
                    masked_varentropy = torch.sum(varentropy * padding_mask, dim=1)
                    characteristic = masked_varentropy
                else:
                    raise ValueError(f"Unknown sort_by option: {sort_by}")
            else:
                raise ValueError(f"sort_by '{sort_by}' not found in results and mechanism is '{mechanism}'")
        else:
            # Per-string mechanism
            characteristic = results[sort_by]

        # Convert characteristic to numpy for sorting
        characteristic_np = characteristic.cpu().numpy()

        # Get sorted indices
        sorted_indices = np.argsort(characteristic_np)
        if descending:
            sorted_indices = sorted_indices[::-1]

        # Permute dataset
        permuted_dataset = [dataset[idx] for idx in sorted_indices]

        return permuted_dataset, characteristic_np[sorted_indices]

    def display_sorted_characteristics(
        self,
        dataset: list,
        config: dict,
        sort_by: str,
        descending: bool = False,
        max_length: int = 512
    ):
        """
        Permutes the dataset and displays the sorted characteristics.

        Args:
            dataset (list): List of input strings.
            config (dict): Configuration dictionary for entropy_characterize.
            sort_by (str): The characteristic to sort by.
            descending (bool): Whether to sort in descending order.
            max_length (int): Maximum sequence length for tokenization.

        Returns:
            tuple: (permuted_dataset, sorted_characteristics)
        """
        permuted_dataset, sorted_characteristics = self.permute_dataset(
            dataset=dataset,
            config=config,
            sort_by=sort_by,
            descending=descending,
            max_length=max_length
        )
        print(f"Dataset sorted by '{sort_by}' in {'descending' if descending else 'ascending'} order.")
        for idx, (string, characteristic) in enumerate(zip(permuted_dataset, sorted_characteristics)):
            print(f"{idx + 1}: {characteristic:.4f} - {string}")
        return permuted_dataset, sorted_characteristics



if __name__ == "__main__":
    seed = 1337
    torch.manual_seed(seed=seed)
    #model_id = 'HuggingFaceTB/SmolLM-360M-Instruct' #No need for tokens as Smollm is not gated!
    model_id = "meta-llama/Llama-3.2-1B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    entropix_model = EntropixModel(model, tokenizer, seed=seed)

ModuleNotFoundError: No module named 'transformers.utils'

## Run

In [None]:
# Example inputs
input_strings = [
    "The quick brown fox jumps over the lazy dog.",  # Classic pangram
    "In quantum mechanics, particles can exist in multiple states simultaneously.",  # Scientific
    "她站在窗前，望着远方的山峰。",  # Chinese (Looking at distant mountains)
    "To be, or not to be, that is the question.",  # Literary/Shakespeare
    "The cryptocurrency market experienced significant volatility today.",  # Financial news
    "Je pense, donc je suis.",  # French philosophy (Descartes)
    "🌟 Dancing under the moonlight, spirits high and hearts light. 🌙",  # Emojis and poetic
    "SELECT * FROM users WHERE age > 18;",  # SQL code
    "The neural network achieved 98.5% accuracy on the test dataset.",  # AI/ML
    "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",  # Latin placeholder
    "Breaking: Major breakthrough in fusion energy announced today!",  # News headline
    "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",  # Python code
    "Step 1: Preheat oven to 350°F. Step 2: Mix ingredients thoroughly.",  # Recipe instructions
    "Once upon a time, in a galaxy far, far away...",  # Story opening
    "Error 404: Page not found. Please check the URL and try again.",  # Technical error
    "Climate change threatens biodiversity in coral reef ecosystems.",  # Environmental
    "おはようございます、今日はいい天気ですね。",  # Japanese (Good morning, nice weather today)
    "1234567890 !@#$%^&*()_+ <>?:\"{}|",  # Numbers and special characters
    "URGENT: Meeting rescheduled to 3PM EST - All hands required",  # Business communication
    "The composition of Bach's fugues demonstrates mathematical precision.",  # Music analysis
    "Das Leben ist wie ein Fahrrad. Man muss sich vorwärts bewegen.",  # German (Einstein quote)
    "for i in range(len(array)): if array[i] > max_val: max_val = array[i]",  # More Python code
    "CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(255));",  # SQL DDL
    "La vita è bella quando si vive con passione.",  # Italian (Life is beautiful...)
    "RT @SpaceX: Successful launch of Starship prototype #42! 🚀",  # Social media
    "В тихом омуте черти водятся.",  # Russian proverb
    "async function fetchData() { const response = await fetch(url); }",  # JavaScript async
    "🎮 Level Up! You've earned 1000 XP and unlocked new achievements! 🏆",  # Gaming with emojis
    "<!DOCTYPE html><html><head><title>Hello World</title></head></html>",  # HTML
    "Hola mundo, ¿cómo estás hoy?",  # Spanish greeting
    "import numpy as np; X = np.array([[1, 2], [3, 4]])",  # Scientific Python
    "Breaking News: Artificial Intelligence Achieves New Milestone in Protein Folding",  # Science news
    "public class HelloWorld { public static void main(String[] args) {} }",  # Java
    "The mitochondria is the powerhouse of the cell.",  # Biology
    "git commit -m \"Fix: resolve memory leak in main loop\"",  # Git command
    "अतिथि देवो भव:",  # Sanskrit (Guest is God)
    "try { throw new Error('Test'); } catch (e) { console.log(e); }",  # JavaScript error handling
    "Dans les champs de l'observation, le hasard ne favorise que les esprits préparés.",  # French (Pasteur)
    "docker run -d -p 80:80 nginx:latest",  # Docker command
    "While(true) { System.out.println(\"Hello, World!\"); }",  # Infinite loop
    "kubectl get pods -n kubernetes-dashboard",  # Kubernetes command
    "Χαίρετε! Πώς είστε σήμερα;",  # Greek greeting
    "const handleSubmit = (e) => { e.preventDefault(); setState(newValue); };",  # React code
    "مرحبا بالعالم",  # Arabic (Hello World)
    "SELECT COUNT(*) OVER (PARTITION BY department) FROM employees;",  # Advanced SQL
    "pip install tensorflow==2.8.0 torch==2.0.0 transformers==4.28.0",  # Package installation
    "한글은 세상에서 가장 과학적인 글자입니다.",  # Korean (Hangul is the most scientific writing system)
    "{ \"name\": \"John\", \"age\": 30, \"city\": \"New York\" }",  # JSON data
    "CRITICAL: Memory usage exceeded 90% threshold at 02:45:30 UTC",  # System log
    "@media (max-width: 768px) { .container { flex-direction: column; } }",  # CSS media query
    "Fibonacci sequence: 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144...",  # Mathematical sequence
    "$ curl -X POST https://api.example.com/v1/data -H \"Content-Type: application/json\"",  # CURL command
    "WARNING: Certificate expires in 7 days. Please renew SSL certificate.",  # Security warning
    "sudo apt-get update && sudo apt-get upgrade -y",  # Linux command
    "print(f\"Current temperature: {temp:.2f}°C at {time:%H:%M:%S}\")",  # Python f-string
    "Революция в квантовых вычислениях: создан 1000-кубитный процессор",  # Russian tech news
    "interface User { id: string; name: string; age: number; }",  # TypeScript interface
    "O Romeo, Romeo! wherefore art thou Romeo?",  # Shakespeare quote
    "Exception in thread \"main\" java.lang.NullPointerException at Main.java:42",  # Java error
    "今日は富士山に登りました。頂上からの景色は素晴らしかったです。"  # Japanese (Climbing Mt. Fuji)
]

# Define configuration for per-token analysis
config_per_token = {
    "mechanism": "per_token",          # Options: "per_token", "per_string"
    "compute_entropy": True,
    "compute_varentropy": True,
    "output_format": "dict"            # Options: "dict", "tensor", "list"
}

# Define configuration for per-string analysis
config_per_string = {
    "mechanism": "per_string",
    "compute_entropy": True,
    "compute_varentropy": True,
    "output_format": "dict"
}

# Compute characteristics per token
results_per_token = entropix_model.entropy_characterize(
    input_strings=input_strings,
    config=config_per_token,
    max_length=512  # Adjust as needed
)

# Compute characteristics per string
results_full = entropix_model.entropy_characterize(
    input_strings=input_strings,
    config=config_per_string,
    max_length=512  # Adjust as needed
)

# Visualize results per token
fig_per_token = entropix_model.visualize_results(
    results=results_per_token,
    config=config_per_token,
    title="Entropy Analysis (Per Token)"
)
fig_per_token.show()

# Visualize results per string
fig_full = entropix_model.visualize_results(
    results=results_full,
    config=config_per_string,
    title="Entropy Analysis (Full String)"
)
fig_full.show()

# Optional: Save the figures as HTML files
# fig_per_token.write_html("entropy_analysis_per_token.html")
# fig_full.write_html("entropy_analysis_full.html")

# Example of permuting the dataset based on average entropy per token
permuted_dataset, sorted_characteristics = entropix_model.permute_dataset(
    dataset=input_strings,
    config=config_per_token,
    sort_by="entropy_token_avg",      # Define your sort key
    descending=True,
    max_length=512
)

print("\nPermuted Dataset Sorted by Average Entropy per Token:")
for idx, (string, entropy) in enumerate(zip(permuted_dataset, sorted_characteristics)):
    print(f"{idx + 1}: {entropy:.4f} - {string}")

# Optionally, visualize the sorted characteristics
sorted_results = entropix_model.entropy_characterize(
    input_strings=permuted_dataset,
    config=config_per_token,
    max_length=512
)
sorted_fig = entropix_model.visualize_results(
    results=sorted_results,
    config=config_per_token,
    title="Sorted Entropy Analysis (Per Token)"
)
sorted_fig.show()


In [None]:
fig_full

In [None]:
fig_per_token