In [1]:
# Celda 1 /kaggle/input/aaaaaad/openwebtext_2GB.txt

import numpy as np
import pandas as pd

import os, random, math, time, sys, collections, itertools, contextlib
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from collections import defaultdict, deque
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from datetime import datetime

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")
AMP_DTYPE = torch.bfloat16
USE_AMP = True

sns.set_style("darkgrid", {"axes.facecolor": "#1c1c1c", "grid.color": "#444"})
plt.rcParams.update({
    "figure.facecolor": "#1c1c1c",
    "savefig.facecolor": "#1c1c1c",
    "text.color": "white",
    "axes.labelcolor": "white",
    "axes.edgecolor": "#aaa"
})


In [2]:

# 2
os.makedirs('/kaggle/working/plastic_results/checkpoints', exist_ok=True)
os.makedirs('/kaggle/working/plastic_results/csv_data', exist_ok=True)
os.makedirs('/kaggle/working/baseline_results/checkpoints', exist_ok=True)
os.makedirs('/kaggle/working/baseline_results/csv_data', exist_ok=True)
os.makedirs('/kaggle/working/plots', exist_ok=True)


In [3]:
# Celda 3
@torch.no_grad()
def generate_text(model, prompt, stoi, itos, device, length=100, temperature=0.8):
    model.eval()

    if hasattr(model, 'c') and hasattr(model.c, 'sequence_length'):
         seq_len = model.c.sequence_length
    elif hasattr(model, 'cfg') and hasattr(model.cfg, 'sequence_length'):
         seq_len = model.cfg.sequence_length
    else:
        seq_len = 128

    tokens = [stoi.get(c, 0) for c in prompt]
    generated_tokens = tokens.copy()

    h_state, c_state = None, None
    if hasattr(model, 'lstm'):
        num_layers = model.lstm.num_layers
        hidden_size = model.lstm.hidden_size
        h_state = torch.zeros(num_layers, 1, hidden_size).to(device)
        c_state = torch.zeros(num_layers, 1, hidden_size).to(device)

    if len(tokens) > 1:
        prompt_tensor = torch.tensor([tokens[:-1]]).to(device)
        if hasattr(model, 'mask'):
            _, (h_state, c_state) = model.lstm(model.embed(prompt_tensor), (h_state, c_state))
        else:
            _, (h_state, c_state) = model.lstm(model.embed(prompt_tensor))

    input_token = torch.tensor([[tokens[-1]]]).to(device)

    for _ in range(length):
        if hasattr(model, 'mask'):
            emb = model.embed(input_token)
            output, (h_state, c_state) = model.lstm(emb, (h_state, c_state))
            last_h = output[:, -1, :]

            W_eff = model.W.weight + model.c.hebb_eta * model.h_fast.to(model.W.weight.dtype) if not model.c.disable_hebbian else model.W.weight
            acts = torch.sigmoid(F.linear(last_h, W_eff * model.mask.unsqueeze(1), model.W.bias))
            logits = model.out(acts)
        else:
            emb = model.embed(input_token)
            output, (h_state, c_state) = model.lstm(emb, (h_state, c_state))
            hidden = torch.sigmoid(model.hidden_layer(output[:, -1, :]))
            logits = model.out(hidden)

        probs = F.softmax(logits.squeeze(0) / temperature, dim=-1)
        next_token_val = torch.multinomial(probs, 1).item()

        generated_tokens.append(next_token_val)
        input_token = torch.tensor([[next_token_val]]).to(device)

    generated_text = ''.join([itos.get(t, '') for t in generated_tokens])
    model.train()
    return generated_text, generated_tokens[len(tokens):]

class GenerationMetrics:
    @staticmethod
    def calculate_repetition_rate(tokens, n=4):
        if len(tokens) < n:
            return 0.0

        ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
        if not ngrams:
            return 0.0

        unique_ngrams = len(set(ngrams))
        total_ngrams = len(ngrams)

        return 1.0 - (unique_ngrams / total_ngrams)

    @staticmethod
    def calculate_diversity(generated_texts_tokens):
        all_tokens = []
        for tokens in generated_texts_tokens:
            all_tokens.extend(tokens)

        if not all_tokens:
            return 0.0

        unique_tokens = len(set(all_tokens))
        total_tokens = len(all_tokens)

        return unique_tokens / total_tokens

def save_generation_samples(model, epoch, step, stoi, itos, device, cfg, model_name):
    samples = []

    prompts_to_use = [p for cat in cfg.test_prompts_categorized.values() for p in cat]
    selected_prompts = random.sample(prompts_to_use, k=min(len(prompts_to_use), cfg.generation_samples_per_epoch))

    for prompt in selected_prompts:
        full_text, generated_tokens = generate_text(model, prompt, stoi, itos, device,
                                                    cfg.generation_length, cfg.generation_temperature)

        generated_part = full_text[len(prompt):]
        repetition = GenerationMetrics.calculate_repetition_rate(generated_tokens)

        samples.append({
            'prompt': prompt,
            'generated_text': generated_part,
            'generated_tokens': generated_tokens,
            'repetition_rate': repetition
        })

    return samples

In [4]:
# Celda 4

class Config:
    sequence_length = 128
    batch_size = 64
    max_data_mb = 300
    learning_rate = 5e-4
    num_epochs = 6
    eval_every_steps = 200

    embedding_dim = 256
    lstm_hidden_size = 512
    num_lstm_layers = 2
    vocab_size = None

    max_neurons = 1024
    initial_active = 256
    target_population_ratio = 0.3
    plasticity_interval = 20
    maturation_time = 200
    survival_percentile = 70
    birth_percentile = 90
    tournament_interval = 200
    elite_clone_ratio = 0.3
    adaptive_thresholds = True
    threshold_history_size = 1000
    population_penalty = 0.001
    capacity_penalty_start = 0.25
    sparsity_lambda = 0.0005
    max_births_per_update = 5
    max_death_ratio = 0.1
    birth_cooldown = 50
    hebb_eta = 0.02
    hebb_decay = 0.97
    hebb_norm_clip = 10.0
    hebb_top_neurons_ratio = 0.5
    elite_protection_ratio = 0.15
    elite_update_interval = 1000
    maturity_threshold = 1000
    max_plasticity_reduction = 0.7
    checkpoint_window = 5
    ensemble_interval = 2000
    ensemble_blend_factor = 0.2
    cache_computations = True
    cache_size = 100

    test_prompts_categorized = {
        'code': [
            "def fibonacci(n):",
            "class NeuralNetwork:",
            "import numpy as np",
        ],
        'narrative': [
            "Once upon a time",
            "The scientist discovered",
            "In the year 2050,",
        ],
        'technical': [
            "Machine learning is",
            "The algorithm works by",
            "Neural networks can",
        ],
        'conversational': [
            "Hello, how are",
            "What do you think about",
            "Could you please explain",
        ]
    }

    generation_length = 80
    generation_temperature = 0.8
    generation_samples_per_epoch = 2

    phase_b_start = 6000
    phase_b_data_mb = 60
    replay_size = 5000

    data_path = '/kaggle/input/openwebtext-2gb/openwebtext_2GB.txt'
    plastic_checkpoints_dir = '/kaggle/working/plastic_results/checkpoints'
    plastic_csv_dir = '/kaggle/working/plastic_results/csv_data'
    baseline_save_dir = '/kaggle/working/baseline_results'
    plots_dir = '/kaggle/working/plots'

    disable_hebbian = False
    disable_death = False
    track_forgetting = True
    track_adaptation_speed = True
    track_generation_quality = True

cfg = Config()


In [5]:
# Celda 5
class CharDataset(Dataset):
    def __init__(self, text, seq_len, stoi=None, itos=None):
        if stoi is None:
            chars = sorted(set(text))
            self.stoi = {c: i for i, c in enumerate(chars)}
            self.itos = {i: c for c, i in self.stoi.items()}
        else:
            self.stoi, self.itos = stoi, itos
        self.data = [self.stoi.get(c, 0) for c in text]
        self.seq_len = seq_len

    def __len__(self):
        return max(1, (len(self.data) - 1) // self.seq_len)

    def __getitem__(self, idx):
        s = idx * self.seq_len
        chunk = self.data[s:s + self.seq_len + 1]
        if len(chunk) < self.seq_len + 1:
            chunk = chunk + [0] * (self.seq_len + 1 - len(chunk))
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y


def build_loaders(cfg, phase='A', existing_stoi=None, existing_itos=None):
    print(f'Loading text for Phase {phase}...')
    try:
        with open(cfg.data_path, 'r', encoding='utf-8', errors='ignore') as f:
            if phase == 'A':
                text = f.read(int(cfg.max_data_mb * 1024 * 1024))
            else:
                f.seek(int(cfg.max_data_mb * 1024 * 1024))
                text = f.read(int(cfg.phase_b_data_mb * 1024 * 1024))
        print(f"Loaded {len(text)/1e6:.1f}M characters for Phase {phase}")
    except FileNotFoundError:
        print("Dataset not found, using synthetic data")
        if phase == 'A':
            text = "The quick brown fox jumps over the lazy dog. " * 100000
        else:
            text = "In a hole in the ground there lived a hobbit. " * 100000

    full = CharDataset(text, cfg.sequence_length, existing_stoi, existing_itos)

    if phase == 'A':
        cfg.vocab_size = len(full.itos)
        print(f"Vocabulary size: {cfg.vocab_size}")

    n = len(full)
    train_size = int(0.9 * n)
    train, val = torch.utils.data.random_split(
        full, [train_size, n - train_size],
        generator=torch.Generator().manual_seed(SEED + (1 if phase == 'B' else 0))
    )

    num_workers = min(4, (os.cpu_count() or 2))
    use_workers = num_workers > 0
    common_kwargs = dict(
        batch_size=cfg.batch_size,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        persistent_workers=use_workers
    )
    if use_workers:
        common_kwargs['prefetch_factor'] = 2

    def mk(ds, shuf):
        return DataLoader(ds, shuffle=shuf, **common_kwargs)

    return mk(train, True), mk(val, False), full.stoi, full.itos, full


In [6]:
# 5.5
class CatastrophicForgettingTracker:
    def __init__(self, cfg):
        self.cfg = cfg
        self.phase_checkpoints = {}
        self.phase_performance = defaultdict(dict)
        self.retention_scores = []

        # Prompts específicos por fase
        self.phase_specific_prompts = {
            'A': {
                'code': ["def calculate_mean(arr):", "import pandas as pd", "class DataProcessor:"],
                'narrative': ["The old man sat", "In the beginning", "She remembered when"],
                'technical': ["The standard deviation", "Linear regression assumes", "The p-value indicates"]
            },
            'B': {
                'code': ["async function fetchData", "const reducer = (state", "interface UserProps"],
                'narrative': ["The spaceship landed", "In the future", "The robot said"],
                'technical': ["Quantum computing uses", "The blockchain verifies", "Neural architecture search"]
            }
        }

        self.forgetting_history = []
        self.empirical_metrics_history = []
        self.phase_a_activations = {}
        self.initial_phase_a_ppl = None

    def save_phase_checkpoint(self, model, phase, step, device):
        checkpoint_data = {
            'step': step,
            'phase': phase,
            'active_neurons': int(model.mask.sum().item()),
            'phase_neurons': {
                'A': list(model.phase_a_neurons),
                'B': list(model.phase_b_neurons)
            }
        }
        self.phase_checkpoints[phase] = checkpoint_data

    @torch.no_grad()
    def evaluate_phase_specific_performance(self, model, stoi, itos, device, current_phase, step):
        """
        Evalúa PPL por fase/categoría.

        - Soporta logits 2D (B,V) o 3D (B,T,V).
        - Usa .detach() al guardar activaciones.
        - Usa AMP (bf16 si disponible; fp16 en GPUs pre-Ada/Hopper) para acelerar.
        """
        model.eval()
        results = {}
        activations_by_phase = {'A': [], 'B': []}

        use_cuda = torch.cuda.is_available() and device.type == 'cuda'
        amp_dtype = torch.bfloat16
        if use_cuda:
            major, _ = torch.cuda.get_device_capability(device)
            if major < 8:  # <= Ampere: preferir fp16
                amp_dtype = torch.float16

        for phase, prompt_categories in self.phase_specific_prompts.items():
            phase_results = {}
            all_perplexities = []
            phase_activations = []

            for category, prompts in prompt_categories.items():
                category_perplexities = []

                for prompt in prompts:
                    # Generación breve para calentar estado
                    _full_text, _generated_tokens = generate_text(
                        model, prompt, stoi, itos, device,
                        length=50, temperature=0.8
                    )

                    tokens = [stoi.get(c, 0) for c in prompt]
                    if len(tokens) > 1:
                        input_tensor = torch.tensor([tokens[:-1]], device=device)
                        target_tensor = torch.tensor([tokens[1:]], device=device)

                        if use_cuda:
                            with torch.cuda.amp.autocast(dtype=amp_dtype):
                                if hasattr(model, 'mask'):
                                    logits, analysis = model(input_tensor, return_analysis=True)
                                    if analysis is not None and 'activations' in analysis:
                                        phase_activations.append(
                                            analysis['activations'].detach().float().cpu().numpy()
                                        )
                                else:
                                    logits = model(input_tensor)
                                    analysis = None
                        else:
                            if hasattr(model, 'mask'):
                                logits, analysis = model(input_tensor, return_analysis=True)
                                if analysis is not None and 'activations' in analysis:
                                    phase_activations.append(
                                        analysis['activations'].detach().float().cpu().numpy()
                                    )
                            else:
                                logits = model(input_tensor)
                                analysis = None

                        # Manejo robusto de forma de logits
                        if logits.dim() == 3:
                            logits_last = logits[:, -1, :]
                        else:
                            logits_last = logits

                        loss = F.cross_entropy(logits_last, target_tensor[:, -1])
                        perplexity = math.exp(loss.item()) if loss.item() < 700 else float('inf')

                        category_perplexities.append(perplexity)
                        all_perplexities.append(perplexity)

                phase_results[category] = np.mean(category_perplexities) if category_perplexities else float('inf')

            results[phase] = {
                'category_scores': phase_results,
                'mean_perplexity': np.mean(all_perplexities) if all_perplexities else float('inf')
            }

            if phase_activations:
                activations_by_phase[phase] = np.concatenate(phase_activations, axis=0)

        # Registrar activaciones iniciales de fase A
        if current_phase == 'A' and self.initial_phase_a_ppl is None:
            self.initial_phase_a_ppl = results['A']['mean_perplexity']
            self.phase_a_activations['initial'] = activations_by_phase['A']
        elif current_phase == 'B':
            self.phase_a_activations['current'] = activations_by_phase['A']

        # Calcular retención cuando estamos en B
        retention_score = None
        if current_phase == 'B' and 'A' in self.phase_checkpoints:
            if 'A' in results and results['A']['mean_perplexity'] < float('inf'):
                baseline_ppl = self.phase_performance.get('A', {}).get('final_perplexity', 10.0)
                current_ppl = results['A']['mean_perplexity']
                retention_score = 1.0 - min(1.0, (current_ppl - baseline_ppl) / max(baseline_ppl, 1e-9))

        self.phase_performance[current_phase][step] = results

        record = {
            'step': step,
            'current_phase': current_phase,
            'phase_a_ppl': results['A']['mean_perplexity'],
            'phase_b_ppl': results['B']['mean_perplexity'] if 'B' in results else None,
            'retention_score': retention_score
        }

        for phase in ['A', 'B']:
            if phase in results:
                for category, score in results[phase]['category_scores'].items():
                    record[f'{phase}_{category}_ppl'] = score

        self.forgetting_history.append(record)

        model.train()
        return results, retention_score, activations_by_phase

    def calculate_empirical_metrics(self, model, current_phase, step):
        metrics = {}

        if self.initial_phase_a_ppl and current_phase == 'B':
            current_phase_a_ppl = self.phase_performance.get('B', {}).get(step, {}).get('A', {}).get('mean_perplexity', float('inf'))
            if current_phase_a_ppl < float('inf'):
                metrics['retencion_fase_a'] = (self.initial_phase_a_ppl / current_phase_a_ppl) * 100
            else:
                metrics['retencion_fase_a'] = 0.0
        else:
            metrics['retencion_fase_a'] = 100.0

        if 'initial' in self.phase_a_activations and 'current' in self.phase_a_activations:
            initial_acts = self.phase_a_activations['initial'].flatten()
            current_acts = self.phase_a_activations['current'].flatten()

            if len(initial_acts) == len(current_acts):
                correlation = np.corrcoef(initial_acts, current_acts)[0, 1]
                metrics['interferencia_B'] = 1.0 - abs(correlation)
            else:
                metrics['interferencia_B'] = 0.0
        else:
            metrics['interferencia_B'] = 0.0

        total_neurons = model.c.max_neurons
        active_neurons = int(model.mask.sum().item())
        metrics['costo_oportunidad'] = (total_neurons - active_neurons) / max(1, total_neurons)

        self.empirical_metrics_history.append({
            'step': step,
            'phase': current_phase,
            **metrics
        })

        return metrics

    def save_forgetting_metrics_to_csv(self, output_dir):
        if self.forgetting_history:
            df = pd.DataFrame(self.forgetting_history)
            df.to_csv(f"{output_dir}/catastrophic_forgetting_metrics.csv", index=False)

        if self.empirical_metrics_history:
            df_empirical = pd.DataFrame(self.empirical_metrics_history)
            df_empirical.to_csv(f"{output_dir}/empirical_metrics.csv", index=False)


In [7]:
# 6
class ReplayBuffer:
    def __init__(self, max_batches):
        self.buf = deque(maxlen=max_batches)
        self.losses = deque(maxlen=max_batches)
        self.phase_labels = deque(maxlen=max_batches)

    def push(self, x, y, loss=None, phase='A'):
        self.buf.append((x.cpu(), y.cpu()))
        self.losses.append(loss if loss is not None else 1.0)
        self.phase_labels.append(phase)

    def sample(self, n=1):
        if not self.buf:
            return None

        # Priorizar samples de fase A y con mayor pérdida
        if len(self.buf) > n:
            # Crear pesos de prioridad
            weights = []
            for i, (loss, phase) in enumerate(zip(self.losses, self.phase_labels)):
                # Dar más peso a fase A y a pérdidas altas
                weight = loss if loss is not None else 1.0
                if phase == 'A':
                    weight *= 2.0  # Duplicar peso para fase A
                weights.append(weight)

            # Normalizar pesos
            total_weight = sum(weights)
            if total_weight > 0:
                weights = [w/total_weight for w in weights]
            else:
                weights = [1.0/len(self.buf)] * len(self.buf)

            # Muestrear con pesos
            indices = np.random.choice(len(self.buf), size=n, p=weights)
            samples = [self.buf[i] for i in indices]

            if n == 1:
                return samples[0]
            return samples
        else:
            # Si tenemos menos samples que los pedidos, devolver todos
            if n == 1:
                return random.choice(self.buf)
            return list(self.buf)

In [8]:
# 7
class CompetitiveThresholds:
    """Sistema de umbrales competitivos basados en percentiles"""
    def __init__(self, survival_percentile=70, birth_percentile=90, history_size=1000):
        self.survival_percentile = survival_percentile
        self.birth_percentile = birth_percentile
        self.contribution_history = deque(maxlen=history_size)
        self.loss_history = deque(maxlen=history_size)
        self.uniqueness_history = deque(maxlen=history_size)

    def update(self, contributions, uniqueness_scores, loss):
        """Actualizar historial y calcular nuevos umbrales dinámicos"""
        # Filtrar solo neuronas activas
        active_contribs = contributions[contributions > 0].cpu().numpy()
        active_unique = uniqueness_scores[uniqueness_scores > 0].cpu().numpy()

        if len(active_contribs) > 0:
            self.contribution_history.extend(active_contribs.tolist())
            self.uniqueness_history.extend(active_unique.tolist())
        self.loss_history.append(loss)

        # Calcular umbrales basados en percentiles
        if len(self.contribution_history) > 100:
            # Umbral de muerte: percentil para supervivencia
            contrib_threshold = np.percentile(self.contribution_history, 100 - self.survival_percentile)
            unique_threshold = np.percentile(self.uniqueness_history, 100 - self.survival_percentile)

            # Combinar ambos criterios
            death_threshold = {
                'contribution': contrib_threshold,
                'uniqueness': unique_threshold
            }

            # Umbral de nacimiento basado en pérdida histórica
            birth_threshold = np.percentile(self.loss_history, self.birth_percentile)
        else:
            # Valores por defecto mientras se acumula historial
            death_threshold = {
                'contribution': 0.01,
                'uniqueness': 0.3
            }
            birth_threshold = 3.0

        return death_threshold, birth_threshold


class CachedComputations:
    """Cache para cálculos costosos como correlaciones"""
    def __init__(self, cache_size=100):
        self.cache = {}
        self.cache_size = cache_size
        self.hits = 0
        self.misses = 0

    def get_or_compute(self, key, compute_fn):
        """Obtener del cache o computar"""
        if key in self.cache:
            self.hits += 1
            return self.cache[key]

        self.misses += 1
        result = compute_fn()

        # LRU simple: eliminar el más antiguo si excedemos tamaño
        if len(self.cache) >= self.cache_size:
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]

        self.cache[key] = result
        return result

    def clear(self):
        """Limpiar cache"""
        self.cache.clear()

    def get_stats(self):
        """Obtener estadísticas de cache"""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        return {
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': hit_rate,
            'size': len(self.cache)
        }

In [9]:
# Celda 8

class PlasticLSTM(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.c = c

        # Core modules
        self.embed = nn.Embedding(c.vocab_size, c.embedding_dim)
        self.lstm = nn.LSTM(
            c.embedding_dim, c.lstm_hidden_size, c.num_lstm_layers,
            batch_first=True, dropout=0.1
        )
        self.W = nn.Linear(c.lstm_hidden_size, c.max_neurons)
        self.out = nn.Linear(c.max_neurons, c.vocab_size)

        # Normalization on neuron pre-activations (stable with AMP)
        self.neuron_norm = nn.LayerNorm(c.max_neurons)

        # Persistent state buffers
        self.register_buffer('mask', torch.zeros(c.max_neurons))
        self.mask[:c.initial_active] = 1

        self.register_buffer('age', torch.zeros(c.max_neurons))
        self.register_buffer('contrib', torch.zeros(c.max_neurons))

        # Fast Hebbian weights kept in lower precision to save memory/bandwidth
        self.register_buffer(
            'h_fast',
            torch.zeros(c.max_neurons, c.lstm_hidden_size, dtype=AMP_DTYPE)
        )

        self.register_buffer('uniqueness_score', torch.ones(c.max_neurons))
        self.register_buffer('error_reduction', torch.zeros(c.max_neurons))
        self.register_buffer('phase_specialization', torch.zeros(c.max_neurons))

        self.register_buffer('correlation_matrix', torch.eye(c.max_neurons))

        self.register_buffer('competitive_score', torch.zeros(c.max_neurons))

        # Threshold manager / cache
        self.threshold_manager = CompetitiveThresholds(
            c.survival_percentile,
            c.birth_percentile,
            c.threshold_history_size
        )
        if c.cache_computations:
            self.computation_cache = CachedComputations(c.cache_size)
        else:
            self.computation_cache = None

        # Book-keeping
        self.last_birth_step = 0
        self.last_tournament_step = 0
        self.hard_patterns_bank = deque(maxlen=1000)
        self.pattern_errors = deque(maxlen=1000)
        self.life_events = []
        self.step = 0
        self._cache = (None, None, None)
        self._capacity_penalty = 0.0

        # Phases
        self.current_phase = 'A'
        self.phase_a_neurons = set(range(c.initial_active))
        self.phase_b_neurons = set()
        self.generation_history = defaultdict(list)

        self._current_thresholds = {
            'birth': 3.2,
            'death': {'contribution': 0.01, 'uniqueness': 0.3}
        }

        # Inheritance tracking
        self.register_buffer('inheritance_count', torch.zeros(c.max_neurons))
        self.register_buffer('knowledge_sources', torch.zeros(c.max_neurons))
        self.inheritance_history = []

        # Population/phase control
        self.population_flex_factor = 1.0
        self.phase_transition_buffer = 0

        # Elite / anchors
        self.elite_neurons = set()
        self.elite_protection_ratio = c.elite_protection_ratio
        self.phase_elite_history = {'A': [], 'B': []}
        self.phase_a_anchors = set()
        self.dynamic_anchor_ratio = 0.2
        self.min_anchor_ratio = 0.1
        self.max_anchor_ratio = 0.5

        # Plasticity controls
        self.register_buffer('plasticity_levels', torch.ones(c.max_neurons))
        self.register_buffer('phase_importance', torch.zeros(c.max_neurons))
        self.register_buffer('cross_phase_utility', torch.zeros(c.max_neurons))
        self.register_buffer('plasticity_factor', torch.ones(c.max_neurons))
        self.maturity_threshold = c.maturity_threshold

        # Regions (A/shared/B)
        self.phase_regions = {
            'A': (0, int(c.max_neurons * 0.6)),
            'shared': (int(c.max_neurons * 0.6), int(c.max_neurons * 0.8)),
            'B': (int(c.max_neurons * 0.8), c.max_neurons)
        }

        # Soft checkpoints
        self.soft_checkpoints = deque(maxlen=c.checkpoint_window)
        self.checkpoint_interval = 500

        # Vector phase flags
        self.register_buffer('phase_a_flag', torch.zeros(c.max_neurons, dtype=torch.bool))
        self.register_buffer('phase_b_flag', torch.zeros(c.max_neurons, dtype=torch.bool))
        self.phase_a_flag[:c.initial_active] = True

        # Init
        with torch.no_grad():
            nn.init.xavier_uniform_(self.W.weight)
            nn.init.zeros_(self.W.bias)
            nn.init.xavier_uniform_(self.out.weight)

    # --- MÉTODOS AÑADIDOS/CORREGIDOS ---
    def set_phase(self, phase):
        self.current_phase = phase
        if phase == 'B':
            self.phase_transition_buffer = 1.0 # Activa el buffer de transición
        print(f"Model phase set to: {self.current_phase}")

    def update_plasticity_levels(self, phase_a_performance, phase_b_performance):
        if self.current_phase != 'B':
            return

        # Identificar neuronas de fase A que son importantes (alto contrib)
        phase_a_indices = torch.tensor(list(self.phase_a_neurons), device=self.mask.device, dtype=torch.long)
        if len(phase_a_indices) == 0:
            return

        contribs_a = self.contrib[phase_a_indices]
        important_threshold = torch.quantile(contribs_a, 0.7) # Proteger el 30% superior
        important_neurons_mask = contribs_a >= important_threshold
        important_indices = phase_a_indices[important_neurons_mask]

        # Reducir la plasticidad de las neuronas importantes de la fase A para protegerlas
        reduction_factor = 0.05
        self.plasticity_levels[important_indices] *= reduction_factor
        self.plasticity_levels.clamp_(min=0.01) # Mínimo de plasticidad

    # --- FIN DE MÉTODOS AÑADIDOS/CORREGIDOS ---

    # Gradient hooks modulated by plasticity
    def register_gradient_hooks(self):
        def mod_grad_W(grad):
            if grad is None:
                return grad
            mask = self.plasticity_levels[:grad.shape[0]].unsqueeze(1).to(grad.device)
            return grad * mask

        def mod_grad_bias(grad):
            if grad is None:
                return grad
            return grad * self.plasticity_levels.to(grad.device)

        def mod_grad_out(grad):
            if grad is None:
                return grad
            mask = self.plasticity_levels[:grad.shape[1]].unsqueeze(0).to(grad.device)
            return grad * mask

        self.W.weight.register_hook(mod_grad_W)
        self.W.bias.register_hook(mod_grad_bias)
        self.out.weight.register_hook(mod_grad_out)

    # Similarity utilities (compute in float32 for numerical stability)
    def compute_neuron_similarity(self, idx1, idx2):
        w1 = self.W.weight[idx1].float()
        w2 = self.W.weight[idx2].float()
        weight_sim = F.cosine_similarity(w1.unsqueeze(0), w2.unsqueeze(0)).item()

        out1 = self.out.weight[:, idx1].float()
        out2 = self.out.weight[:, idx2].float()
        out_sim = F.cosine_similarity(out1.unsqueeze(0), out2.unsqueeze(0)).item()

        corr_sim = self.correlation_matrix[idx1, idx2].item()
        total_similarity = 0.4 * weight_sim + 0.3 * out_sim + 0.3 * corr_sim
        return total_similarity

    def find_inheritance_targets(self, dying_idx, k=10):
        active_mask = (self.mask > 0)
        active_indices = active_mask.nonzero(as_tuple=True)[0]
        if active_indices.numel() == 0:
            return []

        idx = dying_idx if torch.is_tensor(dying_idx) else torch.tensor(dying_idx, device=self.W.weight.device)
        dying_in_A = bool(self.phase_a_flag[idx])

        w_all = self.W.weight[active_indices].float()
        w_d = self.W.weight[idx].float().unsqueeze(0)
        sim_w = F.cosine_similarity(w_all, w_d, dim=1)

        o_all = self.out.weight[:, active_indices].T.float()
        o_d = self.out.weight[:, idx].float().unsqueeze(0)
        sim_o = F.cosine_similarity(o_all, o_d, dim=1)

        corr_sim = self.correlation_matrix[idx, active_indices]

        total_sim = 0.4 * sim_w + 0.3 * sim_o + 0.3 * corr_sim

        plasticity_bonus = 1.0 - self.plasticity_levels[active_indices] * 0.5
        total_sim = total_sim * plasticity_bonus

        if dying_in_A:
            same_phase_bonus = torch.where(self.phase_a_flag[active_indices], 1.5, 1.0)
            total_sim = total_sim * same_phase_bonus
        else:
            same_phase_bonus = torch.where(self.phase_b_flag[active_indices], 1.5, 1.0)
            total_sim = total_sim * same_phase_bonus

        inheritance_penalty = 1.0 - (self.inheritance_count[active_indices] / 10).tanh()
        adjusted_sim = total_sim * inheritance_penalty

        self_mask = active_indices != idx
        if self_mask.any():
            adjusted_sim = adjusted_sim[self_mask]
            total_sim_f = total_sim[self_mask]
            candidates = active_indices[self_mask]
        else:
            total_sim_f = total_sim
            candidates = active_indices

        if candidates.numel() == 0:
            return []

        topk = min(k, candidates.numel())
        vals, inds = adjusted_sim.topk(topk)
        chosen = candidates[inds]

        out_list = []
        for c, v, o in zip(chosen, vals, total_sim_f[inds]):
            out_list.append((c, v.item(), o.item()))
        return out_list

    def inherit_knowledge(self, dying_idx, inheritors):
        if not inheritors:
            return

        dying_weights = self.W.weight.data[dying_idx].clone()
        dying_out_weights = self.out.weight.data[:, dying_idx].clone()
        dying_bias = self.W.bias.data[dying_idx].clone()
        dying_contrib = self.contrib[dying_idx].clone()
        dying_hebbian = self.h_fast[dying_idx].clone()

        if dying_idx.item() in self.phase_a_neurons:
            low_plasticity_indices = (self.plasticity_levels < 0.3).nonzero(as_tuple=True)[0]
            if len(low_plasticity_indices) > 0:
                anchor_idx = low_plasticity_indices[torch.randint(len(low_plasticity_indices), (1,))]
                inheritors = [(anchor_idx, 1.0, 1.0)] + inheritors[:9]

        total_similarity = sum(sim for _, _, sim in inheritors)

        for target_idx, adjusted_sim, original_sim in inheritors:
            inheritance_ratio = original_sim / total_similarity if total_similarity > 0 else 1.0 / len(inheritors)
            plasticity = self.plasticity_levels[target_idx].item()
            merge_factor = 0.8 * inheritance_ratio * plasticity

            self.W.weight.data[target_idx] = (1 - merge_factor) * self.W.weight.data[target_idx] + merge_factor * dying_weights
            self.out.weight.data[:, target_idx] = (1 - merge_factor) * self.out.weight.data[:, target_idx] + merge_factor * dying_out_weights
            self.W.bias.data[target_idx] = (1 - merge_factor) * self.W.bias.data[target_idx] + merge_factor * dying_bias

            contrib_inheritance = 0.3 * inheritance_ratio * dying_contrib
            self.contrib[target_idx] += contrib_inheritance

            self.h_fast[target_idx] = (1 - merge_factor * 0.5) * self.h_fast[target_idx] + (merge_factor * 0.5) * dying_hebbian.to(self.h_fast.dtype)

            self.inheritance_count[target_idx] += 1
            self.knowledge_sources[target_idx] += inheritance_ratio

            self.inheritance_history.append({
                'step': self.step,
                'dying_neuron': dying_idx.item() if torch.is_tensor(dying_idx) else dying_idx,
                'inheritor': target_idx.item() if torch.is_tensor(target_idx) else target_idx,
                'inheritance_ratio': inheritance_ratio,
                'merge_factor': merge_factor,
                'phase': self.current_phase,
                'to_low_plasticity': plasticity < 0.3
            })

    def get_adaptive_population_target(self, current_loss, phase):
        base_target = self.c.max_neurons * self.c.target_population_ratio
        if phase == 'B' and self.phase_transition_buffer > 0:
            target = base_target * (1.0 + 0.5 * self.phase_transition_buffer)
            self.phase_transition_buffer *= 0.95
        else:
            if hasattr(self, 'recent_loss_trend'):
                if self.recent_loss_trend > 0:
                    self.population_flex_factor = min(1.3, self.population_flex_factor * 1.02)
                else:
                    self.population_flex_factor = max(0.8, self.population_flex_factor * 0.98)
            target = base_target * self.population_flex_factor
        return int(target)

    @property
    def current_birth_threshold(self):
        return self._current_thresholds['birth']

    @property
    def current_death_threshold(self):
        death_thresh = self._current_thresholds['death']
        if isinstance(death_thresh, dict):
            return (death_thresh.get('contribution', 0.01) +
                    death_thresh.get('uniqueness', 0.3)) / 2
        return death_thresh

    def forward(self, x, return_analysis=False):
        self.step += 1

        h, _ = self.lstm(self.embed(x))
        last = h[:, -1, :]

        # Effective weights (Hebbian add-on)
        if self.c.disable_hebbian:
            W_eff = self.W.weight
        else:
            h_fast_normed = self.h_fast.to(self.W.weight.dtype)
            confidence = torch.sigmoid(self.error_reduction).unsqueeze(1)
            h_fast_gated = h_fast_normed * confidence
            W_eff = self.W.weight + self.c.hebb_eta * h_fast_gated

        pre_acts = F.linear(last, W_eff * self.mask.unsqueeze(1), self.W.bias)  # (B, N)
        pre_acts = self.neuron_norm(pre_acts)                                   # (B, N)
        acts = torch.sigmoid(pre_acts)                                          # (B, N)

        if self.training:
            current_pop = self.mask.sum()
            target_pop = self.get_adaptive_population_target(0, self.current_phase)
            self._capacity_penalty = self.compute_capacity_penalty(current_pop, target_pop)
            if self.c.sparsity_lambda > 0:
                sparsity_penalty = self.c.sparsity_lambda * acts.abs().mean()
                self._capacity_penalty += sparsity_penalty

        logits = self.out(acts)
        self._cache = (acts.detach(), last.detach(), None)

        if return_analysis:
            return logits, {
                'activations': acts,
                'hidden': last,
                'uniqueness': self.uniqueness_score,
                'phase_spec': self.phase_specialization,
                'competitive_score': self.competitive_score,
                'plasticity_levels': self.plasticity_levels
            }

        return logits

    def compute_capacity_penalty(self, current_population, target_population):
        if current_population <= target_population * self.c.capacity_penalty_start:
            return 0.0
        overpopulation_ratio = (current_population - target_population) / target_population
        penalty = self.c.population_penalty * (overpopulation_ratio.exp() - 1)
        return penalty.item() if torch.is_tensor(penalty) else penalty

In [10]:
# Celda 9

@torch.no_grad()
def compute_neuron_uniqueness(self, acts):
    """
    Calcula la unicidad de las neuronas basándose en la correlación de sus activaciones.
    La unicidad se define como 1 menos la correlación media absoluta con otras neuronas.
    - Se ejecuta en FP32 para estabilidad numérica.
    - Usa una máscara para considerar solo las neuronas activas.
    - Actualiza el buffer `self.uniqueness_score`.
    """
    active_mask = self.mask > 0
    active_indices = active_mask.nonzero(as_tuple=True)[0]

    if active_indices.numel() < 2:
        self.uniqueness_score.fill_(1.0)
        return

    active_acts = acts[:, active_indices].float()

    if torch.std(active_acts) < 1e-6:
        self.uniqueness_score[active_indices] = 0.0
        return

    correlation_matrix = torch.corrcoef(active_acts.T)
    correlation_matrix.nan_to_num_(0.0)

    abs_corr = torch.abs(correlation_matrix)
    avg_corr = (abs_corr.sum(dim=1) - 1) / max(1, (active_indices.numel() - 1))
    uniqueness = 1.0 - avg_corr

    decay = 0.95
    current_uniqueness = self.uniqueness_score[active_indices]
    self.uniqueness_score[active_indices] = decay * current_uniqueness + (1 - decay) * uniqueness.to(self.uniqueness_score.dtype)

PlasticLSTM.compute_neuron_uniqueness = compute_neuron_uniqueness


@torch.no_grad()
def competitive_structural_update(self, loss, loss_per_sample=None):
    # Asegurar constantes AMP disponibles
    try:
        AMP_DTYPE
    except NameError:
        AMP_DTYPE = torch.float16

    # --- Edad y plasticidad (mantener en FP32 por estabilidad) ---
    self.age += self.mask
    maturity_factor = torch.sigmoid((self.age - self.maturity_threshold) / 500)
    # plasticity_levels es FP32; mantener control en FP32
    self.plasticity_factor = 1.0 - (maturity_factor * 0.7) * self.plasticity_levels

    # Actualizar élite/ancoras de forma ocasional
    if self.step % 1000 == 0:
        if hasattr(self, '_identify_elite_neurons'): self._identify_elite_neurons()
        self._graduate_low_performing_anchors()

    acts, hidden, _ = self._cache

    # --- Métricas de unicidad y reducción de error (FP32) ---
    if acts is not None:
        # compute_neuron_uniqueness usa corrcoef -> FP32
        self.compute_neuron_uniqueness(acts)

        if loss_per_sample is not None:
            # menor pérdida que la media se considera "buena"
            loss_mean = loss
            neuron_impact = acts * (loss_per_sample.unsqueeze(1) < loss_mean)
            self.error_reduction = 0.9 * self.error_reduction + 0.1 * neuron_impact.mean(0)

        plasticity_weight = 0.5 + 0.5 * self.plasticity_levels
        survival_bonus = torch.clamp(self.age / 5000, 0, 0.3) * (1.0 - self.plasticity_levels * 0.5)

        self.competitive_score = (
            0.35 * self.contrib +
            0.30 * self.uniqueness_score +
            0.25 * self.error_reduction +
            0.10 * torch.sigmoid(self.age / 1000)
        ) * self.mask * plasticity_weight + survival_bonus * self.mask

        self.contrib = 0.9 * self.contrib + 0.1 * self.competitive_score

    if hidden is not None and loss_per_sample is not None and hasattr(self, 'track_error_patterns'):
        self.track_error_patterns(hidden, loss_per_sample)

    # --- Población y umbrales adaptativos ---
    current_pop = int(self.mask.sum().item())
    target_pop = self.get_adaptive_population_target(loss, self.current_phase)

    death_thresh, birth_thresh = self.threshold_manager.update(
        self.contrib, self.uniqueness_score, float(loss)
    )

    # Amortiguar transición de fase
    if self.phase_transition_buffer > 0:
        if isinstance(death_thresh, dict):
            death_thresh = {
                k: (v * 1.5 if isinstance(v, (int, float)) else v)
                for k, v in death_thresh.items()
            }
        else:
            death_thresh = death_thresh * 1.5
        birth_thresh *= 0.8

    self._current_thresholds['birth'] = float(birth_thresh)
    self._current_thresholds['death'] = death_thresh

    births_this_update = 0
    deaths_this_update = 0

    # --- Torneo si hay sobrepoblación grande ---
    if (current_pop > target_pop * 1.2 and
        self.step - self.last_tournament_step >= self.c.tournament_interval):

        num_to_kill = max(
            current_pop - int(target_pop * 1.05),
            int(current_pop * self.c.max_death_ratio)
        )

        active_indices = (self.mask > 0).nonzero(as_tuple=True)[0]
        scores = self.competitive_score[active_indices].clone()

        if self.current_phase == 'B' and len(self.phase_a_neurons) > 0:
            for idx in active_indices:
                if idx.item() in self.phase_a_neurons:
                    idx_in_active = (active_indices == idx).nonzero(as_tuple=True)[0]
                    if len(idx_in_active) > 0:
                        scores[idx_in_active[0]] *= (2.0 - self.plasticity_levels[idx])

        _, sorted_indices = scores.sort(descending=True)
        losers = active_indices[sorted_indices[-num_to_kill:]]

        deaths_this_update = self._kill_neurons(losers, reason='tournament', inherit_knowledge=True)
        self.last_tournament_step = self.step

    # --- Selección blanda por umbrales ---
    elif not self.c.disable_death and current_pop > target_pop * 1.1:
        if isinstance(death_thresh, dict):
            below_contrib_thresh = self.contrib < float(death_thresh.get('contribution', 0.01))
            below_unique_thresh = self.uniqueness_score < float(death_thresh.get('uniqueness', 0.3))
        else:
            below_contrib_thresh = self.contrib < float(death_thresh)
            below_unique_thresh = self.uniqueness_score < 0.3

        is_mature = self.age > self.c.maturation_time
        plasticity_death_factor = self.plasticity_levels.clamp(min=0.3)

        elimination_score = (
            below_contrib_thresh.float() * 0.4 +
            below_unique_thresh.float() * 0.4 +
            is_mature.float() * 0.2
        ) * self.mask * plasticity_death_factor

        death_candidates = (elimination_score > 0.5).nonzero(as_tuple=True)[0]

        if len(death_candidates) > 0:
            max_deaths = min(
                len(death_candidates),
                int(current_pop * 0.05),
                current_pop - int(target_pop * 0.9)
            )

            if max_deaths > 0:
                scores = elimination_score[death_candidates]
                _, top_indices = scores.topk(min(max_deaths, len(scores)))
                death_idx = death_candidates[top_indices]

                deaths_this_update = self._kill_neurons(death_idx, reason='selection', inherit_knowledge=True)

    # --- Historial de pérdidas para tendencias ---
    if hasattr(self, 'loss_history'):
        self.loss_history.append(float(loss))
        if len(self.loss_history) > 20:
            recent_improvement = (self.loss_history[-20] - self.loss_history[-1]) / (self.loss_history[-20] + 1e-12)
            if recent_improvement < 0.001:
                birth_thresh *= 0.95
    else:
        self.loss_history = deque(maxlen=100)
        self.loss_history.append(float(loss))

    # --- Nacimientos controlados ---
    should_birth = (
        float(loss) > float(birth_thresh) and
        current_pop < int(target_pop * 1.3) and
        self.step - self.last_birth_step >= self.c.birth_cooldown
    )

    if should_birth:
        max_allowed_births = max(0, int(target_pop * 1.3) - current_pop)
        if max_allowed_births > 0:
            if float(loss) > float(birth_thresh):
                error_factor = min((float(loss) - float(birth_thresh)) / float(birth_thresh), 1.0)
            else:
                error_factor = 0.5
            num_births = max(1, int(self.c.max_births_per_update * error_factor))
            num_births = min(num_births, max_allowed_births)

            births_this_update = self._smart_birth(num_births, float(loss), float(birth_thresh), loss_per_sample)

            if births_this_update > 0:
                self.last_birth_step = self.step

    if hasattr(self, 'loss_history') and len(self.loss_history) > 10:
        self.recent_loss_trend = (self.loss_history[-1] - self.loss_history[-10]) / 10.0


def _graduate_low_performing_anchors(self):
    if self.current_phase != 'B':
        return

    low_plasticity_indices = (self.plasticity_levels < 0.3).nonzero(as_tuple=True)[0]

    for idx in low_plasticity_indices:
        if self.mask[idx] > 0 and idx.item() in self.phase_a_neurons:
            performance_percentile = (self.contrib[idx] > self.contrib[self.mask > 0]).float().mean()

            if performance_percentile < 0.2 and self.age[idx] > 5000:
                self.plasticity_levels[idx] = min(0.8, float(self.plasticity_levels[idx]) + 0.3)
                self.phase_a_anchors.discard(idx.item())

            elif performance_percentile > 0.8 and self.plasticity_levels[idx] > 0:
                self.plasticity_levels[idx] = max(0.0, float(self.plasticity_levels[idx]) - 0.1)


def _kill_neurons(self, indices, reason='selection', inherit_knowledge=True):
    if len(indices) == 0:
        return 0

    protected_indices = []
    killable_indices = []

    for idx in indices:
        idx_val = idx.item() if torch.is_tensor(idx) else idx
        protection_threshold = 0.8 if reason == 'tournament' else 0.3

        if (idx_val in self.elite_neurons and reason != 'tournament') or \
           (self.plasticity_levels[idx] < protection_threshold and self.contrib[idx] > self.contrib.mean()):
            protected_indices.append(idx_val)
        else:
            killable_indices.append(idx)

    indices = torch.tensor(killable_indices, device=self.mask.device) if killable_indices else torch.tensor([], dtype=torch.long, device=self.mask.device)

    if len(indices) == 0:
        return 0

    if inherit_knowledge:
        for idx in indices:
            if self.contrib[idx] > self.contrib.mean() * 0.5:
                inheritors = self.find_inheritance_targets(idx, k=5)
                if inheritors:
                    self.inherit_knowledge(idx, inheritors)

    # Reset de estado (FP32); pesos quedan en FP32, hebb en AMP_DTYPE
    self.mask[indices] = 0
    self.age[indices] = 0
    self.contrib[indices] = 0
    self.uniqueness_score[indices] = 1
    self.error_reduction[indices] = 0
    self.competitive_score[indices] = 0
    self.W.weight.data[indices] = 0
    self.W.bias.data[indices] = 0
    self.out.weight.data[:, indices] = 0
    self.h_fast[indices] = torch.zeros_like(self.h_fast[indices], dtype=self.h_fast.dtype)
    self.inheritance_count[indices] = 0
    self.knowledge_sources[indices] = 0
    self.plasticity_levels[indices] = 1.0

    # Flags de fase
    self.phase_a_flag[indices] = False
    self.phase_b_flag[indices] = False

    for idx in indices:
        self.life_events.append({
            'type': 'death',
            'step': self.step,
            'id': idx.item(),
            'phase': self.current_phase,
            'reason': reason,
            'final_score': float(self.competitive_score[idx]),
            'knowledge_inherited': bool(inherit_knowledge),
            'plasticity_level': float(self.plasticity_levels[idx])
        })

    for idx in indices:
        idx_val = idx.item()
        self.phase_a_neurons.discard(idx_val)
        self.phase_b_neurons.discard(idx_val)
        self.elite_neurons.discard(idx_val)
        self.phase_a_anchors.discard(idx_val)

    return len(indices)


def _smart_birth(self, num_births, loss, birth_thresh, loss_per_sample=None):
    if self.current_phase == 'A':
        valid_range = (self.phase_regions['A'][0], self.phase_regions['shared'][1])
    else:
        valid_range = (self.phase_regions['shared'][0], self.phase_regions['B'][1])

    free_indices = (self.mask == 0).nonzero(as_tuple=True)[0]
    free_indices = free_indices[(free_indices >= valid_range[0]) & (free_indices < valid_range[1])]

    if len(free_indices) == 0 or num_births == 0:
        return 0

    num_births = min(num_births, len(free_indices))
    birth_indices = free_indices[torch.randperm(len(free_indices), device=free_indices.device)[:num_births]]

    births = 0
    for idx in birth_indices:
        init_type = 'random'

        if len(self.hard_patterns_bank) > 10 and random.random() < 0.7 and len(list(self.pattern_errors)) > 100:
            pattern_idx = np.argmax(list(self.pattern_errors)[-100:])
            hard_pattern = list(self.hard_patterns_bank)[-100:][pattern_idx].to(self.W.weight.device)
            noise = torch.randn_like(hard_pattern, dtype=self.W.weight.dtype) * 0.1
            self.W.weight.data[idx] = hard_pattern.to(self.W.weight.dtype) + noise
            init_type = 'hard_pattern'

        elif self.mask.sum() > 10 and random.random() < self.c.elite_clone_ratio:
            top_k = min(10, int(self.mask.sum()))
            elite = self.competitive_score.topk(top_k).indices
            donor = elite[random.randint(0, len(elite) - 1)]
            self.W.weight.data[idx] = self.W.weight.data[donor] + torch.randn_like(self.W.weight.data[donor]) * 0.15
            self.out.weight.data[:, idx] = self.out.weight.data[:, donor] + torch.randn_like(self.out.weight.data[:, donor]) * 0.1
            init_type = 'elite_clone'
        else:
            nn.init.xavier_uniform_(self.W.weight.data[idx:idx+1])
            nn.init.normal_(self.out.weight.data[:, idx], 0, 0.02)
            init_type = 'random'

        # Estado inicial
        self.mask[idx] = 1
        self.age[idx] = 0
        self.contrib[idx] = 0.05
        self.uniqueness_score[idx] = 0.5
        self.error_reduction[idx] = 0
        self.competitive_score[idx] = 0.05
        self.h_fast[idx] = torch.zeros_like(self.h_fast[idx], dtype=self.h_fast.dtype)
        self.W.bias.data[idx] = 0
        self.inheritance_count[idx] = 0
        self.knowledge_sources[idx] = 0
        self.plasticity_factor[idx] = 1.0
        self.plasticity_levels[idx] = 1.0

        if self.current_phase == 'B':
            self.phase_specialization[idx] = 1
            self.phase_b_neurons.add(idx.item())
            self.phase_b_flag[idx] = True
            self.phase_a_flag[idx] = False
        else:
            self.phase_specialization[idx] = 0
            self.phase_a_neurons.add(idx.item())
            self.phase_a_flag[idx] = True
            self.phase_b_flag[idx] = False

        self.life_events.append({
            'type': 'birth',
            'step': self.step,
            'id': idx.item(),
            'loss': float(loss),
            'phase': self.current_phase,
            'init_type': init_type,
            'birth_threshold': float(birth_thresh),
            'initial_plasticity': 1.0
        })

        births += 1

    return births


# Enlazar métodos a la clase
PlasticLSTM.structural_update = competitive_structural_update
PlasticLSTM._graduate_low_performing_anchors = _graduate_low_performing_anchors
PlasticLSTM._kill_neurons = _kill_neurons
PlasticLSTM._smart_birth = _smart_birth

In [11]:
# 10
@torch.no_grad()
def analyze_competition_dynamics(self):
    active_mask = self.mask > 0
    active_indices = active_mask.nonzero(as_tuple=True)[0]
    if active_indices.numel() == 0:
        return {}

    active_scores = self.competitive_score[active_indices]
    active_contribs = self.contrib[active_indices]
    active_unique = self.uniqueness_score[active_indices]
    active_plasticity = self.plasticity_levels[active_indices]

    top_10_pct = int(len(active_indices) * 0.1)
    bottom_10_pct = top_10_pct
    if top_10_pct > 0:
        top_scores = active_scores.topk(top_10_pct).values
        bottom_scores = active_scores.topk(bottom_10_pct, largest=False).values
        elite_gap = (top_scores.mean() - bottom_scores.mean()).item()
    else:
        elite_gap = 0.0

    # Conteos por fase
    phase_a_active = sum(1 for idx in active_indices.tolist() if idx in self.phase_a_neurons)
    phase_b_active = sum(1 for idx in active_indices.tolist() if idx in self.phase_b_neurons)

    # Plasticidad por fase
    if phase_a_active > 0:
        phase_a_indices = torch.tensor([i for i in active_indices.tolist() if i in self.phase_a_neurons],
                                       device=active_indices.device, dtype=torch.long)
        phase_a_plasticity = self.plasticity_levels[phase_a_indices].mean().item()
    else:
        phase_a_plasticity = 1.0

    if phase_b_active > 0:
        phase_b_indices = torch.tensor([i for i in active_indices.tolist() if i in self.phase_b_neurons],
                                       device=active_indices.device, dtype=torch.long)
        phase_b_plasticity = self.plasticity_levels[phase_b_indices].mean().item()
    else:
        phase_b_plasticity = 1.0

    low_plasticity_count = int((active_plasticity < 0.3).sum().item())
    medium_plasticity_count = int(((active_plasticity >= 0.3) & (active_plasticity < 0.7)).sum().item())
    high_plasticity_count = int((active_plasticity >= 0.7).sum().item())

    return {
        'active_count': int(active_indices.numel()),
        'avg_competitive_score': float(active_scores.mean().item()),
        'std_competitive_score': float(active_scores.std(unbiased=False).item()),
        'avg_contribution': float(active_contribs.mean().item()),
        'avg_uniqueness': float(active_unique.mean().item()),
        'elite_gap': float(elite_gap),
        'phase_a_ratio': float(phase_a_active / active_indices.numel()),
        'phase_b_ratio': float(phase_b_active / active_indices.numel()),
        'avg_plasticity': float(active_plasticity.mean().item()),
        'phase_a_avg_plasticity': float(phase_a_plasticity),
        'phase_b_avg_plasticity': float(phase_b_plasticity),
        'plasticity_distribution': {
            'low': low_plasticity_count,
            'medium': medium_plasticity_count,
            'high': high_plasticity_count
        },
        'cache_stats': self.computation_cache.get_stats() if hasattr(self, 'computation_cache') and self.computation_cache else None
    }


@torch.no_grad()
def get_population_health(self):
    current_pop = int(self.mask.sum().item())
    target_pop = int(self.get_adaptive_population_target(0.0, self.current_phase))

    if current_pop > 0:
        active_mask = self.mask > 0
        diversity = float(self.uniqueness_score[active_mask].mean().item())
        redundancy = 1.0 - diversity

        active_plasticity = self.plasticity_levels[active_mask]
        effective_capacity = float((self.plasticity_levels * self.mask).sum().item())

        phase_a_active = sum(1 for idx in active_mask.nonzero(as_tuple=True)[0].tolist()
                             if idx in self.phase_a_neurons)
        phase_b_active = sum(1 for idx in active_mask.nonzero(as_tuple=True)[0].tolist()
                             if idx in self.phase_b_neurons)

        knowledge_preservation = float((active_plasticity < 0.3).sum().item()) / current_pop
        adaptation_capacity = float((active_plasticity > 0.7).sum().item()) / current_pop
    else:
        diversity = 0.0
        redundancy = 1.0
        effective_capacity = 0.0
        phase_a_active = 0
        phase_b_active = 0
        knowledge_preservation = 0.0
        adaptation_capacity = 0.0

    return {
        'current_population': current_pop,
        'target_population': target_pop,
        'population_ratio': float(current_pop / max(1, self.c.max_neurons)),
        'overpopulation': max(0, current_pop - target_pop),
        'diversity': diversity,
        'redundancy': redundancy,
        'capacity_penalty': float(self._capacity_penalty),
        'effective_capacity': float(effective_capacity / max(1, self.c.max_neurons)),
        'phase_distribution': {
            'phase_a': int(phase_a_active),
            'phase_b': int(phase_b_active),
            'unassigned': int(current_pop - phase_a_active - phase_b_active)
        },
        'knowledge_preservation_ratio': knowledge_preservation,
        'adaptation_capacity_ratio': adaptation_capacity,
        'plasticity_balance': float(1.0 - abs(knowledge_preservation - adaptation_capacity)),
    }


@torch.no_grad()
def get_continual_learning_metrics(self):
    active_mask = self.mask > 0
    active_indices = active_mask.nonzero(as_tuple=True)[0]
    if active_indices.numel() == 0:
        return {}

    phase_a_list = [idx for idx in active_indices.tolist() if idx in self.phase_a_neurons]
    phase_b_list = [idx for idx in active_indices.tolist() if idx in self.phase_b_neurons]

    phase_a_count_total = len(self.phase_a_neurons) if len(self.phase_a_neurons) > 0 else 1
    phase_a_retention = len(phase_a_list) / phase_a_count_total

    phase_b_growth = len(phase_b_list) / (len(self.phase_b_neurons) + 1e-6)

    cross_phase_neurons = 0
    if hasattr(self, 'cross_phase_utility'):
        for idx in active_indices:
            if self.cross_phase_utility[idx] > 0.5:
                cross_phase_neurons += 1

    inheritance_active = int((self.inheritance_count[active_indices] > 0).sum().item())
    knowledge_transfer_ratio = inheritance_active / active_indices.numel()

    if len(phase_a_list) > 0:
        pa_idx = torch.tensor(phase_a_list, device=active_indices.device, dtype=torch.long)
        pa_plast = self.plasticity_levels[pa_idx].mean().item()
    else:
        pa_plast = 0.0

    if len(phase_b_list) > 0:
        pb_idx = torch.tensor(phase_b_list, device=active_indices.device, dtype=torch.long)
        pb_plast = self.plasticity_levels[pb_idx].mean().item()
    else:
        pb_plast = 0.0

    return {
        'phase_a_retention': float(phase_a_retention),
        'phase_b_growth': float(phase_b_growth),
        'cross_phase_neurons': int(cross_phase_neurons),
        'knowledge_transfer_ratio': float(knowledge_transfer_ratio),
        'avg_inheritance_depth': float(self.inheritance_count[active_indices].float().mean().item()),
        'plasticity_gradient': {
            'phase_a': float(pa_plast),
            'phase_b': float(pb_plast),
            'difference': float(abs(pa_plast - pb_plast))
        }
    }


PlasticLSTM.analyze_competition_dynamics = analyze_competition_dynamics
PlasticLSTM.get_population_health = get_population_health
PlasticLSTM.get_continual_learning_metrics = get_continual_learning_metrics


In [12]:
# 11
@torch.no_grad()
def efficient_hebbian_update(self, loss_per_sample=None):
    """
    Actualización hebbiana eficiente y AMP-safe.
    - Alinea dtypes para evitar conflictos bajo autocast (fp16/fp32).
    - Selecciona top neuronas por competitive_score.
    - Normaliza y decae h_fast con compuertas de plasticidad.
    """
    if self.c.disable_hebbian:
        return

    acts, last, _ = self._cache
    if acts is None:
        return

    # Filtrado por buenas muestras
    if loss_per_sample is not None:
        good_samples = loss_per_sample < loss_per_sample.mean()
        if not good_samples.any():
            return
        acts = acts[good_samples]
        last = last[good_samples]

    num_active = int(self.mask.sum())
    if num_active > 10:
        num_to_update = int(num_active * self.c.hebb_top_neurons_ratio)
        top_neurons = self.competitive_score.topk(num_to_update).indices
        update_mask = torch.zeros_like(self.mask)
        update_mask[top_neurons] = 1
    else:
        update_mask = self.mask

    # Centrado
    acts_centered = acts - acts.mean(1, keepdim=True)   # (B, N)
    last_centered = last - last.mean(1, keepdim=True)   # (B, H)

    # === Arreglo AMP: alinear dtypes ===
    compute_dtype = self.h_fast.dtype  # normalmente float16
    acts_centered = acts_centered.to(compute_dtype)
    last_centered = last_centered.to(compute_dtype)

    plasticity_modulation = (self.plasticity_levels * update_mask).to(compute_dtype)          # (N,)
    uniqueness_weight = (self.uniqueness_score * plasticity_modulation).unsqueeze(1).to(compute_dtype)  # (N,1)

    # dW: (N, H) con einsum
    dw = torch.einsum('bn,bh->nh', acts_centered, last_centered)
    dw = dw / max(1, acts_centered.size(0))
    dw = dw * uniqueness_weight

    plasticity_mask = self.plasticity_factor.unsqueeze(1).to(compute_dtype)  # (N,1)
    dw = dw * plasticity_mask

    # Decaimiento dependiente de unicidad y plasticidad
    decay_base = self.c.hebb_decay + (1 - self.c.hebb_decay) * (1 - self.uniqueness_score)
    decay_plasticity_adjusted = decay_base + (1 - decay_base) * (1 - self.plasticity_levels)
    decay_rate = decay_plasticity_adjusted.unsqueeze(1).to(compute_dtype)  # (N,1)

    mask_expanded = plasticity_modulation.unsqueeze(1)  # (N,1) ya en compute_dtype
    self.h_fast = self.h_fast * decay_rate + dw * mask_expanded

    # Clip norm con modulación de plasticidad (AMP-safe)
    if update_mask.sum() > 0:
        norms = self.h_fast.norm(dim=1, keepdim=True)  # (N,1) en compute_dtype
        norm_factor = (norms / self.c.hebb_norm_clip).clamp(min=1.0)
        # Ajuste depende de plasticidad (convertir a compute_dtype)
        plast_levels = self.plasticity_levels.unsqueeze(1).to(compute_dtype)
        norm_adjustment = 1.0 + (norm_factor - 1.0) * plast_levels
        self.h_fast = self.h_fast / norm_adjustment

PlasticLSTM.hebbian_update = efficient_hebbian_update


In [13]:
# 12
def calculate_efficiency_metrics(self):
    active_neurons = int(self.mask.sum().item())
    total_neurons = self.c.max_neurons

    active_params = (
        active_neurons * self.c.lstm_hidden_size +          # W.weight
        active_neurons +                                    # W.bias
        active_neurons * self.c.vocab_size                  # out.weight
    )

    total_potential_params = (
        total_neurons * self.c.lstm_hidden_size +
        total_neurons +
        total_neurons * self.c.vocab_size
    )

    return {
        'active_neurons': active_neurons,
        'total_neurons': total_neurons,
        'neuron_utilization': active_neurons / max(1, total_neurons),
        'active_params': active_params,
        'total_potential_params': total_potential_params,
        'param_efficiency': active_params / max(1, total_potential_params),
        'memory_saved_ratio': 1 - (active_params / max(1, total_potential_params))
    }


def analyze_phase_specialization(self):
    phase_a_born = len(self.phase_a_neurons)
    phase_b_born = len(self.phase_b_neurons)

    active_mask = self.mask > 0
    active_indices = set(active_mask.nonzero(as_tuple=True)[0].cpu().numpy())

    phase_a_active = len(self.phase_a_neurons.intersection(active_indices))
    phase_b_active = len(self.phase_b_neurons.intersection(active_indices))

    results = {
        'phase_a_born': phase_a_born,
        'phase_a_active': phase_a_active,
        'phase_b_born': phase_b_born,
        'phase_b_active': phase_b_active
    }

    if phase_a_active > 0:
        phase_a_indices = torch.tensor(list(self.phase_a_neurons.intersection(active_indices)),
                                       device=self.mask.device, dtype=torch.long)
        results['phase_a_avg_contrib'] = float(self.contrib[phase_a_indices].mean().item())
    else:
        results['phase_a_avg_contrib'] = 0.0

    if phase_b_active > 0:
        phase_b_indices = torch.tensor(list(self.phase_b_neurons.intersection(active_indices)),
                                       device=self.mask.device, dtype=torch.long)
        results['phase_b_avg_contrib'] = float(self.contrib[phase_b_indices].mean().item())
    else:
        results['phase_b_avg_contrib'] = 0.0

    return results


def get_neuron_info(self, idx):
    birth_events = [e for e in self.life_events if e['type'] == 'birth' and e['id'] == idx]
    death_events = [e for e in self.life_events if e['type'] == 'death' and e['id'] == idx]

    return {
        'id': idx,
        'age': int(self.age[idx].item()),
        'contribution': float(self.contrib[idx].item()),
        'uniqueness': float(self.uniqueness_score[idx].item()),
        'competitive_score': float(self.competitive_score[idx].item()),
        'is_active': bool(self.mask[idx].item()),
        'birth_death': birth_events + death_events,
        'phase': 'A' if idx in self.phase_a_neurons else ('B' if idx in self.phase_b_neurons else 'Unknown')
    }


def get_inheritance_statistics(self):
    active_mask = self.mask > 0
    active_indices = active_mask.nonzero(as_tuple=True)[0]
    if active_indices.numel() == 0:
        return {}

    active_inheritance = self.inheritance_count[active_indices]
    active_sources = self.knowledge_sources[active_indices]

    stats = {
        'total_inheritance_events': len(self.inheritance_history),
        'neurons_with_inheritance': int((active_inheritance > 0).sum().item()),
        'avg_inheritance_per_neuron': float(active_inheritance.float().mean().item()),
        'max_inheritance_count': int(active_inheritance.max().item()),
        'avg_knowledge_sources': float(active_sources.float().mean().item()),
        'inheritance_ratio': float((active_inheritance > 0).sum().item()) / active_indices.numel()
    }

    phase_a_inheritance = sum(1 for event in self.inheritance_history if event.get('phase') == 'A')
    phase_b_inheritance = sum(1 for event in self.inheritance_history if event.get('phase') == 'B')

    stats['phase_a_inheritance_events'] = int(phase_a_inheritance)
    stats['phase_b_inheritance_events'] = int(phase_b_inheritance)
    return stats


def calculate_functional_diversity(self, test_sequences=None, max_neurons_eval=512, batches=6, seq_len=32):
    """
    Estima diversidad funcional midiendo la entropía del histograma de activaciones.
    Optimizado para no crear tensores gigantes:
      - Muestrea hasta `max_neurons_eval` neuronas activas.
      - Usa pocas secuencias cortas (batches * seq_len).
    """
    device = next(self.parameters()).device

    # Selección de neuronas activas a evaluar
    active_idx = (self.mask > 0).nonzero(as_tuple=True)[0]
    if active_idx.numel() == 0:
        return {
            'mean_functional_diversity': 0.0,
            'std_functional_diversity': 0.0,
            'min_functional_diversity': 0.0,
            'max_functional_diversity': 0.0
        }

    if active_idx.numel() > max_neurons_eval:
        perm = torch.randperm(active_idx.numel(), device=active_idx.device)[:max_neurons_eval]
        eval_idx = active_idx[perm]
    else:
        eval_idx = active_idx

    # Construcción de secuencias de prueba si no se pasan
    if test_sequences is None:
        test_sequences = []
        for _ in range(batches):
            seq = torch.randint(0, self.c.vocab_size, (1, seq_len), device=device, dtype=torch.long)
            test_sequences.append(seq)

    # Colección de activaciones sólo para las neuronas seleccionadas
    activations = []
    with torch.no_grad():
        for seq in test_sequences:
            logits, analysis = self.forward(seq, return_analysis=True)
            acts = analysis['activations']  # (B, N)
            activations.append(acts[:, eval_idx].detach().float().cpu())

    if not activations:
        return {
            'mean_functional_diversity': 0.0,
            'std_functional_diversity': 0.0,
            'min_functional_diversity': 0.0,
            'max_functional_diversity': 0.0
        }

    acts_mat = torch.cat(activations, dim=0).clamp(0, 1)  # (T, K)

    # Entropía por neurona con 10 bins
    K = acts_mat.shape[1]
    diversity_scores = []
    for k in range(K):
        hist = torch.histc(acts_mat[:, k], bins=10, min=0.0, max=1.0)
        s = hist.sum()
        if s <= 0:
            diversity_scores.append(0.0)
            continue
        p = (hist / s).clamp_min(1e-10)
        entropy = float((-p * torch.log(p)).sum().item())
        diversity_scores.append(entropy)

    diversity_scores = torch.tensor(diversity_scores)
    return {
        'mean_functional_diversity': float(diversity_scores.mean().item()),
        'std_functional_diversity': float(diversity_scores.std(unbiased=False).item()),
        'min_functional_diversity': float(diversity_scores.min().item()),
        'max_functional_diversity': float(diversity_scores.max().item())
    }


PlasticLSTM.calculate_efficiency_metrics = calculate_efficiency_metrics
PlasticLSTM.analyze_phase_specialization = analyze_phase_specialization
PlasticLSTM.get_neuron_info = get_neuron_info
PlasticLSTM.get_inheritance_statistics = get_inheritance_statistics
PlasticLSTM.calculate_functional_diversity = calculate_functional_diversity


import json
import pandas as pd

class NeuronAnalyzer:
    def __init__(self, model, cfg, val_loader, itos, model_type='plastic'):
        self.model = model
        self.cfg = cfg
        self.val_loader = val_loader
        self.itos = itos
        self.history = defaultdict(list)
        self.model_type = model_type

        if self.model_type == 'plastic':
            self.csv_dir = self.cfg.plastic_csv_dir
            self.checkpoints_dir = self.cfg.plastic_checkpoints_dir
        else:
            self.csv_dir = os.path.join(self.cfg.baseline_save_dir, 'csv_data')
            self.checkpoints_dir = os.path.join(self.cfg.baseline_save_dir, 'checkpoints')
            os.makedirs(self.csv_dir, exist_ok=True)
            os.makedirs(self.checkpoints_dir, exist_ok=True)

        self.history['inheritance_events'] = []
        self.history['neurons_with_inheritance'] = []
        self.history['inheritance_ratio'] = []

    @torch.no_grad()
    def evaluate(self, loader, device):
        self.model.eval()
        total_loss, n = 0.0, 0

        use_cuda = torch.cuda.is_available() and device.type == 'cuda'
        # Elegir dtype de AMP (bf16 preferido si está soportado)
        amp_dtype = torch.bfloat16
        if use_cuda:
            major, _ = torch.cuda.get_device_capability(device)
            if major < 8:  # Ampere o menor: usar fp16
                amp_dtype = torch.float16

        for x, y in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            if use_cuda:
                with torch.cuda.amp.autocast(dtype=amp_dtype):
                    output = self.model(x)
                    logits = output[0] if isinstance(output, tuple) else output
                    loss = F.cross_entropy(logits, y[:, -1])
            else:
                output = self.model(x)
                logits = output[0] if isinstance(output, tuple) else output
                loss = F.cross_entropy(logits, y[:, -1])

            bs = x.size(0)
            total_loss += float(loss.item()) * bs
            n += bs

        avg_loss = total_loss / max(1, n)
        ppl = math.exp(avg_loss) if avg_loss < 700 else float('inf')
        return avg_loss, ppl

    def log_metrics(self, step, loss, ppl, active_neurons, phase='A'):
        self.history['step'].append(step)
        self.history['loss'].append(loss)
        self.history['ppl'].append(ppl)
        self.history['active_neurons'].append(active_neurons)
        self.history['active'].append(active_neurons)
        self.history['phase'].append(phase)
        if self.model_type == 'plastic' and self.cfg.adaptive_thresholds:
            self.history['birth_threshold'].append(float(self.model.current_birth_threshold))
            self.history['death_threshold'].append(float(self.model.current_death_threshold))

    def log_inheritance_metrics(self, step):
        if hasattr(self.model, 'get_inheritance_statistics'):
            inheritance_stats = self.model.get_inheritance_statistics()
            self.history['inheritance_events'].append(inheritance_stats.get('total_inheritance_events', 0))
            self.history['neurons_with_inheritance'].append(inheritance_stats.get('neurons_with_inheritance', 0))
            self.history['inheritance_ratio'].append(inheritance_stats.get('inheritance_ratio', 0.0))

    def save_plot(self, fig, name):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{self.cfg.plots_dir}/{name}_{timestamp}.html"
        fig.write_html(filename)
        print(f" Saved plot: {filename}")
        return filename

    def save_matplotlib_plot(self, name):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{self.cfg.plots_dir}/{name}_{timestamp}.png"
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        print(f" Saved plot: {filename}")
        return filename

    def save_all_data_to_csv(self):
        print(f"\nGuardando datos en CSV en: {self.csv_dir}")
        df_history = pd.DataFrame(self.history)
        df_history.to_csv(f"{self.csv_dir}/training_history.csv", index=False)
        print(" - training_history.csv guardado.")

        if hasattr(self.model, 'generation_history') and self.model.generation_history:
            gen_data = []
            for step, samples in self.model.generation_history.items():
                for sample in samples:
                    gen_data.append({
                        'step': step,
                        'prompt': sample.get('prompt', ''),
                        'generated_text': sample.get('generated_text', ''),
                        'repetition_rate': sample.get('repetition_rate', 0.0)
                    })
            if gen_data:
                df_gen = pd.DataFrame(gen_data)
                df_gen.to_csv(f"{self.csv_dir}/generation_samples.csv", index=False)
                print(" - generation_samples.csv guardado.")

        if self.model_type == 'plastic':
            if hasattr(self.model, 'life_events') and self.model.life_events:
                df_life = pd.DataFrame(self.model.life_events)
                df_life.to_csv(f"{self.csv_dir}/neuron_life_events.csv", index=False)
                print(" - neuron_life_events.csv guardado.")
            if hasattr(self.model, 'inheritance_history') and self.model.inheritance_history:
                df_inheritance = pd.DataFrame(self.model.inheritance_history)
                df_inheritance.to_csv(f"{self.csv_dir}/inheritance_history.csv", index=False)
                print(" - inheritance_history.csv guardado.")
        print("Todos los datos CSV han sido guardados.")

    def generate_final_report(self):
        print("\n" + "="*80)
        print(f"{self.model_type.upper()} LSTM - REPORTE NUMÉRICO FINAL")
        print("="*80)

        final_ppl = float(self.history['ppl'][-1]) if self.history['ppl'] else float('nan')
        print(f"\n[1] RENDIMIENTO:")
        print(f"  - Perplejidad Final: {final_ppl:.2f}")

        phase_a_ppls = [p for p, ph in zip(self.history['ppl'], self.history['phase']) if ph == 'A']
        phase_b_ppls = [p for p, ph in zip(self.history['ppl'], self.history['phase']) if ph == 'B']
        if phase_a_ppls and phase_b_ppls:
            print(f"  - PPL Final (Fase A): {phase_a_ppls[-1]:.2f}")
            print(f"  - PPL Inicial (Fase B): {phase_b_ppls[0]:.2f}")
            print(f"  - PPL Final (Fase B): {phase_b_ppls[-1]:.2f}")
            print(f"  - Mejora por Adaptación (Caída de PPL en Fase B): {phase_b_ppls[0] - phase_b_ppls[-1]:.2f} puntos")

        if self.model_type == 'plastic':
            eff_metrics = self.model.calculate_efficiency_metrics()
            spec_metrics = self.model.analyze_phase_specialization()

            print("\n[2] EFICIENCIA:")
            print(f"  - Neuronas Activas: {int(eff_metrics['active_neurons'])} / {self.cfg.max_neurons} ({eff_metrics['neuron_utilization'] * 100:.2f}%)")
            print(f"  - Parámetros Activos: {int(eff_metrics['active_params']):,}")
            print(f"  - Ahorro de Memoria Estimado: {eff_metrics['memory_saved_ratio'] * 100:.2f}%")

            print("\n[3] DINÁMICA DE POBLACIÓN:")
            print(f"  - Supervivencia Fase A: {spec_metrics.get('phase_a_active', 0)} de {spec_metrics.get('phase_a_born', 0)} nacidas.")
            print(f"  - Supervivencia Fase B: {spec_metrics.get('phase_b_active', 0)} de {spec_metrics.get('phase_b_born', 0)} nacidas.")

        print("\n" + "="*80)

    def save_checkpoint(self):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        checkpoint_path = f"{self.checkpoints_dir}/{self.model_type}_lstm_{timestamp}.pt"
        config_dict = {k: v for k, v in vars(self.cfg).items() if not k.startswith('__') and not callable(v)}
        torch.save({'model_state_dict': self.model.state_dict(), 'config': config_dict}, checkpoint_path)
        print(f"\nCheckpoint del modelo guardado en: {checkpoint_path}")


In [14]:
import json
import pandas as pd

class NeuronAnalyzer:
    def __init__(self, model, cfg, val_loader, itos, model_type='plastic'):
        self.model = model
        self.cfg = cfg
        self.val_loader = val_loader
        self.itos = itos
        self.history = defaultdict(list)
        self.model_type = model_type

        if self.model_type == 'plastic':
            self.csv_dir = self.cfg.plastic_csv_dir
            self.checkpoints_dir = self.cfg.plastic_checkpoints_dir
        else:
            self.csv_dir = os.path.join(self.cfg.baseline_save_dir, 'csv_data')
            self.checkpoints_dir = os.path.join(self.cfg.baseline_save_dir, 'checkpoints')
            os.makedirs(self.csv_dir, exist_ok=True)
            os.makedirs(self.checkpoints_dir, exist_ok=True)

        self.history['inheritance_events'] = []
        self.history['neurons_with_inheritance'] = []
        self.history['inheritance_ratio'] = []

    @torch.no_grad()
    def evaluate(self, loader, device):
        self.model.eval()
        total_loss, n = 0, 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            output = self.model(x)
            logits = output[0] if isinstance(output, tuple) else output
            loss = F.cross_entropy(logits, y[:, -1])
            total_loss += loss.item() * x.size(0)
            n += x.size(0)
        avg_loss = total_loss / max(1, n)
        ppl = math.exp(avg_loss) if avg_loss < 700 else float('inf')
        return avg_loss, ppl

    def log_metrics(self, step, loss, ppl, active_neurons, phase='A'):
        self.history['step'].append(step)
        self.history['loss'].append(loss)
        self.history['ppl'].append(ppl)
        self.history['active_neurons'].append(active_neurons)
        self.history['active'].append(active_neurons)
        self.history['phase'].append(phase)
        if self.model_type == 'plastic' and self.cfg.adaptive_thresholds:
            self.history['birth_threshold'].append(self.model.current_birth_threshold)
            self.history['death_threshold'].append(self.model.current_death_threshold)

    def log_inheritance_metrics(self, step):
        if hasattr(self.model, 'get_inheritance_statistics'):
            inheritance_stats = self.model.get_inheritance_statistics()
            self.history['inheritance_events'].append(inheritance_stats.get('total_inheritance_events', 0))
            self.history['neurons_with_inheritance'].append(inheritance_stats.get('neurons_with_inheritance', 0))
            self.history['inheritance_ratio'].append(inheritance_stats.get('inheritance_ratio', 0))

    def save_plot(self, fig, name):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{self.cfg.plots_dir}/{name}_{timestamp}.html"
        fig.write_html(filename)
        print(f" Saved plot: {filename}")
        return filename

    def save_matplotlib_plot(self, name):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{self.cfg.plots_dir}/{name}_{timestamp}.png"
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        print(f" Saved plot: {filename}")
        return filename

    def save_all_data_to_csv(self):
        print(f"\nGuardando datos en CSV en: {self.csv_dir}")
        df_history = pd.DataFrame(self.history)
        df_history.to_csv(f"{self.csv_dir}/training_history.csv", index=False)
        print(" - training_history.csv guardado.")
        if hasattr(self.model, 'generation_history') and self.model.generation_history:
            gen_data = []
            for step, samples in self.model.generation_history.items():
                for sample in samples:
                    gen_data.append({
                        'step': step,
                        'prompt': sample.get('prompt', ''),
                        'generated_text': sample.get('generated_text', ''),
                        'repetition_rate': sample.get('repetition_rate', 0.0)
                    })
            if gen_data:
                df_gen = pd.DataFrame(gen_data)
                df_gen.to_csv(f"{self.csv_dir}/generation_samples.csv", index=False)
                print(" - generation_samples.csv guardado.")
        if self.model_type == 'plastic':
            if hasattr(self.model, 'life_events') and self.model.life_events:
                df_life = pd.DataFrame(self.model.life_events)
                df_life.to_csv(f"{self.csv_dir}/neuron_life_events.csv", index=False)
                print(" - neuron_life_events.csv guardado.")
            if hasattr(self.model, 'inheritance_history') and self.model.inheritance_history:
                df_inheritance = pd.DataFrame(self.model.inheritance_history)
                df_inheritance.to_csv(f"{self.csv_dir}/inheritance_history.csv", index=False)
                print(" - inheritance_history.csv guardado.")
        print("Todos los datos CSV han sido guardados.")

    def generate_final_report(self):
        print("\n" + "="*80)
        print(f"{self.model_type.upper()} LSTM - REPORTE NUMÉRICO FINAL")
        print("="*80)
        final_ppl = float(self.history['ppl'][-1]) if self.history['ppl'] else float('nan')
        print(f"\n[1] RENDIMIENTO:")
        print(f"  - Perplejidad Final: {final_ppl:.2f}")
        phase_a_ppls = [p for p, ph in zip(self.history['ppl'], self.history['phase']) if ph == 'A']
        phase_b_ppls = [p for p, ph in zip(self.history['ppl'], self.history['phase']) if ph == 'B']
        if phase_a_ppls and phase_b_ppls:
            print(f"  - PPL Final (Fase A): {phase_a_ppls[-1]:.2f}")
            print(f"  - PPL Inicial (Fase B): {phase_b_ppls[0]:.2f}")
            print(f"  - PPL Final (Fase B): {phase_b_ppls[-1]:.2f}")
            print(f"  - Mejora por Adaptación (Caída de PPL en Fase B): {phase_b_ppls[0] - phase_b_ppls[-1]:.2f} puntos")
        if self.model_type == 'plastic':
            eff_metrics = self.model.calculate_efficiency_metrics()
            spec_metrics = self.model.analyze_phase_specialization()
            print("\n[2] EFICIENCIA:")
            print(f"  - Neuronas Activas: {int(eff_metrics['active_neurons'])} / {self.cfg.max_neurons} ({eff_metrics['neuron_utilization'] * 100:.2f}%)")
            print(f"  - Parámetros Activos: {int(eff_metrics['active_params']):,}")
            print(f"  - Ahorro de Memoria Estimado: {eff_metrics['memory_saved_ratio'] * 100:.2f}%")
            print("\n[3] DINÁMICA DE POBLACIÓN:")
            print(f"  - Supervivencia Fase A: {spec_metrics.get('phase_a_active', 0)} de {spec_metrics.get('phase_a_born', 0)} nacidas.")
            print(f"  - Supervivencia Fase B: {spec_metrics.get('phase_b_active', 0)} de {spec_metrics.get('phase_b_born', 0)} nacidas.")
        print("\n" + "="*80)

    def save_checkpoint(self):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        checkpoint_path = f"{self.checkpoints_dir}/{self.model_type}_lstm_{timestamp}.pt"
        config_dict = {k: v for k, v in vars(self.cfg).items() if not k.startswith('__') and not callable(v)}
        torch.save({'model_state_dict': self.model.state_dict(), 'config': config_dict}, checkpoint_path)
        print(f"\nCheckpoint del modelo guardado en: {checkpoint_path}")


def plot_training_progress_enhanced(self):
    fig = make_subplots(
        rows=6, cols=1,
        subplot_titles=('Validation Loss', 'Validation Perplexity', 'Active Neurons',
                       'Birth/Death Events', 'Adaptive Thresholds', 'Generation Quality (Repetition Rate)'),
        vertical_spacing=0.07,
        shared_xaxes=True
    )
    if 'phase' in self.history and len(self.history['phase']) > 1:
        phase_changes = [self.history['step'][i] for i in range(1, len(self.history['phase'])) if self.history['phase'][i] != self.history['phase'][i-1]]
        last_step = self.history['step'][-1]
        shapes = []
        shapes.append(dict(type="rect", xref="x", yref="paper", x0=0, y0=0, x1=phase_changes[0] if phase_changes else last_step, y1=1, fillcolor="rgba(70,70,120,0.2)", layer="below", line_width=0))
        if phase_changes:
            shapes.append(dict(type="rect", xref="x", yref="paper", x0=phase_changes[0], y0=0, x1=last_step, y1=1, fillcolor="rgba(120,120,70,0.2)", layer="below", line_width=0))
        fig.update_layout(shapes=shapes)
    fig.add_trace(go.Scatter(x=self.history['step'], y=self.history['loss'], name='Loss', line=dict(color='cyan')), row=1, col=1)
    fig.add_trace(go.Scatter(x=self.history['step'], y=self.history['ppl'], name='PPL', line=dict(color='magenta')), row=2, col=1)
    fig.add_trace(go.Scatter(x=self.history['step'], y=self.history['active'], name='Active Neurons', line=dict(color='lime')), row=3, col=1)
    if self.model.life_events:
        births = [e for e in self.model.life_events if e['type'] == 'birth']
        deaths = [e for e in self.model.life_events if e['type'] == 'death']
        if births:
            fig.add_trace(go.Scatter(x=[e['step'] for e in births], y=[1]*len(births), mode='markers', name='Births', marker=dict(symbol='star', size=8)), row=4, col=1)
        if deaths:
            fig.add_trace(go.Scatter(x=[e['step'] for e in deaths], y=[0]*len(deaths), mode='markers', name='Deaths', marker=dict(symbol='x', size=8)), row=4, col=1)
    if self.cfg.adaptive_thresholds and 'birth_threshold' in self.history:
        fig.add_trace(go.Scatter(x=self.history['step'], y=self.history['birth_threshold'], name='Birth Thresh', line=dict(color='orange')), row=5, col=1)
        fig.add_trace(go.Scatter(x=self.history['step'], y=[t*100 for t in self.history['death_threshold']], name='Death Thresh (x100)', line=dict(color='tomato', dash='dash')), row=5, col=1)
    if hasattr(self.model, 'generation_history') and self.model.generation_history:
        steps, repetitions = [], []
        for step, samples in sorted(self.model.generation_history.items()):
            if samples:
                steps.append(step)
                avg_rep = np.mean([s['repetition_rate'] for s in samples])
                repetitions.append(avg_rep)
        if steps:
            fig.add_trace(go.Scatter(x=steps, y=repetitions, name='Repetition Rate', line=dict(color='yellow'), mode='lines+markers'), row=6, col=1)
    fig.update_layout(template='plotly_dark', height=1500, title_text="Enhanced Training Progress (Plastic LSTM)")
    fig.show()
    self.save_plot(fig, "training_progress_enhanced_plastic")


def analyze_top_neurons(self, k=10):
    active_mask = self.model.mask > 0
    active_idx = active_mask.nonzero(as_tuple=True)[0]
    if len(active_idx) == 0:
        print("No active neurons!")
        return
    contribs = self.model.contrib[active_idx]
    top_idx = active_idx[contribs.argsort(descending=True)[:k]]
    print(f"\nTop {k} neurons by contribution:")
    print("-" * 80)
    neuron_data = []
    for rank, idx in enumerate(top_idx):
        info = self.model.get_neuron_info(idx.item())
        neuron_data.append(info)
        print(f"Rank {rank+1}: Neuron {idx.item()}")
        print(f"  Age: {info['age']} steps | Contribution: {info['contribution']:.6f}")
        print(f"  Born in phase: {[e['phase'] for e in info['birth_death'] if e['type'] == 'birth'][0] if info['birth_death'] else 'Initial'}")
        print()
    return neuron_data


def generate_report(self):
    print("\n" + "="*80)
    print("NEUROPLASTIC LSTM - FINAL ANALYSIS REPORT")
    print("="*80)
    phase_a_births = len([e for e in self.model.life_events if e['type'] == 'birth' and e.get('phase', 'A') == 'A'])
    phase_b_births = len([e for e in self.model.life_events if e['type'] == 'birth' and e.get('phase') == 'B'])
    phase_a_deaths = len([e for e in self.model.life_events if e['type'] == 'death' and e.get('phase', 'A') == 'A'])
    phase_b_deaths = len([e for e in self.model.life_events if e['type'] == 'death' and e.get('phase') == 'B'])
    print("\n[1] POPULATION DYNAMICS:")
    print(f"  - Initial Neurons: {self.cfg.initial_active}")
    print(f"  - Final Active Neurons: {int(self.model.mask.sum().item())}")
    print(f"  - Phase A: {phase_a_births} Births, {phase_a_deaths} Deaths")
    print(f"  - Phase B: {phase_b_births} Births, {phase_b_deaths} Deaths")
    eff_metrics = self.model.calculate_efficiency_metrics()
    print("\n[2] PARAMETER EFFICIENCY:")
    print(f"  - Active Parameters: {int(eff_metrics['active_params']):,} / {int(eff_metrics['total_potential_params']):,}")
    print(f"  - Utilization: {eff_metrics['param_efficiency']*100:.2f}% of potential capacity")
    print(f"  - Estimated Memory Saved: {eff_metrics['memory_saved_ratio']*100:.2f}%")
    spec_metrics = self.model.analyze_phase_specialization()
    print("\n[3] NEURONAL SPECIALIZATION:")
    print(f"  - Phase A Born & Active: {spec_metrics['phase_a_active']} / {spec_metrics['phase_a_born']} ({spec_metrics['phase_a_active']/max(1,spec_metrics['phase_a_born']):.1%} survival)")
    print(f"  - Phase B Born & Active: {spec_metrics['phase_b_active']} / {spec_metrics['phase_b_born']} ({spec_metrics['phase_b_active']/max(1,spec_metrics['phase_b_born']):.1%} survival)")
    if 'phase_a_avg_contrib' in spec_metrics:
        print(f"  - Avg. Contribution (Phase A neurons): {spec_metrics.get('phase_a_avg_contrib', 0):.4f}")
        print(f"  - Avg. Contribution (Phase B neurons): {spec_metrics.get('phase_b_avg_contrib', 0):.4f}")
    print("\n" + "="*80)


def save_all_results(self):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    extended_analysis = {
        'efficiency': self.model.calculate_efficiency_metrics(),
        'specialization': self.model.analyze_phase_specialization(),
    }
    checkpoint_path = f"{self.checkpoints_dir}/plastic_lstm_{timestamp}.pt"
    torch.save({
        'model_state_dict': self.model.state_dict(),
        'life_events': self.model.life_events,
        'generation_history': dict(self.model.generation_history),
        'config': {k: v for k, v in self.cfg.__dict__.items() if not k.startswith('__')},
        'history': dict(self.history),
        'extended_analysis': extended_analysis,
        'timestamp': timestamp
    }, checkpoint_path)
    print(f"\nSaved checkpoint: {checkpoint_path}")
    metrics_path = f"{self.cfg.plots_dir}/metrics_plastic_{timestamp}.json"
    with open(metrics_path, 'w') as f:
        def simple_dumper(obj):
            if isinstance(obj, (np.integer, np.floating, np.bool_)):
                return obj.item()
            if isinstance(obj, set):
                return list(obj)
            raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
        json.dump({
            'history': dict(self.history),
            'life_events': self.model.life_events,
            'generation_history': dict(self.model.generation_history),
            'extended_analysis': extended_analysis,
            'timestamp': timestamp
        }, f, indent=2, default=simple_dumper)
    print(f"Saved metrics: {metrics_path}")

NeuronAnalyzer.plot_training_progress_enhanced = plot_training_progress_enhanced
NeuronAnalyzer.analyze_top_neurons = analyze_top_neurons
NeuronAnalyzer.generate_report = generate_report
NeuronAnalyzer.save_all_results = save_all_results


In [15]:
# 14
def plot_training_progress_enhanced(self):
    """Plot mejorado con métricas de generación y mejor visualización de fases (robusto a claves faltantes)."""
    fig = make_subplots(
        rows=6, cols=1,
        subplot_titles=(
            'Validation Loss',
            'Validation Perplexity',
            'Active Neurons',
            'Birth/Death Events',
            'Adaptive Thresholds',
            'Generation Quality (Repetition Rate)'
        ),
        vertical_spacing=0.07,
        shared_xaxes=True
    )

    steps_hist = self.history.get('step', [])

    # --- Fondos por fase ---
    phases = self.history.get('phase', [])
    if phases and len(phases) == len(steps_hist) and len(phases) > 1:
        phase_changes = [steps_hist[i] for i in range(1, len(phases)) if phases[i] != phases[i-1]]
        last_step = steps_hist[-1] if steps_hist else 0

        shapes = []
        shapes.append(dict(type="rect", xref="x", yref="paper",
                           x0=0, y0=0, x1=phase_changes[0] if phase_changes else last_step, y1=1,
                           fillcolor="rgba(70,70,120,0.2)", layer="below", line_width=0))
        if phase_changes:
            shapes.append(dict(type="rect", xref="x", yref="paper",
                               x0=phase_changes[0], y0=0, x1=last_step, y1=1,
                               fillcolor="rgba(120,120,70,0.2)", layer="below", line_width=0))
        fig.update_layout(shapes=shapes)

    # --- Loss ---
    loss = self.history.get('loss', [])
    if steps_hist and loss and len(loss) == len(steps_hist):
        fig.add_trace(go.Scatter(x=steps_hist, y=loss, name='Loss', line=dict(color='cyan')), row=1, col=1)

    # --- Perplexity ---
    ppl = self.history.get('ppl', [])
    if steps_hist and ppl and len(ppl) == len(steps_hist):
        fig.add_trace(go.Scatter(x=steps_hist, y=ppl, name='PPL', line=dict(color='magenta')), row=2, col=1)

    # --- Neuronas activas ---
    active = self.history.get('active', [])
    if steps_hist and active and len(active) == len(steps_hist):
        fig.add_trace(go.Scatter(x=steps_hist, y=active, name='Active Neurons', line=dict(color='lime')), row=3, col=1)

    # --- Nacimientos / muertes ---
    if getattr(self, 'model', None) and getattr(self.model, 'life_events', None):
        births = [e for e in self.model.life_events if e.get('type') == 'birth']
        deaths = [e for e in self.model.life_events if e.get('type') == 'death']
        if births:
            fig.add_trace(
                go.Scatter(
                    x=[e.get('step', 0) for e in births],
                    y=[1] * len(births),
                    mode='markers',
                    name='Births',
                    marker=dict(color='lightgreen', symbol='star', size=8)
                ), row=4, col=1
            )
        if deaths:
            fig.add_trace(
                go.Scatter(
                    x=[e.get('step', 0) for e in deaths],
                    y=[0] * len(deaths),
                    mode='markers',
                    name='Deaths',
                    marker=dict(color='red', symbol='x', size=8)
                ), row=4, col=1
            )

    # --- Umbrales adaptativos ---
    if self.cfg.adaptive_thresholds and 'birth_threshold' in self.history and 'death_threshold' in self.history:
        birth_th = self.history.get('birth_threshold', [])
        death_th = self.history.get('death_threshold', [])
        if steps_hist and birth_th and len(birth_th) == len(steps_hist):
            fig.add_trace(go.Scatter(x=steps_hist, y=birth_th, name='Birth Thresh', line=dict(color='orange')), row=5, col=1)
        if steps_hist and death_th and len(death_th) == len(steps_hist):
            fig.add_trace(go.Scatter(x=steps_hist, y=[t * 100 for t in death_th], name='Death Thresh (x100)', line=dict(color='tomato', dash='dash')), row=5, col=1)

    # --- Calidad de generación (repetition rate) ---
    if hasattr(self.model, 'generation_history') and self.model.generation_history:
        steps_rep, repetitions = [], []
        for stp, samples in sorted(self.model.generation_history.items()):
            if samples:
                steps_rep.append(stp)
                repetitions.append(float(np.mean([s.get('repetition_rate', 0.0) for s in samples])))
        if steps_rep:
            fig.add_trace(
                go.Scatter(x=steps_rep, y=repetitions, name='Repetition Rate', line=dict(color='yellow'), mode='lines+markers'),
                row=6, col=1
            )

    fig.update_layout(template='plotly_dark', height=1500, title_text="Enhanced Training Progress (Plastic LSTM)")
    fig.show()
    self.save_plot(fig, "training_progress_enhanced_plastic")

# Asociar a la clase
NeuronAnalyzer.plot_training_progress_enhanced = plot_training_progress_enhanced


In [16]:
# 15
def analyze_top_neurons(self, k=10):
    active_mask = self.model.mask > 0
    active_idx = active_mask.nonzero(as_tuple=True)[0]
    if len(active_idx) == 0:
        print("No active neurons!")
        return
    contribs = self.model.contrib[active_idx]
    top_idx = active_idx[contribs.argsort(descending=True)[:k]]
    print(f"\nTop {k} neurons by contribution:")
    print("-" * 80)
    neuron_data = []
    for rank, idx in enumerate(top_idx):
        info = self.model.get_neuron_info(idx.item())
        neuron_data.append(info)
        print(f"Rank {rank+1}: Neuron {idx.item()}")
        print(f"  Age: {info['age']} steps | Contribution: {info['contribution']:.6f}")
        print(f"  Born in phase: {[e['phase'] for e in info['birth_death'] if e['type'] == 'birth'][0] if info['birth_death'] else 'Initial'}")
        print()
    return neuron_data

def generate_report(self):
    print("\n" + "="*80)
    print("NEUROPLASTIC LSTM - FINAL ANALYSIS REPORT")
    print("="*80)
    phase_a_births = len([e for e in self.model.life_events if e['type'] == 'birth' and e.get('phase', 'A') == 'A'])
    phase_b_births = len([e for e in self.model.life_events if e['type'] == 'birth' and e.get('phase') == 'B'])
    phase_a_deaths = len([e for e in self.model.life_events if e['type'] == 'death' and e.get('phase', 'A') == 'A'])
    phase_b_deaths = len([e for e in self.model.life_events if e['type'] == 'death' and e.get('phase') == 'B'])
    print("\n[1] POPULATION DYNAMICS:")
    print(f"  - Initial Neurons: {self.cfg.initial_active}")
    print(f"  - Final Active Neurons: {int(self.model.mask.sum().item())}")
    print(f"  - Phase A: {phase_a_births} Births, {phase_a_deaths} Deaths")
    print(f"  - Phase B: {phase_b_births} Births, {phase_b_deaths} Deaths")
    eff_metrics = self.model.calculate_efficiency_metrics()
    print("\n[2] PARAMETER EFFICIENCY:")
    print(f"  - Active Parameters: {int(eff_metrics['active_params']):,} / {int(eff_metrics['total_potential_params']):,}")
    print(f"  - Utilization: {eff_metrics['param_efficiency']*100:.2f}% of potential capacity")
    print(f"  - Estimated Memory Saved: {eff_metrics['memory_saved_ratio']*100:.2f}%")
    spec_metrics = self.model.analyze_phase_specialization()
    print("\n[3] NEURONAL SPECIALIZATION:")
    print(f"  - Phase A Born & Active: {spec_metrics['phase_a_active']} / {spec_metrics['phase_a_born']} ({spec_metrics['phase_a_active']/max(1, spec_metrics['phase_a_born']):.1%} survival)")
    print(f"  - Phase B Born & Active: {spec_metrics['phase_b_active']} / {spec_metrics['phase_b_born']} ({spec_metrics['phase_b_active']/max(1, spec_metrics['phase_b_born']):.1%} survival)")
    if 'phase_a_avg_contrib' in spec_metrics:
        print(f"  - Avg. Contribution (Phase A neurons): {spec_metrics.get('phase_a_avg_contrib', 0):.4f}")
        print(f"  - Avg. Contribution (Phase B neurons): {spec_metrics.get('phase_b_avg_contrib', 0):.4f}")
    print("\n" + "="*80)

def save_all_results(self):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    extended_analysis = {
        'efficiency': self.model.calculate_efficiency_metrics(),
        'specialization': self.model.analyze_phase_specialization(),
    }
    checkpoint_path = f"{self.checkpoints_dir}/plastic_lstm_{timestamp}.pt"
    torch.save({
        'model_state_dict': self.model.state_dict(),
        'life_events': self.model.life_events,
        'generation_history': dict(self.model.generation_history),
        'config': {k: v for k, v in self.cfg.__dict__.items() if not k.startswith('__')},
        'history': dict(self.history),
        'extended_analysis': extended_analysis,
        'timestamp': timestamp
    }, checkpoint_path)
    print(f"\nSaved checkpoint: {checkpoint_path}")
    metrics_path = f"{self.cfg.plots_dir}/metrics_plastic_{timestamp}.json"
    with open(metrics_path, 'w') as f:
        def simple_dumper(obj):
            if isinstance(obj, (np.integer, np.floating, np.bool_)):
                return obj.item()
            if isinstance(obj, set):
                return list(obj)
            raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
        json.dump({
            'history': dict(self.history),
            'life_events': self.model.life_events,
            'generation_history': dict(self.model.generation_history),
            'extended_analysis': extended_analysis,
            'timestamp': timestamp
        }, f, indent=2, default=simple_dumper)
    print(f"Saved metrics: {metrics_path}")

NeuronAnalyzer.analyze_top_neurons = analyze_top_neurons
NeuronAnalyzer.generate_report = generate_report
NeuronAnalyzer.save_all_results = save_all_results


In [17]:
# Celda 16
def train_plastic_lstm(cfg):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Usando dispositivo: {device}")

    use_amp = getattr(cfg, "use_amp", True) and device.type == "cuda"
    grad_clip = float(getattr(cfg, "grad_clip", 1.0))
    amp_dtype = torch.float16
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    train_loader, val_loader, stoi, itos, _ = build_loaders(cfg, phase='A')
    model = PlasticLSTM(cfg).to(device)
    model.stoi = stoi
    model.itos = itos
    print(f"Plastic LSTM inicializado con {cfg.initial_active} neuronas activas.")

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    replay_buffer = ReplayBuffer(cfg.replay_size)
    analyzer = NeuronAnalyzer(model, cfg, val_loader, itos, model_type='plastic')
    forgetting_tracker = CatastrophicForgettingTracker(cfg)

    initial_ppl_a = None
    step = 0
    model.train()
    phase_b_triggered = False

    population_history = []
    diversity_history = []
    continual_learning_history = []

    last_time = time.time()
    tokens_counted = 0
    tokens_per_step = cfg.batch_size * cfg.sequence_length

    for epoch in range(cfg.num_epochs):
        print(f"\n{'='*60}\nÉpoca {epoch+1}/{cfg.num_epochs}\n{'='*60}")
        current_loader = train_loader
        pbar = tqdm(current_loader, desc=f'Época {epoch+1}')

        for batch_idx, (x, y) in enumerate(pbar):
            step += 1

            # Transición a Fase B
            if step >= cfg.phase_b_start and not phase_b_triggered:
                print(f"\n FASE B: Inyectando nuevos datos en el paso {step}")

                forgetting_tracker.save_phase_checkpoint(model, 'A', step, device)

                with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
                    phase_a_results, _, _ = forgetting_tracker.evaluate_phase_specific_performance(
                        model, stoi, itos, device, 'A', step
                    )
                forgetting_tracker.phase_performance['A']['final_perplexity'] = phase_a_results['A']['mean_perplexity']

                model.set_phase('B')

                # --- CAMBIO 1: Reducir tasa de aprendizaje ---
                print("!! REDUCIENDO TASA DE APRENDIZAJE PARA FASE B !!")
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1
                # ----------------------------------------------

                train_loader_b, val_loader_b, _, _, _ = build_loaders(
                    cfg, phase='B', existing_stoi=stoi, existing_itos=itos
                )
                current_loader, val_loader = train_loader_b, val_loader_b
                analyzer.val_loader = val_loader_b
                phase_b_triggered = True
                pbar.close()
                pbar = tqdm(current_loader, desc=f'Época {epoch+1} - Fase B', initial=batch_idx)

            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
                logits = model(x)
                loss_per_sample = F.cross_entropy(logits, y[:, -1], reduction='none')
                loss = loss_per_sample.mean()
                total_loss = loss + model._capacity_penalty

            if use_amp:
                scaler.scale(total_loss).backward()
                scaler.unscale_(optimizer)
                if grad_clip is not None and grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                if grad_clip is not None and grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

            model.hebbian_update(loss_per_sample.detach())
            if step % cfg.plasticity_interval == 0:
                model.structural_update(loss.item(), loss_per_sample.detach())

            # Replay buffer
            replay_buffer.push(x, y)

            if step > 100 and step % 10 == 0:
                replay_samples = replay_buffer.sample(n=15)
                if replay_samples:
                    for replay_sample in replay_samples:
                        rx, ry = replay_sample
                        rx = rx.to(device, non_blocking=True)
                        ry = ry.to(device, non_blocking=True)
                        optimizer.zero_grad(set_to_none=True)
                        with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
                            r_logits = model(rx)
                            r_loss = F.cross_entropy(r_logits, ry[:, -1])

                            # --- CAMBIO 2: Ponderar pérdida del replay ---
                            replay_loss_weight = 200.0
                            r_total_loss = (r_loss * replay_loss_weight) + model._capacity_penalty
                            # ---------------------------------------------

                        if use_amp:
                            scaler.scale(r_total_loss).backward()
                            scaler.unscale_(optimizer)
                            if grad_clip is not None and grad_clip > 0:
                                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            r_total_loss.backward()
                            if grad_clip is not None and grad_clip > 0:
                                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                            optimizer.step()

            tokens_counted += tokens_per_step
            now = time.time()
            elapsed = now - last_time
            tps = tokens_counted / max(1e-6, elapsed)

            if step % cfg.eval_every_steps == 0:
                with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
                    val_loss, val_ppl = analyzer.evaluate(val_loader, device)
                model.train()

                if initial_ppl_a is None and not phase_b_triggered:
                    initial_ppl_a = val_ppl

                active_neurons = int(model.mask.sum().item())
                analyzer.log_metrics(step, val_loss, val_ppl, active_neurons,
                                     phase='B' if phase_b_triggered else 'A')

                analyzer.log_inheritance_metrics(step)

                if step % 500 == 0 and step > 0:
                    model.soft_checkpoints.append({
                        'weights': model.W.weight.clone(),
                        'mask': model.mask.clone(),
                        'performance': 1.0 / (val_ppl + 1e-6),
                        'step': step
                    })

                if step % 2000 == 0 and len(model.soft_checkpoints) >= 3:
                    if hasattr(model, '_apply_ensemble_knowledge'): _apply_ensemble_knowledge(model)

                pop_health = model.get_population_health()
                population_history.append({
                    'step': step,
                    'population': pop_health['current_population'],
                    'target': pop_health['target_population'],
                    'diversity': pop_health['diversity'],
                    'penalty': pop_health['capacity_penalty'],
                    'effective_capacity': pop_health['effective_capacity'],
                    'knowledge_preservation': pop_health['knowledge_preservation_ratio'],
                    'adaptation_capacity': pop_health['adaptation_capacity_ratio']
                })

                if step % 1000 == 0:
                    with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
                        cf_results, retention, activations = forgetting_tracker.evaluate_phase_specific_performance(
                            model, stoi, itos, device, 'B' if phase_b_triggered else 'A', step
                        )

                    empirical_metrics = forgetting_tracker.calculate_empirical_metrics(
                        model, 'B' if phase_b_triggered else 'A', step
                    )

                    continual_metrics = model.get_continual_learning_metrics()
                    continual_learning_history.append({
                        'step': step,
                        'phase': 'B' if phase_b_triggered else 'A',
                        **continual_metrics
                    })

                    if phase_b_triggered:
                        model.update_plasticity_levels(
                            phase_a_performance=cf_results.get('A', {}).get('mean_perplexity'),
                            phase_b_performance=cf_results.get('B', {}).get('mean_perplexity')
                        )

                        print(f"\n[Métricas Empíricas - Paso {step}]")
                        print(f"  Retención Fase A: {empirical_metrics['retencion_fase_a']:.1f}%")
                        print(f"  Interferencia B→A: {empirical_metrics['interferencia_B']:.3f}")
                        print(f"  Costo de oportunidad: {empirical_metrics['costo_oportunidad']:.3f}")
                        print(f"  Transferencia de conocimiento: {continual_metrics['knowledge_transfer_ratio']:.2%}")

                if step % 2000 == 0:
                    diversity_metrics = model.calculate_functional_diversity()
                    diversity_history.append({
                        'step': step,
                        'mean_diversity': diversity_metrics['mean_functional_diversity'],
                        'std_diversity': diversity_metrics['std_functional_diversity']
                    })

                pbar.set_postfix({
                    'loss': f'{loss.item():.3f}',
                    'val_ppl': f'{val_ppl:.1f}',
                    'neurons': f'{active_neurons}/{model.get_adaptive_population_target(loss.item(), model.current_phase)}',
                    'penalty': f'{model._capacity_penalty:.4f}',
                    'tok/s': f'{tps:,.0f}'
                })

                if step % (cfg.eval_every_steps * 5) == 0:
                    samples = save_generation_samples(
                        model, epoch, step, stoi, itos, device, cfg, "PlasticLSTM"
                    )
                    model.generation_history[step] = samples

                    if step % (cfg.eval_every_steps * 10) == 0:
                        comp_stats = model.analyze_competition_dynamics()
                        inherit_stats = model.get_inheritance_statistics()

                        print(f"\n[Estadísticas del Sistema - Paso {step}]")
                        print(f"  Plasticidad promedio: {comp_stats['avg_plasticity']:.3f}")
                        print(f"  Distribución: Baja={comp_stats['plasticity_distribution']['low']}, "
                              f"Media={comp_stats['plasticity_distribution']['medium']}, "
                              f"Alta={comp_stats['plasticity_distribution']['high']}")
                        print(f"  Balance de plasticidad: {pop_health['plasticity_balance']:.3f}")

                last_time = time.time()
                tokens_counted = 0

    print("\n" + "="*60)
    print("ENTRENAMIENTO COMPLETO - ANÁLISIS FINAL")
    print("="*60)

    analyzer.generate_final_report()

    final_comp = model.analyze_competition_dynamics()
    final_health = model.get_population_health()
    final_inheritance = model.get_inheritance_statistics()
    final_continual = model.get_continual_learning_metrics()

    print("\n[4] SISTEMA COMPETITIVO:")
    print(f"  - Población Final: {final_health['current_population']} / {final_health['target_population']} objetivo")
    print(f"  - Capacidad efectiva: {final_health['effective_capacity']*100:.1f}%")
    print(f"  - Distribución de plasticidad:")
    print(f"    • Baja (preservación): {final_comp['plasticity_distribution']['low']} neuronas")
    print(f"    • Media (mixta): {final_comp['plasticity_distribution']['medium']} neuronas")
    print(f"    • Alta (adaptación): {final_comp['plasticity_distribution']['high']} neuronas")

    print("\n[5] APRENDIZAJE CONTINUO:")
    print(f"  - Retención Fase A: {final_continual['phase_a_retention']*100:.1f}%")
    print(f"  - Crecimiento Fase B: {final_continual['phase_b_growth']*100:.1f}%")
    print(f"  - Neuronas cross-phase: {final_continual['cross_phase_neurons']}")
    print(f"  - Gradiente de plasticidad A→B: {final_continual['plasticity_gradient']['difference']:.3f}")

    analyzer.save_all_data_to_csv()

    df_population = pd.DataFrame(population_history)
    df_population.to_csv(f"{cfg.plastic_csv_dir}/population_dynamics.csv", index=False)

    forgetting_tracker.save_forgetting_metrics_to_csv(cfg.plastic_csv_dir)

    if continual_learning_history:
        df_continual = pd.DataFrame(continual_learning_history)
        df_continual.to_csv(f"{cfg.plastic_csv_dir}/continual_learning_metrics.csv", index=False)

    if diversity_history:
        df_diversity = pd.DataFrame(diversity_history)
        df_diversity.to_csv(f"{cfg.plastic_csv_dir}/functional_diversity.csv", index=False)

    analyzer.save_checkpoint()

    return model, analyzer, initial_ppl_a


def _apply_ensemble_knowledge(model):
    if len(model.soft_checkpoints) < 2:
        return

    perfs = torch.tensor([cp['performance'] for cp in model.soft_checkpoints], device=model.W.weight.device)
    weights = torch.softmax(perfs * 10, dim=0)

    poor_performers = model.competitive_score < model.competitive_score.median()
    high_plasticity = model.plasticity_levels > 0.7
    eligible = poor_performers & high_plasticity

    idxs = eligible.nonzero().squeeze()
    if idxs.numel() == 0:
        return

    for idx in idxs:
        if model.mask[idx] > 0:
            ensemble_weight = 0.0
            total_w = 0.0
            for w, cp in zip(weights, model.soft_checkpoints):
                if cp['mask'][idx] > 0:
                    ensemble_weight = ensemble_weight + (w * cp['weights'][idx].to(model.W.weight.device))
                    total_w = total_w + w
            if isinstance(ensemble_weight, torch.Tensor) and ensemble_weight.abs().sum() > 0:
                blend_factor = 0.2 * model.plasticity_levels[idx]
                model.W.weight.data[idx] = (1 - blend_factor) * model.W.weight.data[idx] + blend_factor * ensemble_weight


if __name__ == "__main__":
    plastic_model, plastic_analyzer, plastic_initial_ppl = train_plastic_lstm(cfg)

Usando dispositivo: cuda
Loading text for Phase A...


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


Loaded 314.6M characters for Phase A
Vocabulary size: 4933
Plastic LSTM inicializado con 256 neuronas activas.

Época 1/6


Época 1:   0%|          | 0/34559 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
  with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
  with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
  with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):
  with torch.cuda.amp.autocast(dtype=amp_dtype):



[Estadísticas del Sistema - Paso 2000]
  Plasticidad promedio: 1.000
  Distribución: Baja=0, Media=0, Alta=256
  Balance de plasticidad: 0.000

[Estadísticas del Sistema - Paso 4000]
  Plasticidad promedio: 1.000
  Distribución: Baja=0, Media=0, Alta=258
  Balance de plasticidad: 0.000

 FASE B: Inyectando nuevos datos en el paso 6000


  with torch.cuda.amp.autocast(dtype=amp_dtype, enabled=use_amp):


Model phase set to: B
!! REDUCIENDO TASA DE APRENDIZAJE PARA FASE B !!
Loading text for Phase B...
Loaded 62.9M characters for Phase B


Época 1 - Fase B:  87%|########6 | 5999/6911 [00:00<?, ?it/s]


[Métricas Empíricas - Paso 6000]
  Retención Fase A: 22.8%
  Interferencia B→A: 0.322
  Costo de oportunidad: 0.741
  Transferencia de conocimiento: 0.00%

[Estadísticas del Sistema - Paso 6000]
  Plasticidad promedio: 0.717
  Distribución: Baja=79, Media=0, Alta=186
  Balance de plasticidad: 0.000

[Métricas Empíricas - Paso 7000]
  Retención Fase A: 26.7%
  Interferencia B→A: 0.329
  Costo de oportunidad: 0.729
  Transferencia de conocimiento: 0.00%

[Métricas Empíricas - Paso 8000]
  Retención Fase A: 19.8%
  Interferencia B→A: 0.346
  Costo de oportunidad: 0.714
  Transferencia de conocimiento: 0.00%

[Estadísticas del Sistema - Paso 8000]
  Plasticidad promedio: 0.232
  Distribución: Baja=237, Media=0, Alta=56
  Balance de plasticidad: 0.922

[Métricas Empíricas - Paso 9000]
  Retención Fase A: 12.0%
  Interferencia B→A: 0.348
  Costo de oportunidad: 0.707
  Transferencia de conocimiento: 0.00%

[Métricas Empíricas - Paso 10000]
  Retención Fase A: 14.9%
  Interferencia B→A: 0.35

Época 2:   0%|          | 0/34559 [00:00<?, ?it/s]


[Métricas Empíricas - Paso 35000]
  Retención Fase A: 16.6%
  Interferencia B→A: 0.382
  Costo de oportunidad: 0.676
  Transferencia de conocimiento: 94.58%

[Métricas Empíricas - Paso 36000]
  Retención Fase A: 20.5%
  Interferencia B→A: 0.401
  Costo de oportunidad: 0.677
  Transferencia de conocimiento: 94.56%

[Estadísticas del Sistema - Paso 36000]
  Plasticidad promedio: 0.211
  Distribución: Baja=264, Media=0, Alta=67
  Balance de plasticidad: 0.405

[Métricas Empíricas - Paso 37000]
  Retención Fase A: 15.8%
  Interferencia B→A: 0.385
  Costo de oportunidad: 0.683
  Transferencia de conocimiento: 96.31%

[Métricas Empíricas - Paso 39000]
  Retención Fase A: 10.4%
  Interferencia B→A: 0.401
  Costo de oportunidad: 0.672
  Transferencia de conocimiento: 93.75%

[Métricas Empíricas - Paso 40000]
  Retención Fase A: 12.5%
  Interferencia B→A: 0.394
  Costo de oportunidad: 0.675
  Transferencia de conocimiento: 94.29%

[Estadísticas del Sistema - Paso 40000]
  Plasticidad promedio:

Época 3:   0%|          | 0/34559 [00:00<?, ?it/s]


[Métricas Empíricas - Paso 70000]
  Retención Fase A: 7.9%
  Interferencia B→A: 0.415
  Costo de oportunidad: 0.680
  Transferencia de conocimiento: 97.56%

[Estadísticas del Sistema - Paso 70000]
  Plasticidad promedio: 0.204
  Distribución: Baja=264, Media=0, Alta=64
  Balance de plasticidad: 0.390

[Métricas Empíricas - Paso 71000]
  Retención Fase A: 7.7%
  Interferencia B→A: 0.435
  Costo de oportunidad: 0.674
  Transferencia de conocimiento: 95.51%

[Métricas Empíricas - Paso 72000]
  Retención Fase A: 6.4%
  Interferencia B→A: 0.421
  Costo de oportunidad: 0.682
  Transferencia de conocimiento: 96.93%

[Estadísticas del Sistema - Paso 72000]
  Plasticidad promedio: 0.199
  Distribución: Baja=264, Media=0, Alta=62
  Balance de plasticidad: 0.380

[Métricas Empíricas - Paso 73000]
  Retención Fase A: 11.3%
  Interferencia B→A: 0.407
  Costo de oportunidad: 0.684
  Transferencia de conocimiento: 98.77%

[Métricas Empíricas - Paso 74000]
  Retención Fase A: 8.9%
  Interferencia B→A

Época 4:   0%|          | 0/34559 [00:00<?, ?it/s]


[Métricas Empíricas - Paso 104000]
  Retención Fase A: 4.5%
  Interferencia B→A: 0.458
  Costo de oportunidad: 0.676
  Transferencia de conocimiento: 96.39%

[Estadísticas del Sistema - Paso 104000]
  Plasticidad promedio: 0.213
  Distribución: Baja=264, Media=0, Alta=68
  Balance de plasticidad: 0.410

[Métricas Empíricas - Paso 105000]
  Retención Fase A: 8.4%
  Interferencia B→A: 0.451
  Costo de oportunidad: 0.680
  Transferencia de conocimiento: 97.26%

[Métricas Empíricas - Paso 107000]
  Retención Fase A: 7.1%
  Interferencia B→A: 0.454
  Costo de oportunidad: 0.675
  Transferencia de conocimiento: 95.50%

[Métricas Empíricas - Paso 108000]
  Retención Fase A: 5.1%
  Interferencia B→A: 0.452
  Costo de oportunidad: 0.679
  Transferencia de conocimiento: 96.66%

[Estadísticas del Sistema - Paso 108000]
  Plasticidad promedio: 0.206
  Distribución: Baja=264, Media=0, Alta=65
  Balance de plasticidad: 0.395

[Métricas Empíricas - Paso 109000]
  Retención Fase A: 5.2%
  Interferenc

Época 5:   0%|          | 0/34559 [00:00<?, ?it/s]


[Métricas Empíricas - Paso 139000]
  Retención Fase A: 3.5%
  Interferencia B→A: 0.486
  Costo de oportunidad: 0.676
  Transferencia de conocimiento: 95.78%

[Métricas Empíricas - Paso 140000]
  Retención Fase A: 3.6%
  Interferencia B→A: 0.484
  Costo de oportunidad: 0.686
  Transferencia de conocimiento: 96.89%

[Estadísticas del Sistema - Paso 140000]
  Plasticidad promedio: 0.189
  Distribución: Baja=264, Media=0, Alta=58
  Balance de plasticidad: 0.360

[Métricas Empíricas - Paso 141000]
  Retención Fase A: 7.7%
  Interferencia B→A: 0.482
  Costo de oportunidad: 0.676
  Transferencia de conocimiento: 94.28%

[Métricas Empíricas - Paso 142000]
  Retención Fase A: 5.4%
  Interferencia B→A: 0.488
  Costo de oportunidad: 0.681
  Transferencia de conocimiento: 96.94%

[Estadísticas del Sistema - Paso 142000]
  Plasticidad promedio: 0.201
  Distribución: Baja=264, Media=0, Alta=63
  Balance de plasticidad: 0.385

[Métricas Empíricas - Paso 144000]
  Retención Fase A: 5.5%
  Interferenc

Época 6:   0%|          | 0/34559 [00:00<?, ?it/s]


[Métricas Empíricas - Paso 173000]
  Retención Fase A: 1.7%
  Interferencia B→A: 0.489
  Costo de oportunidad: 0.684
  Transferencia de conocimiento: 98.46%

[Métricas Empíricas - Paso 174000]
  Retención Fase A: 1.9%
  Interferencia B→A: 0.499
  Costo de oportunidad: 0.671
  Transferencia de conocimiento: 94.96%

[Estadísticas del Sistema - Paso 174000]
  Plasticidad promedio: 0.225
  Distribución: Baja=264, Media=0, Alta=73
  Balance de plasticidad: 0.433

[Métricas Empíricas - Paso 175000]
  Retención Fase A: 3.2%
  Interferencia B→A: 0.503
  Costo de oportunidad: 0.678
  Transferencia de conocimiento: 96.36%

[Métricas Empíricas - Paso 176000]
  Retención Fase A: 1.9%
  Interferencia B→A: 0.509
  Costo de oportunidad: 0.680
  Transferencia de conocimiento: 96.65%

[Estadísticas del Sistema - Paso 176000]
  Plasticidad promedio: 0.204
  Distribución: Baja=264, Media=0, Alta=64
  Balance de plasticidad: 0.390

[Métricas Empíricas - Paso 177000]
  Retención Fase A: 1.6%
  Interferenc

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple
import pandas as pd

def generate_text(model, prompt, stoi, itos, device, length=100, temperature=0.8, top_k=40):
    """Genera texto a partir de un prompt"""
    model.eval()
    tokens = [stoi.get(c, 0) for c in prompt]
    generated = list(prompt)
    
    with torch.no_grad():
        for _ in range(length):
            x = torch.tensor([tokens[-model.c.sequence_length:]], device=device)
            logits = model(x)
            
            # Aplicar temperatura
            logits = logits[0, -1, :] / temperature
            
            # Top-k sampling
            if top_k > 0:
                values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < values[-1]] = -float('inf')
            
            probs = F.softmax(logits, dim=-1)
            idx = torch.multinomial(probs, 1).item()
            
            char = itos.get(idx, ' ')
            generated.append(char)
            tokens.append(idx)
    
    return ''.join(generated)

def calculate_prediction_metrics(model, prompt, stoi, itos, device, actual_continuation=None):
    """Calcula métricas detalladas de predicción"""
    model.eval()
    metrics = {}
    
    tokens = [stoi.get(c, 0) for c in prompt]
    
    with torch.no_grad():
        x = torch.tensor([tokens[-model.c.sequence_length:]], device=device)
        logits, analysis = model(x, return_analysis=True)
        
        # Perplejidad del prompt
        if len(tokens) > 1:
            x_full = torch.tensor([tokens[:-1]], device=device)
            y_full = torch.tensor([tokens[1:]], device=device)
            logits_full = model(x_full)
            loss = F.cross_entropy(logits_full[0], y_full[0])
            metrics['prompt_perplexity'] = torch.exp(loss).item()
        
        # Entropía de la distribución de predicción
        probs = F.softmax(logits[0, -1, :], dim=-1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8))
        metrics['prediction_entropy'] = entropy.item()
        
        # Top-k probabilidades
        top_k_probs, top_k_indices = torch.topk(probs, 10)
        metrics['top_10_predictions'] = [
            (itos.get(idx.item(), '?'), prob.item()) 
            for idx, prob in zip(top_k_indices, top_k_probs)
        ]
        
        # Análisis de activaciones
        if analysis and 'activations' in analysis:
            acts = analysis['activations'][0]
            metrics['mean_activation'] = acts.mean().item()
            metrics['activation_sparsity'] = (acts == 0).float().mean().item()
            metrics['active_neurons'] = (acts > 0.1).sum().item()
            
            # Activaciones por fase
            phase_a_neurons = list(model.phase_a_neurons)
            phase_b_neurons = list(model.phase_b_neurons)
            
            if phase_a_neurons:
                phase_a_acts = acts[phase_a_neurons].mean().item()
            else:
                phase_a_acts = 0
                
            if phase_b_neurons:
                phase_b_acts = acts[phase_b_neurons].mean().item()
            else:
                phase_b_acts = 0
                
            metrics['phase_a_activation'] = phase_a_acts
            metrics['phase_b_activation'] = phase_b_acts
        
        # Si tenemos la continuación real, calcular accuracy
        if actual_continuation:
            next_token = stoi.get(actual_continuation[0], 0)
            predicted_token = torch.argmax(logits[0, -1, :]).item()
            metrics['exact_match'] = (predicted_token == next_token)
            
            # Rank del token correcto
            sorted_indices = torch.argsort(probs, descending=True)
            rank = (sorted_indices == next_token).nonzero().item() + 1
            metrics['correct_token_rank'] = rank
            metrics['correct_token_prob'] = probs[next_token].item()
    
    return metrics

def comprehensive_model_evaluation(model, cfg, stoi, itos, device):
    """Evaluación comprehensiva del modelo"""
    print("\n" + "="*80)
    print("EVALUACIÓN COMPREHENSIVA DEL MODELO PLASTICLSTM")
    print("="*80)
    
    results = defaultdict(list)
    
    # 1. Evaluar con prompts de diferentes fases y categorías
    test_cases = {
        'Phase A - Code': [
            ("def fibonacci(n):", " return fib_helper(n, {})"),
            ("import numpy as np", "\nimport matplotlib.pyplot as plt"),
            ("class NeuralNetwork:", "\n    def __init__(self):"),
        ],
        'Phase A - Narrative': [
            ("The quick brown fox", " jumps over the lazy dog"),
            ("Once upon a time", " in a land far away"),
            ("The scientist discovered", " a remarkable phenomenon"),
        ],
        'Phase B - Narrative': [
            ("In a hole in the ground", " there lived a hobbit"),
            ("The spaceship landed", " on the distant planet"),
            ("In the year 2050,", " technology had advanced"),
        ],
        'Cross-Phase': [
            ("Machine learning is", " a subset of artificial intelligence"),
            ("The algorithm works by", " processing data iteratively"),
            ("Neural networks can", " learn complex patterns"),
        ]
    }
    
    # 2. Evaluar cada caso
    for category, prompts in test_cases.items():
        print(f"\n\n{category}:")
        print("-" * 50)
        
        for prompt, expected in prompts:
            # Generar texto
            generated = generate_text(model, prompt, stoi, itos, device, 
                                    length=50, temperature=0.8)
            
            # Calcular métricas
            metrics = calculate_prediction_metrics(model, prompt, stoi, itos, 
                                                 device, expected[1:] if len(expected) > 1 else None)
            
            # Mostrar resultados
            print(f"\nPrompt: '{prompt}'")
            print(f"Generated: '{generated[len(prompt):len(prompt)+50]}'")
            print(f"Perplexity: {metrics.get('prompt_perplexity', 'N/A'):.2f}")
            print(f"Entropy: {metrics['prediction_entropy']:.3f}")
            print(f"Active Neurons: {metrics.get('active_neurons', 'N/A')}")
            print(f"Phase A activation: {metrics.get('phase_a_activation', 0):.3f}")
            print(f"Phase B activation: {metrics.get('phase_b_activation', 0):.3f}")
            
            # Top predictions
            print("Top 5 predictions:")
            for char, prob in metrics['top_10_predictions'][:5]:
                print(f"  '{char}': {prob:.3f}")
            
            # Guardar para análisis
            results[category].append({
                'prompt': prompt,
                'generated': generated[len(prompt):len(prompt)+30],
                **metrics
            })
    
    # 3. Análisis de diversidad y calidad
    print("\n\n" + "="*60)
    print("ANÁLISIS DE CALIDAD DE GENERACIÓN")
    print("="*60)
    
    # Evaluar repetición
    all_generated = []
    for category, items in results.items():
        for item in items:
            all_generated.append(item['generated'])
    
    # Calcular diversidad de n-gramas
    for n in [2, 3, 4]:
        ngrams = []
        for text in all_generated:
            for i in range(len(text) - n + 1):
                ngrams.append(text[i:i+n])
        
        unique_ratio = len(set(ngrams)) / max(1, len(ngrams))
        print(f"\n{n}-gram diversity: {unique_ratio:.3f}")
    
    # 4. Comparación de activación entre fases
    print("\n\n" + "="*60)
    print("ANÁLISIS DE ACTIVACIÓN POR FASE")
    print("="*60)
    
    phase_analysis = defaultdict(lambda: {'phase_a': [], 'phase_b': []})
    
    for category, items in results.items():
        for item in items:
            phase_analysis[category]['phase_a'].append(
                item.get('phase_a_activation', 0))
            phase_analysis[category]['phase_b'].append(
                item.get('phase_b_activation', 0))
    
    # Crear DataFrame para mejor visualización
    analysis_data = []
    for category, data in phase_analysis.items():
        analysis_data.append({
            'Category': category,
            'Phase A Avg': np.mean(data['phase_a']),
            'Phase B Avg': np.mean(data['phase_b']),
            'Ratio B/A': np.mean(data['phase_b']) / max(0.001, np.mean(data['phase_a']))
        })
    
    df_analysis = pd.DataFrame(analysis_data)
    print(df_analysis.to_string(index=False))
    
    # 5. Test de capacidad de continuación coherente
    print("\n\n" + "="*60)
    print("TEST DE COHERENCIA EN GENERACIÓN LARGA")
    print("="*60)
    
    long_prompt = "The future of artificial intelligence"
    long_generation = generate_text(model, long_prompt, stoi, itos, device, 
                                  length=200, temperature=0.8)
    
    print(f"Prompt: '{long_prompt}'")
    print(f"Generated text:\n{long_generation}")
    
    # Análizar coherencia mediante ventanas deslizantes
    window_perplexities = []
    tokens = [stoi.get(c, 0) for c in long_generation]
    
    for i in range(10, len(tokens) - 10, 5):
        window_tokens = tokens[i-10:i+10]
        x = torch.tensor([window_tokens[:-1]], device=device)
        y = torch.tensor([window_tokens[1:]], device=device)
        
        with torch.no_grad():
            logits = model(x)
            loss = F.cross_entropy(logits[0], y[0])
            window_perplexities.append(torch.exp(loss).item())
    
    print(f"\nCoherence Analysis:")
    print(f"Mean window perplexity: {np.mean(window_perplexities):.2f}")
    print(f"Std window perplexity: {np.std(window_perplexities):.2f}")
    print(f"Perplexity trend: {'Increasing' if window_perplexities[-5:] > window_perplexities[:5] else 'Stable/Decreasing'}")
    
    # 6. Resumen final
    print("\n\n" + "="*60)
    print("RESUMEN DE PERFORMANCE")
    print("="*60)
    
    all_perplexities = []
    all_entropies = []
    
    for category, items in results.items():
        perps = [item.get('prompt_perplexity', 0) for item in items if 'prompt_perplexity' in item]
        ents = [item['prediction_entropy'] for item in items]
        
        if perps:
            all_perplexities.extend(perps)
        all_entropies.extend(ents)
        
        print(f"\n{category}:")
        print(f"  Avg Perplexity: {np.mean(perps) if perps else 'N/A':.2f}")
        print(f"  Avg Entropy: {np.mean(ents):.3f}")
    
    print(f"\nOVERALL:")
    print(f"  Global Avg Perplexity: {np.mean(all_perplexities):.2f}")
    print(f"  Global Avg Entropy: {np.mean(all_entropies):.3f}")
    print(f"  Active Neurons (avg): {int(model.mask.sum().item())}")
    print(f"  Parameter Efficiency: {model.calculate_efficiency_metrics()['param_efficiency']*100:.1f}%")
    
    return results

# Función principal para llamar
def evaluate_model_performance(model, cfg, stoi, itos, device):
    """Función principal para evaluar el modelo"""
    
    # Ejecutar evaluación comprehensiva
    results = comprehensive_model_evaluation(model, cfg, stoi, itos, device)
    
    # Crear visualizaciones
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # 1. Perplexity por categoría
    ax = axes[0, 0]
    perp_data = defaultdict(list)
    for cat, items in results.items():
        for item in items:
            if 'prompt_perplexity' in item:
                perp_data[cat].append(item['prompt_perplexity'])
    
    categories = list(perp_data.keys())
    perplexities = [np.mean(perp_data[cat]) for cat in categories]
    ax.bar(categories, perplexities, color=['blue', 'green', 'red', 'purple'])
    ax.set_title('Average Perplexity by Category')
    ax.set_ylabel('Perplexity')
    ax.tick_params(axis='x', rotation=45)
    
    # 2. Activación por fase
    ax = axes[0, 1]
    phase_data = {'Phase A': [], 'Phase B': []}
    for cat, items in results.items():
        for item in items:
            phase_data['Phase A'].append(item.get('phase_a_activation', 0))
            phase_data['Phase B'].append(item.get('phase_b_activation', 0))
    
    ax.boxplot([phase_data['Phase A'], phase_data['Phase B']], 
               labels=['Phase A', 'Phase B'])
    ax.set_title('Neuron Activation Distribution by Phase')
    ax.set_ylabel('Mean Activation')
    
    # 3. Entropía de predicción
    ax = axes[1, 0]
    entropy_by_cat = defaultdict(list)
    for cat, items in results.items():
        for item in items:
            entropy_by_cat[cat].append(item['prediction_entropy'])
    
    for i, (cat, entropies) in enumerate(entropy_by_cat.items()):
        ax.scatter([i]*len(entropies), entropies, label=cat, alpha=0.6)
    
    ax.set_xticks(range(len(entropy_by_cat)))
    ax.set_xticklabels(list(entropy_by_cat.keys()), rotation=45)
    ax.set_title('Prediction Entropy by Category')
    ax.set_ylabel('Entropy')
    ax.legend()
    
    # 4. Distribución de neuronas activas
    ax = axes[1, 1]
    active_neurons_data = []
    for cat, items in results.items():
        for item in items:
            if 'active_neurons' in item:
                active_neurons_data.append(item['active_neurons'])
    
    ax.hist(active_neurons_data, bins=20, alpha=0.7, color='orange')
    ax.set_title('Distribution of Active Neurons')
    ax.set_xlabel('Number of Active Neurons')
    ax.set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.show()
    
    return results

# Ejemplo de uso:
if __name__ == "__main__":
    # Asumiendo que tienes el modelo, cfg, stoi, itos y device
    # results = evaluate_model_performance(plastic_model, cfg, stoi, itos, device)
    pass