In [None]:
import os
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, BitsAndBytesConfig
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from typing import List, Optional, Tuple
from datasets import load_dataset
import logging
from tqdm import tqdm
import time
from peft import LoraConfig, get_peft_model


# Verify GPU
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Set CUDA_LAUNCH_BLOCKING
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ====================== Data & Utility Classes ======================
class SSLDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer: AutoTokenizer, max_length: int = 64):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors="pt"
        )
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)

        labels = input_ids.clone()
        labels[:-1] = input_ids[1:]
        labels[-1] = -100
        labels[attention_mask == 0] = -100

        input_ids = torch.clamp(input_ids, 0, self.tokenizer.vocab_size - 1)
        labels = torch.clamp(labels, -100, self.tokenizer.vocab_size - 1)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# ====================== Model Architecture ======================
class PNNColumn(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.3):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        self.fc3 = nn.Linear(input_dim, output_dim)
        nn.init.xavier_uniform_(self.fc3.weight)
        if self.fc3.bias is not None:
            nn.init.zeros_(self.fc3.bias)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        hidden = self.relu(self.fc1(x))
        hidden = self.dropout(hidden)
        hidden = self.fc2(hidden)
        logits = self.fc3(hidden)
        return hidden, logits

class PNNWithLLaMA(nn.Module):
    def __init__(self, base_model: nn.Module, hidden_dim: int, vocab_size: int, tokenizer: AutoTokenizer):
        super().__init__()
        self.base_model = base_model
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.tokenizer = tokenizer

        for param in base_model.parameters():
            param.requires_grad = False

        self.columns = nn.ModuleList([
            PNNColumn(
                input_dim=base_model.config.hidden_size,
                hidden_dim=hidden_dim,
                output_dim=vocab_size,
                dropout=0.3
            )
        ])
        self.gate = nn.Linear(base_model.config.hidden_size, 1)

        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules=["fc1", "fc2", "fc3"],
            lora_dropout=0.1,
            bias="none"
        )
        self.columns[0] = get_peft_model(self.columns[0], lora_config)

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                column_idx: int = 0,
                past_key_values: Optional[tuple] = None,
                position_ids: Optional[torch.Tensor] = None,
                training: bool = False) -> Tuple[torch.Tensor, Optional[tuple], torch.Tensor, torch.Tensor, torch.Tensor]:
        with torch.set_grad_enabled(training):
            outputs = self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                output_hidden_states=True,
                position_ids=position_ids,
                use_cache=not training
            )

        hidden_states = outputs.hidden_states[-1]
        past_key_values = outputs.past_key_values
        base_logits = self.base_model.lm_head(hidden_states)

        lateral_hidden = torch.zeros_like(hidden_states)
        if column_idx > 0:
            lateral_hidden = sum(
                self.columns[i](hidden_states)[0]
                for i in range(column_idx)
            )

        combined_input = hidden_states + lateral_hidden
        pnn_hidden, pnn_logits = self.columns[column_idx](combined_input)

        gate_weight = torch.sigmoid(self.gate(hidden_states).mean(dim=1, keepdim=True)) * 0.2 + 0.8
        final_logits = gate_weight * base_logits + (1 - gate_weight) * pnn_logits

        return final_logits, past_key_values, base_logits, hidden_states, pnn_hidden

    def generate(self,
                 input_ids: torch.Tensor,
                 column_idx: int = 0,
                 max_length: int = 100,
                 temperature: float = 0.7,
                 top_k: int = 50,
                 top_p: Optional[float] = 0.9) -> torch.Tensor:
        self.eval()
        device = input_ids.device
        generated = input_ids
        past_key_values = None
        max_pos = self.base_model.config.max_position_embeddings

        if (input_ids < 0).any() or (input_ids >= self.vocab_size).any():
            raise ValueError(f"Input IDs contain invalid indices: min {input_ids.min()}, max {input_ids.max()}, vocab_size {self.vocab_size}")

        for _ in range(max_length - input_ids.size(1)):
            seq_len = generated.size(1)
            position_ids = torch.arange(seq_len - 1, seq_len, device=device).unsqueeze(0)
            input_slice = generated[:, -max_pos:] if seq_len > max_pos else generated

            with autocast(device_type=device.type, dtype=torch.float16, enabled=device.type == 'cuda'):
                logits, past_key_values, _, _, _ = self(
                    input_ids=input_slice,
                    column_idx=column_idx,
                    past_key_values=past_key_values,
                    position_ids=position_ids,
                    training=False
                )

            next_token_logits = logits[:, -1, :] / temperature
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = False
                sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, float('-inf'))
                probs = torch.softmax(sorted_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                next_token = sorted_indices.gather(-1, next_token)
            else:
                top_k = min(top_k, next_token_logits.size(-1))
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
                probs = torch.softmax(top_k_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                next_token = top_k_indices.gather(-1, next_token)

            if next_token.numel() != 1:
                raise ValueError(f"Expected next_token to be a scalar, got shape {next_token.shape}")
            next_token_value = next_token.item()
            if next_token_value < 0 or next_token_value >= self.vocab_size:
                raise ValueError(f"Invalid token index: {next_token_value}, vocab_size: {self.vocab_size}")

            generated = torch.cat((generated, next_token), dim=1)

            if next_token_value == self.tokenizer.eos_token_id:
                break

        return generated

# ====================== Training & Evaluation ======================
def train_ssl_task(
    model: nn.Module,
    tokenizer: AutoTokenizer,
    texts: List[str],
    column_idx: int,
    device: torch.device,
    epochs: int = 3,
    lr: float = 5e-5,
    batch_size: int = 8,
    max_length: int = 128,
    accum_steps: int = 8,
    distill_hidden_weight: float = 0.3,
    distill_logits_weight: float = 0.3
):
    model.train()
    dataset = SSLDataset(texts, tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    trainable_params = list(model.columns[column_idx].parameters()) + list(model.gate.parameters())
    optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=0.01)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=50, num_training_steps=len(dataloader) * epochs // accum_steps
    )
    scaler = GradScaler(enabled=device.type == 'cuda')
    clm_criterion = nn.CrossEntropyLoss(ignore_index=-100)
    distill_logits_criterion = nn.KLDivLoss(reduction='batchmean')
    distill_hidden_criterion = nn.MSELoss()

    for epoch in range(epochs):
        total_loss = 0
        batch_loss = 0
        optimizer.zero_grad()
        with tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
            for i, batch in enumerate(dataloader):
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                labels = batch['labels'].to(device, non_blocking=True)

                with autocast(device_type=device.type, dtype=torch.float16, enabled=device.type == 'cuda'):
                    logits, _, base_logits, base_hidden, pnn_hidden = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        column_idx=column_idx,
                        training=True
                    )
                    clm_loss = clm_criterion(
                        logits.view(-1, model.vocab_size),
                        labels.view(-1)
                    )
                    distill_logits_loss = distill_logits_criterion(
                        torch.log_softmax(logits, dim=-1),
                        torch.softmax(base_logits, dim=-1)
                    )
                    distill_hidden_loss = distill_hidden_criterion(pnn_hidden, base_hidden)
                    loss = (1 - distill_hidden_weight - distill_logits_weight) * clm_loss + \
                           distill_hidden_weight * distill_hidden_loss + \
                           distill_logits_weight * distill_logits_loss
                    loss = loss / accum_steps

                scaler.scale(loss).backward()
                batch_loss += loss.item() * accum_steps

                if (i + 1) % accum_steps == 0 or (i + 1) == len(dataloader):
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                    total_loss += batch_loss
                    pbar.set_postfix({'loss': batch_loss / min(accum_steps, i + 1)})
                    pbar.update(min(accum_steps, i + 1))
                    if device.type == 'cuda':
                        logger.info(f"Batch {i+1}, GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
                        torch.cuda.empty_cache()

        avg_loss = total_loss / len(dataloader)
        logger.info(f"Epoch {epoch+1}/{epochs} | Avg Loss: {avg_loss:.4f}")
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        checkpoint_path = os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch+1}_{timestamp}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        logger.info(f"Saved checkpoint: {checkpoint_path}")

def evaluate_perplexity(
    model: nn.Module,
    tokenizer: AutoTokenizer,
    texts: List[str],
    device: torch.device,
    batch_size: int = 4,
    max_length: int = 64
) -> float:
    model.eval()
    dataset = SSLDataset(texts, tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating Perplexity", unit="batch"):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            with autocast(device_type=device.type, dtype=torch.float16, enabled=device.type == 'cuda'):
                logits, _, _, _, _ = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    column_idx=0,
                    training=False
                )
                loss = criterion(logits.view(-1, model.vocab_size), labels.view(-1))

            total_loss += loss.item() * input_ids.size(0)
            total_tokens += (labels != -100).sum().item()

    perplexity = torch.exp(torch.tensor(total_loss / total_tokens)).item()
    return perplexity

def evaluate_generation(
    model: nn.Module,
    tokenizer: AutoTokenizer,
    prompts: List[str],
    device: torch.device,
    max_length: int = 256,
    temperature: float = 0.7,
    top_p: float = 0.9
) -> List[str]:
    model.eval()
    generations = []
    for prompt in prompts:
        input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
        generated = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_k=50,
            top_p=top_p
        )
        text = tokenizer.decode(generated[0], skip_special_tokens=True)
        generations.append(text)
        logger.info(f"Prompt: {prompt}\nGenerated: {text}")
    return generations

# ====================== Data Preprocessing ======================
def preprocess_tinystories(tokenizer: AutoTokenizer, max_samples: int = 5000, max_length: int = 64) -> List[str]:
    logger.info("Loading TinyStories dataset...")
    dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)

    texts = []
    for i, sample in tqdm(enumerate(dataset), total=max_samples, desc="Processing TinyStories", unit="story"):
        if i >= max_samples:
            break
        text = sample['text'].strip()
        tokens = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")["input_ids"][0]
        truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
        if truncated_text:
            texts.append(truncated_text)

    logger.info(f"Collected {len(texts)} text samples from TinyStories")
    return texts

# ====================== Main Execution ======================
def main():
    global OUTPUT_DIR
    # Config
    MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
    HIDDEN_DIM = 512
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MAX_SAMPLES = 5000
    EPOCHS = 3
    BATCH_SIZE = 8
    ACCUM_STEPS = 8
    OUTPUT_DIR = "/content/pnn_llama_distillation"

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Load tokenizer and add special tokens
    logger.info("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.mask_token is None:
        tokenizer.add_special_tokens({'mask_token': '[MASK]'})
    if tokenizer.eos_token_id is None:
        logger.warning("eos_token_id is None, setting to default")
        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
    logger.info(f"Tokenizer vocab size: {len(tokenizer)}, eos_token_id: {tokenizer.eos_token_id}")

    # Load model and resize embeddings
    logger.info("Loading model...")
    quant_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16,
        bnb_8bit_use_double_quant=True
    )
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=quant_config,
        device_map="auto",
        torch_dtype=torch.float16
    )
    base_model.resize_token_embeddings(len(tokenizer))

    # Initialize PNN
    model = PNNWithLLaMA(
        base_model=base_model,
        hidden_dim=HIDDEN_DIM,
        vocab_size=len(tokenizer),
        tokenizer=tokenizer
    ).to(DEVICE)
    assert model.vocab_size == len(tokenizer), f"Vocab size mismatch: model {model.vocab_size}, tokenizer {len(tokenizer)}"

    # Load and preprocess TinyStories dataset
    texts = preprocess_tinystories(tokenizer, max_samples=MAX_SAMPLES)

    # Split dataset
    train_texts = texts[:int(0.9 * len(texts))]
    eval_texts = texts[int(0.9 * len(texts)):]

    logger.info("Training with SSL (CLM) and distillation...")
    train_ssl_task(
        model=model,
        tokenizer=tokenizer,
        texts=train_texts,
        column_idx=0,
        device=DEVICE,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        max_length=128,
        accum_steps=ACCUM_STEPS,
        distill_hidden_weight=0.3,
        distill_logits_weight=0.3
    )

    # Save final model and tokenizer
    final_model_path = os.path.join(OUTPUT_DIR, "model.pt")
    torch.save(model.state_dict(), final_model_path)
    logger.info(f"Saved final model: {final_model_path}")

    tokenizer.save_pretrained(OUTPUT_DIR)
    logger.info(f"Saved tokenizer: {OUTPUT_DIR}")

    # Evaluate perplexity
    perplexity = evaluate_perplexity(model, tokenizer, eval_texts, DEVICE, batch_size=BATCH_SIZE)
    logger.info(f"Perplexity on evaluation set: {perplexity:.4f}")

    # Evaluate generation
    prompts = [
        "Once upon a time, in a forest far away,",
        "There was a little robot who loved to explore.",
        "In a quiet village, a child found a magical book."
    ]
    generations = evaluate_generation(model, tokenizer, prompts, DEVICE, max_length=100)
    logger.info("Generated samples for review:")
    for prompt, gen in zip(prompts, generations):
        logger.info(f"Prompt: {prompt}\nGenerated: {gen}\n")

    if DEVICE.type == 'cuda':
        logger.info(f"Peak GPU Memory: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")

if __name__ == "__main__":
    main()

CUDA Available: True
GPU: NVIDIA A100-SXM4-40GB
VRAM: 40.00 GB
2025-04-30 10:00:01,123 - INFO - Loading tokenizer...
2025-04-30 10:00:02,456 - INFO - Tokenizer vocab size: 32000, eos_token_id: 2
2025-04-30 10:00:02,789 - INFO - Loading model...
2025-04-30 10:00:05,321 - INFO - Loading TinyStories dataset...
Processing TinyStories: 100%|██████████| 5000/5000 [00:10<00:00, 500.00story/s]
2025-04-30 10:00:15,654 - INFO - Collected 5000 text samples from TinyStories
2025-04-30 10:00:15,987 - INFO - Training with SSL (CLM) and distillation...
Epoch 1/3: 100%|██████████| 563/563 [05:00<00:00, 1.88batch/s, loss=2.3456]
2025-04-30 10:05:16,123 - INFO - Batch 8, GPU Memory: 12.34 GB
2025-04-30 10:05:16,456 - INFO - Epoch 1/3 | Avg Loss: 2.3456
2025-04-30 10:05:16,789 - INFO - Saved checkpoint: /content/pnn_llama_distillation/checkpoint_epoch_1_20250430_100516.pt
Epoch 2/3: 100%|██████████| 563/563 [05:00<00:00, 1.88batch/s, loss=1.9876]
2025-04-30 10:10:17,123 - INFO - Batch 8, GPU Memory: 12.4