In [1]:
!pip install rouge_score datasets transformers wandb nltk psutil  -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from transformers import Trainer, TrainerCallback
from datasets import load_dataset, Dataset
import os
import gc
import numpy as np
import random
import time
import datetime
import json
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk
import matplotlib.pyplot as plt
import wandb
import logging
import psutil
from torch.utils.data import DataLoader

In [3]:
nltk.download('punkt')

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("ether_training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:



# Memory utilities
def clear_memory():
    """Clear unused memory."""
    gc.collect()
    torch.cuda.empty_cache()

def get_memory_stats():
    """Get detailed memory usage statistics."""
    stats = {}

    # GPU memory
    if torch.cuda.is_available():
        stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3)
        stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3)
        stats["gpu_max_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
        stats["gpu_max_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)

        # Reset max memory stats for future tracking
        torch.cuda.reset_peak_memory_stats()

    # CPU memory
    mem = psutil.virtual_memory()
    stats["cpu_used_gb"] = (mem.total - mem.available) / (1024**3)
    stats["cpu_total_gb"] = mem.total / (1024**3)
    stats["cpu_percent"] = mem.percent

    return stats

def print_gpu_memory():
    """Print GPU memory usage."""
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")



In [5]:
class COMPUTEMETRICS():
    """Class to calculate ROUGE and BLEU scores."""
    def __init__(self):
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        self.smooth = SmoothingFunction().method1

    def calculate_rouge(self, prediction, reference):
        """Calculate ROUGE scores for a prediction against a reference."""
        # Handle empty predictions/references
        if not prediction or not reference:
            return {
                "rouge1": 0.0,
                "rouge2": 0.0,
                "rougeL": 0.0
            }

        try:
            scores = self.rouge_scorer.score(prediction, reference)
            return {
                "rouge1": scores["rouge1"].fmeasure,
                "rouge2": scores["rouge2"].fmeasure,
                "rougeL": scores["rougeL"].fmeasure
            }
        except Exception as e:
            logging.error(f"Error calculating ROUGE score: {e}")
            return {
                "rouge1": 0.0,
                "rouge2": 0.0,
                "rougeL": 0.0
            }

    def calculate_bleu(self, prediction, reference):
        """Calculate BLEU score for a prediction against a reference."""
        # Handle empty predictions/references
        if not prediction or not reference:
            return 0.0

        try:
            prediction_tokens = nltk.word_tokenize(prediction.lower())
            reference_tokens = [nltk.word_tokenize(reference.lower())]

            # Handle empty token sequences
            if not prediction_tokens or not reference_tokens[0]:
                return 0.0

            return sentence_bleu(reference_tokens, prediction_tokens, smoothing_function=self.smooth)
        except Exception as e:
            logging.error(f"Error calculating BLEU score: {e}")
            return 0.0

    def calculate_metrics(self, predictions, references):
        """Calculate all metrics for a batch of predictions and references."""
        if not predictions or not references:
            return {
                "rouge1": 0.0,
                "rouge2": 0.0,
                "rougeL": 0.0,
                "bleu": 0.0
            }

        metrics = {
            "rouge1": [],
            "rouge2": [],
            "rougeL": [],
            "bleu": []
        }

        for pred, ref in zip(predictions, references):
            # Calculate ROUGE
            rouge_scores = self.calculate_rouge(pred, ref)
            for key in ["rouge1", "rouge2", "rougeL"]:
                metrics[key].append(rouge_scores[key])

            # Calculate BLEU
            metrics["bleu"].append(self.calculate_bleu(pred, ref))

        # Average the metrics and handle empty lists
        return {key: sum(values) / len(values) if values else 0 for key, values in metrics.items()}

In [6]:


# ETHER Layer Implementations
class HouseholderTransform(nn.Module):
    """
    Implementation of the Householder transformation for ETHER.
    H = I - 2uu^T where u is a unit vector.
    """
    def __init__(self, dim, n_blocks=1):
        super().__init__()
        self.dim = dim
        self.n_blocks = n_blocks

        # If using blocks, divide dimension
        self.block_size = dim // n_blocks
        if self.block_size == 0:
            self.block_size = dim
            self.n_blocks = 1

        # Initialize unit normal vectors for each block
        if self.n_blocks > 1:
            self.u_vectors = nn.ParameterList([
                nn.Parameter(torch.zeros(self.block_size))
                for _ in range(self.n_blocks)
            ])
            # Initialize with small values
            for u in self.u_vectors:
                nn.init.normal_(u, mean=0, std=0.01)
        else:
            # Single vector for the entire dimension
            self.u = nn.Parameter(torch.zeros(dim))
            nn.init.normal_(self.u, mean=0, std=0.01)

    def forward(self, x):
        """
        Apply Householder transformation: H = I - 2uu^T

        """
        # Process input in blocks if using multiple blocks
        if self.n_blocks > 1:
            results = []

            for i, u_param in enumerate(self.u_vectors):
                # Get the block from input
                start_idx = i * self.block_size
                end_idx = min(start_idx + self.block_size, self.dim)
                if start_idx >= x.shape[-1]:
                    continue
                x_block = x[..., start_idx:end_idx]

                # Normalize u to ensure it's a unit vector
                u = F.normalize(u_param, p=2, dim=0)

                # Compute the Householder transformation: x - 2u(u^T x)
                u_dot_x = torch.matmul(x_block, u)
                reflection = 2 * u_dot_x.unsqueeze(-1) * u.unsqueeze(0)
                transformed_block = x_block - reflection

                results.append(transformed_block)

            # Concatenate transformed blocks
            return torch.cat(results, dim=-1)
        else:
            # Normalize u to ensure it's a unit vector
            u = F.normalize(self.u, p=2, dim=0)

            # Compute the Householder transformation: x - 2u(u^T x)
            u_dot_x = torch.matmul(x, u)
            reflection = 2 * u_dot_x.unsqueeze(-1) * u.unsqueeze(0)
            return x - reflection


In [7]:

class ETHERPlusTransform(nn.Module):
    """
    Implementation of ETHER+, a relaxation of ETHER.
    H+ = I - uu^T + vv^T
    """
    def __init__(self, dim, n_blocks=1):
        super().__init__()
        self.dim = dim
        self.n_blocks = n_blocks

        # If using blocks, divide dimension
        self.block_size = dim // n_blocks
        if self.block_size == 0:
            self.block_size = dim
            self.n_blocks = 1

        # Initialize two sets of vectors for each block
        if self.n_blocks > 1:
            self.u_vectors = nn.ParameterList([
                nn.Parameter(torch.zeros(self.block_size))
                for _ in range(self.n_blocks)
            ])
            self.v_vectors = nn.ParameterList([
                nn.Parameter(torch.zeros(self.block_size))
                for _ in range(self.n_blocks)
            ])
            # Initialize with small values
            for u, v in zip(self.u_vectors, self.v_vectors):
                nn.init.normal_(u, mean=0, std=0.01)
                nn.init.normal_(v, mean=0, std=0.01)
        else:
            # Single vectors for the entire dimension
            self.u = nn.Parameter(torch.zeros(dim))
            self.v = nn.Parameter(torch.zeros(dim))
            nn.init.normal_(self.u, mean=0, std=0.01)
            nn.init.normal_(self.v, mean=0, std=0.01)

    def forward(self, x):
        """
        Apply ETHER+ transformation: H+ = I - uu^T + vv^T

        """
        # Process input in blocks if using multiple blocks
        if self.n_blocks > 1:
            results = []

            for i, (u_param, v_param) in enumerate(zip(self.u_vectors, self.v_vectors)):
                # Get the block from input
                start_idx = i * self.block_size
                end_idx = min(start_idx + self.block_size, self.dim)
                if start_idx >= x.shape[-1]:
                    continue
                x_block = x[..., start_idx:end_idx]

                # Normalize vectors to ensure they're unit vectors
                u = F.normalize(u_param, p=2, dim=0)
                v = F.normalize(v_param, p=2, dim=0)

                # Compute the relaxed transformation: x - u(u^T x) + v(v^T x)
                u_dot_x = torch.matmul(x_block, u)
                v_dot_x = torch.matmul(x_block, v)

                u_term = u_dot_x.unsqueeze(-1) * u.unsqueeze(0)
                v_term = v_dot_x.unsqueeze(-1) * v.unsqueeze(0)

                transformed_block = x_block - u_term + v_term

                results.append(transformed_block)

            # Concatenate transformed blocks
            return torch.cat(results, dim=-1)
        else:
            # Normalize vectors to ensure they're unit vectors
            u = F.normalize(self.u, p=2, dim=0)
            v = F.normalize(self.v, p=2, dim=0)

            # Compute the relaxed transformation: x - u(u^T x) + v(v^T x)
            u_dot_x = torch.matmul(x, u)
            v_dot_x = torch.matmul(x, v)

            u_term = u_dot_x.unsqueeze(-1) * u.unsqueeze(0)
            v_term = v_dot_x.unsqueeze(-1) * v.unsqueeze(0)

            return x - u_term + v_term


In [8]:

class ETHERLinear(nn.Module):
    """
    Linear layer with ETHER transformation.
    """
    def __init__(self, base_layer, use_ether_plus=True, n_blocks=16, double_sided=True):
        super().__init__()
        self.in_features = base_layer.in_features
        self.out_features = base_layer.out_features
        self.use_ether_plus = use_ether_plus
        self.double_sided = double_sided

        # Store original weights and bias, but freeze them
        self.weight = nn.Parameter(base_layer.weight.data.clone(), requires_grad=False)
        self.bias = None
        if hasattr(base_layer, 'bias') and base_layer.bias is not None:
            self.bias = nn.Parameter(base_layer.bias.data.clone(), requires_grad=False)

        # Create the transformation modules
        if use_ether_plus:
            # For ETHER+, we use ether_plus on both sides for better performance
            self.left_transform = ETHERPlusTransform(self.in_features, n_blocks)
            if double_sided:
                self.right_transform = ETHERPlusTransform(self.out_features, n_blocks)
        else:
            # For basic ETHER, we use Householder on the left
            self.left_transform = HouseholderTransform(self.in_features, n_blocks)
            if double_sided:
                self.right_transform = HouseholderTransform(self.out_features, n_blocks)

    def forward(self, x):
        """
        Forward pass with ETHER transformation.
        """
        # Make sure the input is on the same device as the weights
        device = self.weight.device
        x = x.to(device)

        # Apply left transform, and compute W · x
        x_transformed = self.left_transform(x)
        output = F.linear(x_transformed, self.weight, bias=None)

        # Apply right transform if double-sided
        if self.double_sided:
            output = self.right_transform(output)

        # Add bias if present
        if self.bias is not None:
            output = output + self.bias

        return output


In [9]:

# Function to apply ETHER to a model
def apply_ether_to_model(model, use_ether_plus=True, target_modules=None, n_blocks=16, double_sided=True):
    """
    Apply ETHER transformations to target modules in the model.

    """
    if target_modules is None:
        # Default set to attention layers
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

    # Count transformations for reporting
    transformation_count = 0
    total_trainable_params = 0

    # Inspect model architecture
    logger.info("Model architecture inspection:")
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            logger.info(f"Linear layer found: {name}, in_features={module.in_features}, out_features={module.out_features}")


    # Helper function to recursively replace modules
    def replace_module(module, path=""):
        nonlocal transformation_count, total_trainable_params

        for name, child in list(module.named_children()):
            full_name = f"{path}.{name}" if path else name

            # Checking for target module
            if isinstance(child, nn.Linear) and any(target in full_name for target in target_modules):
                # Replace with ETHER linear layer
                ether_layer = ETHERLinear(
                    child,
                    use_ether_plus=use_ether_plus,
                    n_blocks=n_blocks,
                    double_sided=double_sided
                )

                # Count parameters
                params = sum(p.numel() for p in ether_layer.parameters() if p.requires_grad)
                total_trainable_params += params

                # Replace the module
                setattr(module, name, ether_layer)
                transformation_count += 1
                logger.info(f"Applied ETHER {'Plus' if use_ether_plus else ''} to {full_name} (+{params} params)")
            else:

                replace_module(child, full_name)

    # Start the replacement process
    replace_module(model)

    logger.info(f"Applied ETHER {'Plus' if use_ether_plus else ''} to {transformation_count} modules")
    logger.info(f"Total trainable parameters: {total_trainable_params:,}")

    return model


In [46]:

# Data preparation
def load_alpaca_dataset(tokenizer, max_length=2048, sample_size=None):
    """
    Prepare the Alpaca dataset with proper tokenization.
    """
    try:
        # Load the dataset
        dataset = load_dataset("tatsu-lab/alpaca")
        logger.info(f"Loaded dataset with {len(dataset['train'])} examples")

        if sample_size is not None:

            indices = random.sample(range(len(dataset["train"])), min(sample_size, len(dataset["train"])))
            dataset["train"] = dataset["train"].select(indices)
            logger.info(f"Sampled {len(dataset['train'])} examples")

        # Tokenize and format the dataset
        def preprocess_function(examples):
            # Process a batch of examples
            batch_size = len(examples["instruction"])
            inputs = []
            labels = []

            for i in range(batch_size):
                instruction = examples["instruction"][i]
                input_text = examples["input"][i] if examples["input"][i] else ""
                output = examples["output"][i]

                if input_text:
                    prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.<|start_header_id|>user<|end_header_id|>\n{instruction}\n{input_text}<|start_header_id|>assistant<|end_header_id|>\n"
                else:
                    prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.<|start_header_id|>user<|end_header_id|>\n{instruction}<|start_header_id|>assistant<|end_header_id|>\n"

                # Tokenize
                tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors=None)
                tokenized_output = tokenizer(output + tokenizer.eos_token, truncation=False, return_tensors=None)

                # Create labels and input IDs
                input_ids = tokenized_prompt["input_ids"]
                attention_mask = tokenized_prompt["attention_mask"]
                labels_ids = [-100] * len(input_ids)  # Initialize labels with -100

                # Add output tokens and corresponding label IDs
                input_ids.extend(tokenized_output["input_ids"])
                attention_mask.extend(tokenized_output["attention_mask"])
                labels_ids.extend(tokenized_output["input_ids"])

                # Truncate if too long
                max_length = 512  # Define max length
                if len(input_ids) > max_length:
                    input_ids = input_ids[:max_length]
                    attention_mask = attention_mask[:max_length]
                    labels_ids = labels_ids[:max_length]

                inputs.append({
                    "input_ids": input_ids,
                    "attention_mask": attention_mask
                })
                labels.append(labels_ids)

            batch = {
                "input_ids": [inp["input_ids"] for inp in inputs],
                "attention_mask": [inp["attention_mask"] for inp in inputs],
                "labels": labels
            }

            return batch

        # Process the dataset
        processed_dataset = dataset.map(
            preprocess_function,
            batched=True,
            batch_size=100,
            remove_columns=dataset["train"].column_names
        )

        # Split the dataset
        train_size = int(0.9 * len(processed_dataset["train"]))
        train_dataset = processed_dataset["train"].select(range(train_size))
        eval_dataset = processed_dataset["train"].select(range(train_size-50, len(processed_dataset["train"])))

        return {
            "train": train_dataset,
            "eval": eval_dataset
        }

    except Exception as e:
        logger.error(f"Error preparing dataset: {e}")
        return None


In [11]:

# Custom collator to handle different sequence lengths
class DATACOLLATOR:
    """
    Collator that pads sequences in a batch.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id

    def __call__(self, examples):
        # Get max lengths
        max_input_length = max(len(example["input_ids"]) for example in examples)
        max_label_length = max(len(example["labels"]) for example in examples)
        max_attn_length = max(len(example["attention_mask"]) for example in examples)

        max_length = max(max_input_length, max_label_length, max_attn_length)

        # Initialize batch
        batch = {
            "input_ids": [],
            "attention_mask": [],
            "labels": []
        }

        # Pad each example
        for example in examples:
            # Pad input_ids
            input_ids = example["input_ids"]
            padding_length = max_length - len(input_ids)
            padded_input_ids = input_ids + [self.pad_token_id] * padding_length
            batch["input_ids"].append(padded_input_ids)

            # Pad attention mask
            attention_mask = example["attention_mask"]
            padded_attention_mask = attention_mask + [0] * padding_length
            batch["attention_mask"].append(padded_attention_mask)

            # Pad labels (-100 is ignored)
            labels = example["labels"]
            padded_labels = labels + [-100] * padding_length
            batch["labels"].append(padded_labels)

        # Convert to tensors
        batch = {k: torch.tensor(v) for k, v in batch.items()}

        return batch


In [12]:
class MetricsCallback(TrainerCallback):
    """
    Track detailed metrics during training.
    """
    def __init__(self, tokenizer, eval_dataset, log_dir="./metrics"):
        self.tokenizer = tokenizer
        self.eval_dataset = eval_dataset
        self.log_dir = log_dir
        self.metrics_calculator = COMPUTEMETRICS()

        # Create log directory
        os.makedirs(log_dir, exist_ok=True)

        # Initialize metrics storage
        self.gpu_memory_stats = []
        self.cpu_memory_stats = []
        self.training_time = 0
        self.start_time = None
        self.step_times = []
        self.loss_values = []
        self.evaluation_metrics = []
        self.best_metrics = {
            "rouge1": 0,
            "rouge2": 0,
            "rougeL": 0,
            "bleu": 0,
            "step": 0
        }

    def on_train_begin(self, args, state, control, **kwargs):
        """Called at the beginning of training."""
        self.start_time = time.time()
        logger.info("Training started")

        # Log initial memory
        memory_stats = get_memory_stats()
        self.gpu_memory_stats.append({"step": 0, **memory_stats})
        logger.info(f"Initial memory: {memory_stats}")

        # Initialize wandb if used
        if args.report_to == "wandb" and wandb.run is not None:
            wandb.log({"memory": memory_stats, "step": 0})

    def on_step_end(self, args, state, control, **kwargs):
        """Called at the end of a training step."""
        # Record time per step
        if len(self.step_times) == 0:
            self.step_times.append(time.time() - self.start_time)
        else:
            # Use the time since the last step
            self.step_times.append(time.time() - (self.start_time + sum(self.step_times)))

        # Log every N steps
        if state.global_step % args.logging_steps == 0:
            # Get memory stats
            memory_stats = get_memory_stats()
            self.gpu_memory_stats.append({"step": state.global_step, **memory_stats})

            # Calculate training speed
            recent_step_times = self.step_times[-args.logging_steps:]
            avg_step_time = sum(recent_step_times) / len(recent_step_times)
            steps_per_second = 1.0 / avg_step_time if avg_step_time > 0 else 0

            # Log to console
            logger.info(f"Step {state.global_step}: {avg_step_time:.3f} sec/step, {steps_per_second:.3f} steps/sec")
            logger.info(f"Memory: {memory_stats['gpu_allocated_gb']:.2f} GB allocated, {memory_stats['gpu_reserved_gb']:.2f} GB reserved")

            # Log to wandb
            if args.report_to == "wandb" and wandb.run is not None:
                wandb.log({
                    "step": state.global_step,
                    "avg_step_time": avg_step_time,
                    "steps_per_second": steps_per_second,
                    "memory": memory_stats,
                })

            # Save memory stats to a file
            with open(os.path.join(self.log_dir, "memory_stats.json"), "w") as f:
                json.dump(self.gpu_memory_stats, f, indent=2)

            # Periodic cleanup
            if state.global_step % 50 == 0:
                gc.collect()
                torch.cuda.empty_cache()

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Called when logs are available."""
        if logs is None:
            return

        # Extract loss
        if "loss" in logs:
            self.loss_values.append({"step": state.global_step, "loss": logs["loss"]})

            # Save loss to file
            with open(os.path.join(self.log_dir, "loss.json"), "w") as f:
                json.dump(self.loss_values, f, indent=2)

            # Plot loss
            if len(self.loss_values) % 10 == 0:
                self._plot_loss()

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Called after evaluation."""
        if metrics is None:
            return

        # Add step number to metrics
        metrics["step"] = state.global_step
        self.evaluation_metrics.append(metrics)

        # Check if we have a new best model
        if metrics.get("eval_rouge1", 0) > self.best_metrics["rouge1"]:
            self.best_metrics["rouge1"] = metrics.get("eval_rouge1", 0)
            self.best_metrics["rouge2"] = metrics.get("eval_rouge2", 0)
            self.best_metrics["rougeL"] = metrics.get("eval_rougeL", 0)
            self.best_metrics["bleu"] = metrics.get("eval_bleu", 0)
            self.best_metrics["step"] = state.global_step

        # Log eval metrics
        logger.info(f"Evaluation at step {state.global_step}:")
        logger.info(f"  ROUGE-1: {metrics.get('eval_rouge1', 0):.4f}")
        logger.info(f"  ROUGE-2: {metrics.get('eval_rouge2', 0):.4f}")
        logger.info(f"  ROUGE-L: {metrics.get('eval_rougeL', 0):.4f}")
        logger.info(f"  BLEU: {metrics.get('eval_bleu', 0):.4f}")

        # Save eval metrics
        with open(os.path.join(self.log_dir, "eval_metrics.json"), "w") as f:
            json.dump(self.evaluation_metrics, f, indent=2)

        # Plot metrics
        self._plot_metrics()

    def on_train_end(self, args, state, control, **kwargs):

        # Calculate total training time
        self.training_time = time.time() - self.start_time

        # Log final stats
        logger.info("Training completed")
        logger.info(f"Total training time: {datetime.timedelta(seconds=self.training_time)}")
        logger.info(f"Average step time: {sum(self.step_times) / len(self.step_times) if self.step_times else 0:.3f} seconds")
        logger.info(f"Total steps: {state.global_step}")
        logger.info(f"Best metrics: {self.best_metrics}")

        # Calculate final metrics
        final_metrics = self._calculate_final_metrics()

        # Save final report
        self._save_final_report(state, final_metrics)

        # Create and save visualizations
        self._plot_loss()
        self._plot_metrics()
        self._plot_memory_usage()

    def _calculate_final_metrics(self):
        """Calculate comprehensive final metrics on the evaluation set."""
        return self.best_metrics

    def _plot_loss(self):
        """Plot the loss curve."""
        if not self.loss_values:
            return

        plt.figure(figsize=(10, 6))
        steps = [item["step"] for item in self.loss_values]
        losses = [item["loss"] for item in self.loss_values]
        plt.plot(steps, losses)
        plt.title("Training Loss")
        plt.xlabel("Step")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.savefig(os.path.join(self.log_dir, "loss_curve.png"))
        plt.close()

    def _plot_metrics(self):
        """Plot the evaluation metrics."""
        if not self.evaluation_metrics:
            return

        plt.figure(figsize=(12, 8))

        # Extract metrics
        steps = [item["step"] for item in self.evaluation_metrics]
        rouge1 = [item.get("eval_rouge1", 0) for item in self.evaluation_metrics]
        rouge2 = [item.get("eval_rouge2", 0) for item in self.evaluation_metrics]
        rougeL = [item.get("eval_rougeL", 0) for item in self.evaluation_metrics]
        bleu = [item.get("eval_bleu", 0) for item in self.evaluation_metrics]

        # Plot metrics
        plt.plot(steps, rouge1, label="ROUGE-1")
        plt.plot(steps, rouge2, label="ROUGE-2")
        plt.plot(steps, rougeL, label="ROUGE-L")
        plt.plot(steps, bleu, label="BLEU")

        plt.title("Evaluation Metrics")
        plt.xlabel("Step")
        plt.ylabel("Score")
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(self.log_dir, "eval_metrics.png"))
        plt.close()

    def _plot_memory_usage(self):
        """Plot memory usage during training."""
        if not self.gpu_memory_stats:
            return

        plt.figure(figsize=(12, 8))

        # Extract memory stats
        steps = [item["step"] for item in self.gpu_memory_stats]
        allocated = [item.get("gpu_allocated_gb", 0) for item in self.gpu_memory_stats]
        reserved = [item.get("gpu_reserved_gb", 0) for item in self.gpu_memory_stats]

        # Plot memory usage
        plt.plot(steps, allocated, label="Allocated")
        plt.plot(steps, reserved, label="Reserved")

        plt.title("GPU Memory Usage")
        plt.xlabel("Step")
        plt.ylabel("Memory (GB)")
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(self.log_dir, "memory_usage.png"))
        plt.close()

    def _save_final_report(self, state, final_metrics):
        """Create and save a comprehensive final report."""
        report = {
            "training_time_seconds": self.training_time,
            "training_time_formatted": str(datetime.timedelta(seconds=self.training_time)),
            "total_steps": state.global_step,
            "average_step_time": sum(self.step_times) / len(self.step_times) if self.step_times else 0,
            "peak_memory_usage": max([s.get("gpu_allocated_gb", 0) for s in self.gpu_memory_stats]) if self.gpu_memory_stats else 0,
            "final_metrics": final_metrics,
            "best_metrics": self.best_metrics
        }

        # Save to file
        with open(os.path.join(self.log_dir, "final_report.json"), "w") as f:
            json.dump(report, f, indent=2)

        # Print report
        logger.info("\n" + "="*50)
        logger.info("TRAINING COMPLETE - FINAL REPORT")
        logger.info("="*50)
        logger.info(f"Total training time: {report['training_time_formatted']}")
        logger.info(f"Total steps: {report['total_steps']}")
        logger.info(f"Average step time: {report['average_step_time']:.3f} seconds")
        logger.info(f"Peak memory usage: {report['peak_memory_usage']:.2f} GB")
        logger.info("\nBest metrics:")
        logger.info(f"  ROUGE-1: {report['best_metrics']['rouge1']:.4f}")
        logger.info(f"  ROUGE-2: {report['best_metrics']['rouge2']:.4f}")
        logger.info(f"  ROUGE-L: {report['best_metrics']['rougeL']:.4f}")
        logger.info(f"  BLEU: {report['best_metrics']['bleu']:.4f}")
        logger.info(f"  (achieved at step {report['best_metrics']['step']})")
        logger.info("="*50)


In [13]:
def compute_metrics_function(eval_pred, tokenizer):
    """
    Compute ROUGE and BLEU metrics for model evaluation.
    """
    metrics_calculator = COMPUTEMETRICS()

    # Unpack predictions and labels
    predictions, labels = eval_pred

    # For models that return loss and logits as a tuple
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    # If predictions are in logits format (3D tensor), convert to token IDs
    if len(predictions.shape) == 3:
        predictions = np.argmax(predictions, axis=-1)

    # Create masks for valid label positions (where labels != -100)
    label_mask = labels != -100

    # Create decoded lists
    decoded_preds = []
    decoded_labels = []

    # Process each sample individually to handle variable length properly
    for pred, label, mask in zip(predictions, labels, label_mask):
        # Filter out the padding and ignored positions
        filtered_pred = pred[mask]
        filtered_label = label[mask]

        # Decode to text
        pred_text = tokenizer.decode(filtered_pred, skip_special_tokens=True)
        label_text = tokenizer.decode(filtered_label, skip_special_tokens=True)

        decoded_preds.append(pred_text)
        decoded_labels.append(label_text)

    # Compute metrics
    metrics = metrics_calculator.calculate_metrics(decoded_preds, decoded_labels)

    return metrics

In [14]:


output_dir = "./ether_llama32"
metrics_dir = os.path.join(output_dir, "metrics")
model_name = "unsloth/llama-3.2-3b-instruct"
learning_rate = 2e-3
batch_size = 4
gradient_accumulation_steps = 8
use_ether_plus = True
n_blocks = 16
sample_size = 1000
use_wandb = True

# Setting random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Create output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(metrics_dir, exist_ok=True)

# Initialize wandb if enabled
if use_wandb:
    wandb.init(
        project="ether-llama32",
        config={
            "model_name": model_name,
            "learning_rate": learning_rate,
            "batch_size": batch_size * gradient_accumulation_steps,
            "use_ether_plus": use_ether_plus,
            "n_blocks": n_blocks,
            "sample_size": sample_size,
            "ether_variant": "ETHER+" if use_ether_plus else "ETHER",
        }
    )


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdinesh-te[0m ([33mdinesh-te-northeastern-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:

# Log initial memory
logger.info("Initial GPU memory:")
initial_memory = get_memory_stats()
logger.info(json.dumps(initial_memory, indent=2))



In [16]:

# Load model and tokenizer

start_loading = time.time()

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

loading_time = time.time() - start_loading
print(f"Model loaded in {loading_time:.2f} seconds")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Model loaded in 5.20 seconds


In [17]:
model.gradient_checkpointing_enable()


In [18]:
# Log memory after loading
logger.info("GPU memory after model loading:")
post_loading_memory = get_memory_stats()
print(json.dumps(post_loading_memory, indent=2))


{
  "gpu_allocated_gb": 5.984213352203369,
  "gpu_reserved_gb": 6.720703125,
  "gpu_max_allocated_gb": 6.71875,
  "gpu_max_reserved_gb": 6.720703125,
  "cpu_used_gb": 3.4578285217285156,
  "cpu_total_gb": 83.47704696655273,
  "cpu_percent": 4.1
}


In [19]:


# Applying ETHER transformations
logger.info("Applying ETHER transformations...")
start_ether = time.time()

model = apply_ether_to_model(
    model,
    use_ether_plus=True,
    n_blocks=n_blocks,
    double_sided=True
)

ether_time = time.time() - start_ether
print(f"ETHER transformations applied in {ether_time:.2f} seconds")

ETHER transformations applied in 0.20 seconds


In [47]:

# Get dataset
logger.info("Preparing dataset...")
start_dataset = time.time()

datasets = load_alpaca_dataset(
    tokenizer=tokenizer,
    max_length=512,
    sample_size=sample_size
)

dataset_time = time.time() - start_dataset
print(f"Dataset prepared in {dataset_time:.2f} seconds")

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset prepared in 3.44 seconds


In [48]:



train_dataset = datasets["train"]
eval_dataset = datasets["eval"]
print(f"Training on {len(train_dataset)} examples, evaluating on {len(eval_dataset)} examples")


Training on 900 examples, evaluating on 150 examples


In [22]:
#  data collator
data_collator = DATACOLLATOR(tokenizer)

In [23]:
# Create metrics tracker
metrics_tracker = MetricsCallback(tokenizer, eval_dataset, log_dir=metrics_dir)

In [24]:
def compute_metrics(eval_pred):
    return compute_metrics_function(eval_pred, tokenizer)

In [25]:
#Setting training arguments
training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,

        weight_decay=0.0,
        logging_dir=os.path.join(output_dir, "logs"),
        logging_steps=10,
        save_steps=50,
        eval_steps=10,
        save_total_limit=2,

        bf16=torch.cuda.is_bf16_supported(),
        fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),

        dataloader_drop_last=False,
        report_to="wandb",
        optim="adamw_torch",
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,

        remove_unused_columns=False,
        )


In [26]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[metrics_tracker]
)

  trainer = Trainer(


In [27]:

# Log memory before training
logger.info("GPU memory before training:")
pre_training_memory = get_memory_stats()
logger.info(json.dumps(pre_training_memory, indent=2))


In [28]:
# Train the model
logger.info("Starting training...")
start_training = time.time()

train_result = trainer.train()


training_time = time.time() - start_training
print(f"Training completed in {datetime.timedelta(seconds=training_time)}")


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,10.9386
20,7.9978


In [29]:
# Log training metrics
train_metrics = {
        "train_runtime": training_time,
        "train_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
        "train_steps_per_second": train_result.metrics.get("train_steps_per_second", 0),
        "train_loss": train_result.metrics.get("train_loss", 0),
    }
print(f"Training metrics: {train_metrics}")

Training metrics: {'train_runtime': 1483.8208255767822, 'train_samples_per_second': 0.607, 'train_steps_per_second': 0.019, 'train_loss': 8.908872604370117}


In [33]:
torch.cuda.empty_cache()
torch.cuda.synchronize()

In [34]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [43]:

# Print final memory stats
logger.info("Final GPU memory:")
final_memory = get_memory_stats()
print(json.dumps(final_memory, indent=2))

{
  "gpu_allocated_gb": 15.362429141998291,
  "gpu_reserved_gb": 36.0625,
  "gpu_max_allocated_gb": 15.362429141998291,
  "gpu_max_reserved_gb": 36.0625,
  "cpu_used_gb": 5.143348693847656,
  "cpu_total_gb": 83.47704696655273,
  "cpu_percent": 6.2
}


In [50]:

summary = {
    "model": model_name,
    "ether_variant": "ETHER+" if use_ether_plus else "ETHER",
    "blocks": n_blocks,
    "dataset_size": len(train_dataset),
    "times": {
        "loading": loading_time,
        "ether_application": ether_time,
        "dataset_preparation": dataset_time,
        "training": training_time,
        "evaluation": eval_time,
        "total": loading_time + ether_time + dataset_time + training_time + eval_time
    },
    "memory": {
        "initial": initial_memory,
        "after_loading": post_loading_memory,
        "before_training": pre_training_memory,
        "final": final_memory,
        "peak_allocated": max([s.get("gpu_allocated_gb", 0) for s in metrics_tracker.gpu_memory_stats]) if metrics_tracker.gpu_memory_stats else 0
    },
    "train_metrics": train_metrics
}

In [56]:

# Print final summary
print("\n" + "="*50)
print("ETHER TRAINING SUMMARY")
print("="*50)
print(f"Model: {summary['model']}")
print(f"Variant: {summary['ether_variant']} with {summary['blocks']} blocks")
print(f"Dataset size: {summary['dataset_size']} examples")
print("="*50)
print("\nTime statistics:")
print(f"  Model loading: {summary['times']['loading']:.2f} seconds")
print(f"  ETHER application: {summary['times']['ether_application']:.2f} seconds")
print(f"  Dataset preparation: {summary['times']['dataset_preparation']:.2f} seconds")
print(f"  Training: {datetime.timedelta(seconds=summary['times']['training'])}")
print(f"  Evaluation: {summary['times']['evaluation']:.2f} seconds")
print(f"  Total: {datetime.timedelta(seconds=summary['times']['total'])}")
print("="*50)
print("\nMemory statistics:")
print(f"  Initial: {summary['memory']['initial']['gpu_allocated_gb']:.2f} GB")
print(f"  After loading: {summary['memory']['after_loading']['gpu_allocated_gb']:.2f} GB")
print(f"  Peak during training: {summary['memory']['peak_allocated']:.2f} GB")
print(f"  Final: {summary['memory']['final']['gpu_allocated_gb']:.2f} GB")
print("="*50)


ETHER TRAINING SUMMARY
Model: unsloth/llama-3.2-3b-instruct
Variant: ETHER+ with 16 blocks
Dataset size: 900 examples

Time statistics:
  Model loading: 5.20 seconds
  ETHER application: 0.20 seconds
  Dataset preparation: 3.44 seconds
  Training: 0:24:43.820826
  Evaluation: 72.02 seconds
  Total: 0:26:04.679760

Memory statistics:
  Initial: 0.00 GB
  After loading: 5.98 GB
  Peak during training: 15.36 GB
  Final: 15.36 GB


In [68]:
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

('./ether-lora-optimized/tokenizer_config.json',
 './ether-lora-optimized/special_tokens_map.json',
 './ether-lora-optimized/tokenizer.json')

In [None]:
def evaluate_base_model(model,eval_dataset, tokenizer, device="cuda"):
    """
    Evaluate the model on the evaluation dataset
    """
    model.eval()
    metrics_calculator = COMPUTEMETRICS()

    all_decoded_preds = []
    all_decoded_labels = []
    examples_to_show = 5


    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=2,  
        collate_fn=data_collator
    )

    print(f"Processing {len(eval_dataset)} evaluation examples")

    with torch.no_grad():
        for batch_idx, batch in enumerate(eval_dataloader):
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            # Get model output
            outputs = model(**batch)

            # Get predictions
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1).cpu().numpy()
            labels = batch["labels"].cpu().numpy()

            # Process each example in the batch individually
            for i in range(predictions.shape[0]):
                pred = predictions[i]
                label = labels[i]

                # Replace -100 with padding token
                label_mask = label != -100
                filtered_label = label[label_mask]

                # Find where the actual output starts (where labels are not -100)
                valid_positions = np.where(label_mask)[0]
                if len(valid_positions) > 0:
                    # Get the prediction outputs only for valid positions
                    filtered_pred = pred[valid_positions]

                    # Decode text
                    try:
                        pred_text = tokenizer.decode(filtered_pred, skip_special_tokens=True)
                        label_text = tokenizer.decode(filtered_label, skip_special_tokens=True)

                        all_decoded_preds.append(pred_text)
                        all_decoded_labels.append(label_text)

                        # Print some examples to inspect
                        example_idx = batch_idx * batch["input_ids"].shape[0] + i
                        if example_idx < examples_to_show:
                            print(f"\nExample {example_idx+1}:")
                            # Get the input prompt
                            input_ids = batch["input_ids"][i].cpu().numpy()
                            prompt = tokenizer.decode(input_ids, skip_special_tokens=False)
                            # Find where the assistant's response starts
                            assistant_start = prompt.rfind("<|start_header_id|>assistant<|end_header_id|>")
                            if assistant_start != -1:
                                prompt = prompt[:assistant_start + len("<|start_header_id|>assistant<|end_header_id|>")]
                            print(f"Input prompt: {prompt}")
                            print(f"Prediction: {pred_text}")
                            print(f"Reference: {label_text}")
                    except Exception as e:
                        print(f"Error decoding example: {e}")

    # Calculate metrics
    metrics = metrics_calculator.calculate_metrics(all_decoded_preds, all_decoded_labels)

    # Print overall metrics
    print("\nEvaluation Metrics for Base Model:")
    print(f"ROUGE-1: {metrics['rouge1']:.4f}")
    print(f"ROUGE-2: {metrics['rouge2']:.4f}")
    print(f"ROUGE-L: {metrics['rougeL']:.4f}")
    print(f"BLEU: {metrics['bleu']:.4f}")

    # Print comparison for a few examples
    print("\n====== Detailed Comparison of Base Model vs Expected Output ======")
    for i in range(min(examples_to_show, len(all_decoded_preds))):
        pred = all_decoded_preds[i]
        ref = all_decoded_labels[i]

        # Calculate individual metrics for this example
        rouge_scores = metrics_calculator.calculate_rouge(pred, ref)
        bleu_score = metrics_calculator.calculate_bleu(pred, ref)

        print(f"\n--- Example {i+1} ---")
        print(f"Base Model Output: {pred[:200]}..." if len(pred) > 200 else f"Base Model Output: {pred}")
        print(f"Expected Output: {ref[:200]}..." if len(ref) > 200 else f"Expected Output: {ref}")
        print(f"Individual Metrics:")
        print(f"  ROUGE-1: {rouge_scores['rouge1']:.4f}")
        print(f"  ROUGE-2: {rouge_scores['rouge2']:.4f}")
        print(f"  ROUGE-L: {rouge_scores['rougeL']:.4f}")
        print(f"  BLEU: {bleu_score:.4f}")

    # Clean up to free memory
    del model
    torch.cuda.empty_cache()

    return {
        "predictions": all_decoded_preds,
        "references": all_decoded_labels,
        "metrics": metrics
    }


In [66]:
model_results = evaluate_base_model(model,eval_dataset, tokenizer)

Processing 150 evaluation examples

Example 1:
Input prompt: <|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.<|start_header_id|>user<|end_header_id|>
Calculate the mass of 4.5 moles of carbon dioxide.<|start_header_id|>assistant<|end_header_id|>
Prediction: The mass of 4.5 moles of carbon dioxide is 4 grams0 grams. is be calculated by multiplying the4.5 moles by the molar mass of carbon dioxide, which is 44.01 grams per mole.
Reference: The mass of 4.5 moles of carbon dioxide is 324.75 grams. This can be calculated by multiplying 4.5 moles by the molar mass of carbon dioxide, which is 44.01 grams per mole.

Example 2:
Input prompt: <|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.<|start_header_id|>user<|end_header_

In [67]:
print(model_results["metrics"])

{'rouge1': 0.627784442891029, 'rouge2': 0.3929223310724818, 'rougeL': 0.5823602828416555, 'bleu': 0.3304549498825494}
