In [1]:
import warnings
warnings.filterwarnings('ignore')

!pip uninstall -y transformers tokenizers -q
!pip install transformers==4.35.0 -q
!pip install torch datasets sentencepiece scipy rouge-score sentence-transformers nltk -q

# restart the runtime after this cell and skip it

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.1/123.1 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m116.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m104.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
peft 0.18.0 requires huggingface_hub>=0.25.0, but you have huggingface-hub 0.17.3 which is incompatible.
accelerate 1.12.0 requires huggingface_hub>=0.21.0, but you have huggingface-hub 0.17.3 which is incompatible.
sentence-transformers 5.1.2 requires huggingface-hub>=0.20.0, but you have huggingface-hub 0.17.3 which is incompatible.

In [1]:
import logging
import random
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from datasets import load_dataset
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer, util
from sklearn.decomposition import PCA
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, GPT2Config, GPT2LMHeadModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")


Device: cuda


In [2]:
def set_seed(seed=42):
    """Sets the seed for reproducibility across random, numpy, and torch operations."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 42
set_seed(SEED)


In [3]:
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

log_filename = f"training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[
        logging.FileHandler(log_filename, mode='w'),
        logging.StreamHandler()
    ],
    force=True
)

logger = logging.getLogger(__name__)
logger.info(f"Log file created: {log_filename}")
logger.info(f"Device: {device}")

# Initialize ROUGE scorer with 'rouge1' and 'rougeL' metrics
rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
# Load SentenceTransformer model for semantic similarity calculation
sbert_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)

nltk.download('punkt', quiet=True)

logger.info("Metrics initialized: ROUGE, Sentence-BERT, and BLEU")


2025-12-13 09:36:08,418 - Log file created: training_log_20251213_093608.txt
2025-12-13 09:36:08,418 - Device: cuda
2025-12-13 09:36:08,419 - Using default tokenizer.
2025-12-13 09:36:08,424 - Use pytorch device_name: cuda:0
2025-12-13 09:36:08,424 - Load pretrained SentenceTransformer: all-MiniLM-L6-v2
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

2025-12-13 09:36:12,711 - Metrics initialized: ROUGE, Sentence-BERT, and BLEU


In [4]:
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """
    Creates a schedule with a learning rate that increases linearly from 0 to the
    initial learning rate set in the optimizer for `num_warmup_steps`,
    then decreases linearly from the initial learning rate to 0 over the remainder
    of the `num_training_steps`.

    Args:
        optimizer (`torch.optim.Optimizer`):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.

    Returns:
        `torch.optim.lr_scheduler.LambdaLR`:
            A learning rate scheduler.
    """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)


In [5]:
class EmbedParrot(nn.Module):
    """
    A model that reconstructs initial embeddings from deep embeddings produced by a transformer model.
    It consists of an input adapter, a decoder (either a causal LM or a standard transformer),
    and an output adapter.
    """
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        decoder_name: str = "gpt2",
        use_causal_lm: bool = False,
        dropout_p: float = 0.0,
        init_from_scratch: bool = True
    ):
        """
        Initializes the EmbedParrot model.

        Args:
            input_dim (int): The dimension of the input deep embeddings.
            output_dim (int): The dimension of the target initial embeddings.
            decoder_name (str, optional): The name of the decoder model to use (e.g., 'gpt2'). Defaults to "gpt2".
            use_causal_lm (bool, optional): Whether to use a causal language model as the decoder. Defaults to False.
            dropout_p (float, optional): The dropout probability for the adapter layers. Defaults to 0.0.
            init_from_scratch (bool, optional): Whether to initialize the decoder from scratch or use pretrained weights. Defaults to True.
        """
        super().__init__()

        cfg = AutoConfig.from_pretrained(decoder_name)
        cfg.use_cache = False
        cfg.output_hidden_states = True

        if init_from_scratch:
            # Random initialization
            if use_causal_lm:
                self.decoder = AutoModelForCausalLM.from_config(cfg)
            else:
                self.decoder = AutoModel.from_config(cfg)
        else:
            # Pretrained weights
            if use_causal_lm:
                self.decoder = AutoModelForCausalLM.from_pretrained(decoder_name, config=cfg)
            else:
                self.decoder = AutoModel.from_pretrained(decoder_name, config=cfg)

        dec_hidden = self.decoder.config.hidden_size

        self.input_adapter = nn.Sequential(
            nn.Linear(input_dim, dec_hidden),
            nn.Dropout(dropout_p)
        )
        self.output_adapter = nn.Sequential(
            nn.Dropout(dropout_p),
            nn.Linear(dec_hidden, output_dim)
        )

        self.use_causal_lm = use_causal_lm

    def forward(self, hidden_states: torch.Tensor):
        """
        Performs a forward pass through the EmbedParrot model.

        Args:
            hidden_states (torch.Tensor): The input deep embeddings.

        Returns:
            torch.Tensor: The reconstructed initial embeddings.
        """
        x = self.input_adapter(hidden_states)

        outputs = self.decoder(
            inputs_embeds=x,
            return_dict=True,
            output_hidden_states=True,
            use_cache=False
        )

        last_h = getattr(outputs, "last_hidden_state", None)
        if last_h is None:
            last_h = outputs.hidden_states[-1]

        reconstructed = self.output_adapter(last_h)
        return reconstructed


In [6]:
class EmbeddingDataset(Dataset):
    """
    A PyTorch Dataset for extracting and preparing embeddings from text data.
    It uses a target transformer model to get initial and deep embeddings,
    and optionally applies PCA for dimensionality reduction.
    """
    def __init__(
        self,
        texts,
        target_model_name="gpt2",
        target_layer=11,
        max_length=128,
        pca_components=None,
        val_split=0.2,
    ):
        """
        Initializes the EmbeddingDataset.

        Args:
            texts (list): A list of text strings to process.
            target_model_name (str, optional): The name of the transformer model to use for embedding extraction. Defaults to "gpt2".
            target_layer (int, optional): The specific layer from which to extract deep embeddings. Defaults to 11.
            max_length (int, optional): The maximum sequence length for tokenization. Defaults to 128.
            pca_components (int, optional): Number of PCA components to reduce deep embeddings to. If None, PCA is not applied. Defaults to None.
            val_split (float, optional): The fraction of data to be used for validation. Defaults to 0.2.
        """
        self.target_layer = target_layer
        self.max_length = max_length
        self.pca_components = pca_components
        self.val_split = val_split

        indices = np.random.permutation(len(texts))
        val_size = int(len(texts) * val_split)

        self.train_indices = indices[val_size:]
        self.val_indices = indices[:val_size]

        self.train_texts = [texts[i] for i in self.train_indices]
        self.val_texts = [texts[i] for i in self.val_indices]

        print(f"Train samples: {len(self.train_texts)}")
        print(f"Val samples: {len(self.val_texts)}")

        print(f"Loading model: {target_model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModel.from_pretrained(
            target_model_name,
            output_hidden_states=True
        )

        if torch.cuda.is_available():
            self.model = self.model.half().cuda()
        self.model.eval()

        all_texts = self.train_texts + self.val_texts

        print("Extracting embeddings...")
        all_deep_embeddings = []
        all_initial_embeddings = []
        all_attention_masks = []

        for i in tqdm(range(len(all_texts))):
            result = self._extract_embeddings(all_texts[i])
            all_deep_embeddings.append(result['deep_embeddings'])
            all_initial_embeddings.append(result['initial_embeddings'])
            all_attention_masks.append(result['attention_mask'])

        if self.pca_components is not None:
            print(f"Applying PCA to reduce embeddings to {pca_components} components...")

            train_deep_embeddings = all_deep_embeddings[:len(self.train_texts)]

            stacked_train_deep = torch.cat(train_deep_embeddings, dim=0)

            n_components = min(pca_components, stacked_train_deep.shape[0], stacked_train_deep.shape[1])
            pca_deep = PCA(n_components=n_components)
            pca_deep.fit(stacked_train_deep.numpy())

            print(f"PCA fitted on training data only")
            print(f"Explained variance ratio: {pca_deep.explained_variance_ratio_.sum():.4f}")

            stacked_all_deep = torch.cat(all_deep_embeddings, dim=0)
            stacked_deep_pca = pca_deep.transform(stacked_all_deep.numpy())

            print(f"Deep embeddings shape after PCA: {stacked_deep_pca.shape}")

            all_deep_embeddings = []
            idx = 0
            for i in range(len(all_texts)):
                orig_seq_len = all_attention_masks[i].shape[0]
                deep_emb = torch.from_numpy(stacked_deep_pca[idx:idx+orig_seq_len]).float()
                all_deep_embeddings.append(deep_emb)
                idx += orig_seq_len

        self.train_cache = []
        self.val_cache = []

        for i in range(len(self.train_texts)):
            self.train_cache.append({
                'deep_embeddings': all_deep_embeddings[i],
                'initial_embeddings': all_initial_embeddings[i],
                'attention_mask': all_attention_masks[i],
                'text': self.train_texts[i]
            })

        for i in range(len(self.val_texts)):
            idx = len(self.train_texts) + i
            self.val_cache.append({
                'deep_embeddings': all_deep_embeddings[idx],
                'initial_embeddings': all_initial_embeddings[idx],
                'attention_mask': all_attention_masks[idx],
                'text': self.val_texts[i]
            })

    def _extract_embeddings(self, text):
        """
        Extracts initial and deep embeddings for a given text using the target model.

        Args:
            text (str): The input text.

        Returns:
            dict: A dictionary containing 'deep_embeddings', 'initial_embeddings', and 'attention_mask'.
        """
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].to(self.model.device)
        attention_mask = encoding['attention_mask'].to(self.model.device)

        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            initial_embeddings = outputs.hidden_states[0].squeeze(0).cpu().float()
            deep_embeddings = outputs.hidden_states[self.target_layer].squeeze(0).cpu().float()

        return {
            'deep_embeddings': deep_embeddings,
            'initial_embeddings': initial_embeddings,
            'attention_mask': attention_mask.squeeze(0).cpu()
        }

    def get_train_dataset(self):
        """
        Returns a DatasetView object for the training data.
        """
        return DatasetView(self.train_cache)

    def get_val_dataset(self):
        """
        Returns a DatasetView object for the validation data.
        """
        return DatasetView(self.val_cache)


class DatasetView(Dataset):
    """
    A lightweight wrapper Dataset to provide a view of cached data for training/validation.
    """
    def __init__(self, cache):
        """
        Initializes the DatasetView with a pre-cached list of data samples.

        Args:
            cache (list): A list of dictionaries, where each dictionary represents a data sample.
        """
        self.cache = cache

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        return len(self.cache)

    def __getitem__(self, idx):
        """
        Retrieves a single data sample by index.

        Args:
            idx (int): The index of the desired sample.

        Returns:
            dict: The data sample at the specified index.
        """
        return self.cache[idx]


In [7]:
def cosine_similarity_loss(pred, target, mask=None):
    """
    Calculates the cosine similarity loss between predicted and target embeddings.

    Args:
        pred (torch.Tensor): The predicted embeddings.
        target (torch.Tensor): The target embeddings.
        mask (torch.Tensor, optional): An attention mask to apply to the loss calculation. Defaults to None.

    Returns:
        torch.Tensor: The calculated cosine similarity loss.
    """
    pred_norm = nn.functional.normalize(pred, p=2, dim=-1)
    target_norm = nn.functional.normalize(target, p=2, dim=-1)
    cos_sim = (pred_norm * target_norm).sum(dim=-1)

    if mask is not None:
        cos_sim = cos_sim * mask
        loss = -cos_sim.sum() / (mask.sum() + 1e-8)
    else:
        loss = -cos_sim.mean()

    return loss

In [8]:
def train_embed_parrot(
    model,
    train_dataloader,
    val_dataloader=None,
    num_epochs=3,
    learning_rate=1e-4,
    warmup_ratio=0.1,
    save_path='embed_parrot_best.pt',
    logger=None,
    rouge_scorer_obj=None,
    sbert_model=None,
    target_model=None,
    tokenizer=None
):
    """
    Trains the EmbedParrot model to reconstruct initial embeddings from deep embeddings.

    Args:
        model (nn.Module): The EmbedParrot model to be trained.
        train_dataloader (DataLoader): DataLoader for the training dataset.
        val_dataloader (DataLoader, optional): DataLoader for the validation dataset. Defaults to None.
        num_epochs (int, optional): Number of training epochs. Defaults to 3.
        learning_rate (float, optional): Initial learning rate for the optimizer. Defaults to 1e-4.
        warmup_ratio (float, optional): Ratio of total steps for linear warmup. Defaults to 0.1.
        save_path (str, optional): Path to save the best model checkpoint. Defaults to 'embed_parrot_best.pt'.
        logger (Logger, optional): Logger object for logging training progress. Defaults to None.
        rouge_scorer_obj (rouge_scorer.RougeScorer, optional): ROUGE scorer for evaluation. Defaults to None.
        sbert_model (SentenceTransformer, optional): Sentence-BERT model for semantic similarity evaluation. Defaults to None.
        target_model (AutoModel, optional): The target transformer model used for embedding extraction. Defaults to None.
        tokenizer (AutoTokenizer, optional): The tokenizer corresponding to the target model. Defaults to None.

    Returns:
        nn.Module: The trained EmbedParrot model.
    """
    set_seed(SEED)

    if logger is None:
        logger = logging.getLogger(__name__)

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)

    total_steps = len(train_dataloader) * num_epochs
    warmup_steps = int(warmup_ratio * total_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    logger.info(f"Training: {num_epochs} epochs, {total_steps} steps")

    # Initialize BLEU smoothing function
    bleu_smoothing = SmoothingFunction().method1

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        train_cosine_sim = 0

        total_batches = len(train_dataloader)

        print(f"Epoch {epoch+1}/{num_epochs} - {total_batches} batches")

        for batch_idx, batch in enumerate(train_dataloader):
            deep_emb = batch['deep_embeddings'].to(device)
            initial_emb = batch['initial_embeddings'].to(device)
            mask = batch['attention_mask'].to(device).float()

            pred_emb = model(deep_emb)
            loss = cosine_similarity_loss(pred_emb, initial_emb, mask)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

        if val_dataloader and epoch == num_epochs - 1:
            model.eval()
            val_loss = 0
            val_cosine_sim = 0
            val_rouge1_scores = []
            val_rougeL_scores = []
            val_sbert_scores = []
            val_bleu_scores = []

            samples_for_text_metrics = 0
            max_text_samples = 50

            with torch.no_grad():
                for batch_idx, batch in enumerate(val_dataloader):
                    deep_emb = batch['deep_embeddings'].to(device)
                    initial_emb = batch['initial_embeddings'].to(device)
                    mask = batch['attention_mask'].to(device).float()
                    texts = batch['text']

                    pred_emb = model(deep_emb)
                    loss = cosine_similarity_loss(pred_emb, initial_emb, mask)
                    val_loss += loss.item()

                    if rouge_scorer_obj and sbert_model and target_model and tokenizer:
                        if samples_for_text_metrics < max_text_samples:
                            for i in range(min(len(texts), 5, max_text_samples - samples_for_text_metrics)):
                                original_text = texts[i]

                                # Get reconstructed text
                                single_deep_emb = deep_emb[i:i+1]
                                single_mask = mask[i:i+1]
                                reconstructed_texts = embed_parrot_reconstruction(
                                    single_deep_emb,
                                    model,
                                    target_model,
                                    tokenizer,
                                    attention_mask=single_mask
                                )
                                reconstructed_text = reconstructed_texts[0]

                                # ROUGE scores
                                rouge_scores = rouge_scorer_obj.score(original_text, reconstructed_text)
                                val_rouge1_scores.append(rouge_scores['rouge1'].fmeasure)
                                val_rougeL_scores.append(rouge_scores['rougeL'].fmeasure)

                                # Sentence-BERT similarity
                                emb1 = sbert_model.encode(original_text, convert_to_tensor=True)
                                emb2 = sbert_model.encode(reconstructed_text, convert_to_tensor=True)
                                sbert_sim = util.cos_sim(emb1, emb2).item()
                                val_sbert_scores.append(sbert_sim)

                                # BLEU score
                                reference = [original_text.split()]
                                candidate = reconstructed_text.split()
                                bleu = sentence_bleu(reference, candidate, smoothing_function=bleu_smoothing)
                                val_bleu_scores.append(bleu)

                                samples_for_text_metrics += 1

                                if i == 0:
                                    logger.info(f"Original text: {original_text}")
                                    logger.info(f"Reconstructed text: {reconstructed_text}")

            avg_val_loss = val_loss / len(val_dataloader)
            avg_rouge1 = np.mean(val_rouge1_scores) * 100 if val_rouge1_scores else 0
            avg_rougeL = np.mean(val_rougeL_scores) * 100 if val_rougeL_scores else 0
            avg_sbert = np.mean(val_sbert_scores) * 100 if val_sbert_scores else 0
            avg_bleu = np.mean(val_bleu_scores) * 100 if val_bleu_scores else 0

            logger.info(f'Epoch {epoch+1} - Val Loss: {avg_val_loss:.4f}')
            if val_rouge1_scores:
                logger.info(f'Epoch {epoch+1} - Val ROUGE-1: {avg_rouge1:.2f} | Val ROUGE-L: {avg_rougeL:.2f}')
                logger.info(f'Epoch {epoch+1} - Val Sentence-BERT: {avg_sbert:.2f} | Val BLEU: {avg_bleu:.2f}')

            torch.save(model.state_dict(), save_path)
            logger.info(f'✓ Saved best model!')

    return model

In [9]:
def hotmap_embed_inversion(hidden_states, model, tokenizer, attention_mask=None):
    """
    Performs a 'hotmap' embedding inversion to reconstruct text from embeddings.
    This function maps reconstructed embeddings back to the vocabulary to generate text.

    Args:
        hidden_states (torch.Tensor): The reconstructed initial embeddings.
        model (nn.Module): The target transformer model (e.g., GPT2) to get input embeddings.
        tokenizer (AutoTokenizer): The tokenizer for decoding token IDs back to text.
        attention_mask (torch.Tensor, optional): An attention mask to consider valid tokens. Defaults to None.

    Returns:
        list: A list of reconstructed text strings.
    """
    with torch.no_grad():
        embedding_matrix = model.get_input_embeddings().weight

        if embedding_matrix.dtype != hidden_states.dtype:
            embedding_matrix = embedding_matrix.to(hidden_states.dtype)

        hidden_norm = nn.functional.normalize(hidden_states, p=2, dim=-1)
        embed_norm = nn.functional.normalize(embedding_matrix, p=2, dim=-1)
        similarity = torch.matmul(hidden_norm, embed_norm.T)
        predicted_ids = torch.argmax(similarity, dim=-1)

        texts = []
        for i, ids in enumerate(predicted_ids):
            if attention_mask is not None:
                mask = attention_mask[i]
                valid_ids = ids[mask == 1]
            else:
                valid_ids = ids
            texts.append(tokenizer.decode(valid_ids, skip_special_tokens=True))
    return texts

def embed_parrot_reconstruction(deep_embeddings, embed_parrot_model, target_model, tokenizer, attention_mask=None):
    """
    Reconstructs text from deep embeddings using the EmbedParrot model and a target model's tokenizer.

    Args:
        deep_embeddings (torch.Tensor): The deep embeddings to reconstruct from.
        embed_parrot_model (EmbedParrot): The trained EmbedParrot model.
        target_model (AutoModel): The target transformer model (e.g., GPT2) for embedding inversion.
        tokenizer (AutoTokenizer): The tokenizer corresponding to the target model.
        attention_mask (torch.Tensor, optional): An attention mask to consider valid tokens. Defaults to None.

    Returns:
        list: A list of reconstructed text strings.
    """
    embed_parrot_model.eval()
    with torch.no_grad():
        reconstructed_emb = embed_parrot_model(deep_embeddings)
        texts = hotmap_embed_inversion(reconstructed_emb, target_model, tokenizer, attention_mask)
    return texts


In [10]:
def evaluate_reconstruction(original_texts, reconstructed_texts):
    """
    Evaluates the quality of reconstructed texts against original texts using ROUGE and Sentence-BERT scores.

    Args:
        original_texts (list): A list of original text strings.
        reconstructed_texts (list): A list of reconstructed text strings.

    Returns:
        dict: A dictionary containing the average 'ROUGE-1', 'ROUGE-L', and 'Sentence-BERT' scores.
    """
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    sbert_model = SentenceTransformer('all-MiniLM-L6-v2')

    rouge1, rougeL, sbert = [], [], []

    for orig, recon in zip(original_texts, reconstructed_texts):
        scores = scorer.score(orig, recon)
        rouge1.append(scores['rouge1'].fmeasure)
        rougeL.append(scores['rougeL'].fmeasure)

        emb1 = sbert_model.encode(orig, convert_to_tensor=True)
        emb2 = sbert_model.encode(recon, convert_to_tensor=True)
        sbert.append(util.cos_sim(emb1, emb2).item())

    return {
        'ROUGE-1': np.mean(rouge1) * 100,
        'ROUGE-L': np.mean(rougeL) * 100,
        'Sentence-BERT': np.mean(sbert) * 100,
    }


In [11]:
dataset = load_dataset("sentence-transformers/coco-captions")

demo_texts = dataset['train'].shuffle(seed=SEED)['caption1'][:5000]

README.md: 0.00B [00:00, ?B/s]

pair/train-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/414010 [00:00<?, ? examples/s]

In [12]:
pca_component_nums = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768]
layers = [0, 3, 6, 9, 11]
vals = {}

# Loop through different target layers and PCA component numbers to evaluate model performance.
for target_layer_num in layers:
    for pca_component_num in pca_component_nums:

        logger.info(f"\n{'='*80}")
        logger.info(f"Configuration: Layer={target_layer_num}, PCA={pca_component_num}")
        logger.info(f"{'='*80}")
        logger.info(f"Demo dataset: {len(demo_texts)} samples")
        logger.info("Creating dataset (extracting embeddings)...")

        # Create an EmbeddingDataset for the current configuration.
        demo_dataset = EmbeddingDataset(
            demo_texts,
            target_model_name="gpt2",
            target_layer=target_layer_num,
            pca_components=pca_component_num,
            max_length=64,
            val_split=0.2
        )

        # Get training and validation datasets.
        train_dataset = demo_dataset.get_train_dataset()
        val_dataset = demo_dataset.get_val_dataset()

        logger.info(f"✓ Train samples: {len(train_dataset)}")
        logger.info(f"✓ Val samples: {len(val_dataset)}")

        logger.info("Initializing Embed Parrot...")
        torch.manual_seed(SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(SEED)

        # Initialize the EmbedParrot model with the specified input/output dimensions.
        embed_parrot = EmbedParrot(
            input_dim=pca_component_num,
            output_dim=768,
            decoder_name="gpt2",
            use_causal_lm=False
        ).to(device)
        torch_generator = torch.Generator()
        torch_generator.manual_seed(SEED)

        # Create DataLoaders for training and validation.
        train_loader = DataLoader(
            train_dataset,
            batch_size=32,
            shuffle=True,
            generator=torch_generator
        )
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

        vals[pca_component_num] = val_loader

        logger.info(f"✓ Train batches: {len(train_loader)}")
        logger.info(f"✓ Val batches: {len(val_loader)}")
        logger.info(f"Starting training...")

        # Train the EmbedParrot model.
        trained_model = train_embed_parrot(
            embed_parrot,
            train_loader,
            val_loader,
            num_epochs=10,
            learning_rate=1e-4,
            logger=logger,
            rouge_scorer_obj=rouge_scorer_obj,
            sbert_model=sbert_model,
            target_model=demo_dataset.model,
            tokenizer=demo_dataset.tokenizer,
            save_path=f"embed_parrot_layer{target_layer_num}_pca{pca_component_num}.pt"
        )

logger.info(f"\n{'='*80}")
logger.info(f"Training complete! All logs saved to {log_filename}")

2025-12-13 09:36:16,996 - 
2025-12-13 09:36:16,997 - Configuration: Layer=0, PCA=1
2025-12-13 09:36:16,998 - Demo dataset: 5000 samples
2025-12-13 09:36:16,999 - Creating dataset (extracting embeddings)...


Train samples: 4000
Val samples: 1000
Loading model: gpt2


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Extracting embeddings...


100%|██████████| 5000/5000 [01:24<00:00, 59.34it/s]


Applying PCA to reduce embeddings to 1 components...
PCA fitted on training data only
Explained variance ratio: 0.4529


2025-12-13 09:37:47,145 - ✓ Train samples: 4000
2025-12-13 09:37:47,148 - ✓ Val samples: 1000
2025-12-13 09:37:47,149 - Initializing Embed Parrot...


Deep embeddings shape after PCA: (320000, 1)


2025-12-13 09:37:49,338 - ✓ Train batches: 125
2025-12-13 09:37:49,338 - ✓ Val batches: 32
2025-12-13 09:37:49,339 - Starting training...
2025-12-13 09:37:49,342 - Training: 10 epochs, 1250 steps


Epoch 1/10 - 125 batches
Epoch 2/10 - 125 batches
Epoch 3/10 - 125 batches


KeyboardInterrupt: 