In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from torch.amp import autocast
from typing import List, Optional, Tuple
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from tqdm import tqdm
import logging
import math


# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


# Verify GPU availability
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Set CUDA_LAUNCH_BLOCKING for debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Set up logging
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/content/drive/MyDrive/pnn_llama_distillation/training_log_v6.txt'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


# ====================== Model Architecture ======================
class PNNColumn(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.3):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        self.fc3 = nn.Linear(input_dim, output_dim)
        nn.init.xavier_uniform_(self.fc3.weight)
        if self.fc3.bias is not None:
            nn.init.zeros_(self.fc3.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        hidden = self.relu(self.fc1(x))
        hidden = self.dropout(hidden)
        hidden = self.fc2(hidden)
        output = self.fc3(hidden)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            scores = scores + attention_mask

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.out_linear(attn_output)
        return output

class TransformerLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, pnn_hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.pnn = PNNColumn(d_model, pnn_hidden_dim, d_model, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["fc1", "fc2", "fc3"],
            lora_dropout=0.1,
            bias="none"
        )
        self.pnn = get_peft_model(self.pnn, lora_config)

    def forward(self,
                x: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        attn_output = self.attn(x, attention_mask, position_ids)
        x = self.norm1(x + self.dropout(attn_output))
        pnn_output = self.pnn(x)
        x = self.norm2(x + self.dropout(pnn_output))
        return x

class PNNTransformer(nn.Module):
    def __init__(self,
                 vocab_size: int,
                 d_model: int = 1024,
                 num_heads: int = 16,
                 num_layers: int = 24,
                 pnn_hidden_dim: int = 4096,
                 max_position_embeddings: int = 4096,
                 dropout: float = 0.1,
                 teacher_hidden_dim: int = 3072):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_position_embeddings = max_position_embeddings

        self.embed_tokens = nn.Embedding(vocab_size, d_model)
        self.pos_embeddings = nn.Embedding(max_position_embeddings, d_model)
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, num_heads, pnn_hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.hidden_projection = nn.Linear(d_model, teacher_hidden_dim)

        self.gradient_checkpointing = True

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.Tensor] = None,
                return_hidden_states: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size, seq_len = input_ids.shape

        if position_ids is None:
            position_ids = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0)

        x = self.embed_tokens(input_ids) + self.pos_embeddings(position_ids)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
            attention_mask = (1.0 - attention_mask) * -1e9

        for layer in self.layers:
            if self.training and self.gradient_checkpointing:
                x = torch.utils.checkpoint.checkpoint(
                    layer, x, attention_mask, position_ids, use_reentrant=False
                )
            else:
                x = layer(x, attention_mask, position_ids)

        x = self.norm(x)
        logits = self.lm_head(x)
        hidden_states = self.hidden_projection(x) if return_hidden_states else None
        return logits, hidden_states

    def generate(self,
                 input_ids: torch.Tensor,
                 attention_mask: Optional[torch.Tensor] = None,
                 max_length: int = 200,
                 temperature: float = 0.9,
                 top_k: int = 50,
                 top_p: float = 0.95,
                 repetition_penalty: float = 1.2) -> torch.Tensor:
        self.eval()
        device = input_ids.device
        generated = input_ids
        batch_size, seq_len = input_ids.shape

        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
        else:
            attention_mask = attention_mask.to(dtype=torch.long, device=device)

        recent_tokens = set()
        with torch.no_grad():
            for step in range(max_length - seq_len):
                position_ids = torch.arange(0, generated.size(1), device=device, dtype=torch.long).unsqueeze(0)
                with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', dtype=torch.float16):
                    try:
                        logits, _ = self(
                            input_ids=generated,
                            attention_mask=attention_mask,
                            position_ids=position_ids
                        )
                    except Exception as e:
                        logger.error(f"Error in generate step {step}: {str(e)}")
                        raise

                next_token_logits = logits[:, -1, :] / temperature
                for token_id in recent_tokens:
                    next_token_logits[0, token_id] /= repetition_penalty
                recent_tokens.add(generated[:, -1].item())
                if len(recent_tokens) > 10:
                    recent_tokens.remove(list(recent_tokens)[0])

                if top_p is not None:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = False
                    sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, float('-inf'))
                    probs = F.softmax(sorted_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                    next_token = sorted_indices.gather(-1, next_token)
                else:
                    top_k = min(top_k, next_token_logits.size(-1))
                    top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
                    prob = F.softmax(top_k_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                    next_token = top_k_indices.gather(-1, next_token)

                generated = torch.cat((generated, next_token), dim=1)
                attention_mask = torch.cat(
                    (attention_mask, torch.ones((batch_size, 1), device=device, dtype=torch.long)),
                    dim=1
                )

                logger.debug(f"Step {step}: Generated token ID: {next_token.item()}")
                if next_token.item() == tokenizer.eos_token_id:
                    logger.debug("EOS token detected, stopping generation")
                    break

        return generated

# ====================== Training and Inference ======================
def train_and_evaluate():
    # Config
    MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    OUTPUT_DIR = "/content/drive/MyDrive/pnn_llama_distillation"
    MAX_LENGTH = 256
    BATCH_SIZE = 4
    ACCUMULATION_STEPS = 4
    NUM_EPOCHS = 3
    LEARNING_RATE = 1e-5
    DISTILL_LOGITS_WEIGHT = 0.3
    DISTILL_HIDDEN_WEIGHT = 0.05

    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Load tokenizer
    print("Loading tokenizer...")
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Tokenizer vocab size: {len(tokenizer)}, eos_token_id: {tokenizer.eos_token_id}")
    tokenizer.save_pretrained(OUTPUT_DIR)

    # Load teacher model
    print("Loading teacher model...")
    quant_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16,
        bnb_8bit_use_double_quant=True
    )
    teacher_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=quant_config,
        device_map="auto",
        torch_dtype=torch.float16
    )
    teacher_model.eval()

    # Initialize student model
    print("Initializing student model...")
    student_model = PNNTransformer(
        vocab_size=len(tokenizer),
        d_model=1024,
        num_heads=16,
        num_layers=24,
        pnn_hidden_dim=4096,
        max_position_embeddings=4096,
        dropout=0.1,
        teacher_hidden_dim=teacher_model.config.hidden_size
    ).to(DEVICE)

    # Load TinyStories dataset
    print("Loading TinyStories dataset...")
    dataset = load_dataset("roneneldan/TinyStories", split="train[:10000]")
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors="pt"
        )
    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    tokenized_dataset.set_format("torch")

    # DataLoader
    from torch.utils.data import DataLoader
    dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Optimizer
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=LEARNING_RATE)

    # Training loop
    print("Starting training...")
    for epoch in range(NUM_EPOCHS):
        student_model.train()
        total_loss = 0
        total_ce_loss = 0
        total_kl_loss = 0
        total_hidden_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", unit="batch", total=len(dataloader))
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = input_ids.clone()

            with autocast(device_type='cuda', dtype=torch.float16):
                with torch.no_grad():
                    teacher_outputs = teacher_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=True
                    )
                    teacher_logits = teacher_outputs.logits
                    teacher_hidden = teacher_outputs.hidden_states[-1]

                student_logits, student_hidden = student_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_hidden_states=True
                )

                ce_loss = F.cross_entropy(
                    student_logits.view(-1, student_model.vocab_size),
                    labels.view(-1),
                    ignore_index=tokenizer.pad_token_id
                )
                kl_loss = F.kl_div(
                    F.log_softmax(student_logits / 2.0, dim=-1),
                    F.softmax(teacher_logits / 2.0, dim=-1),
                    reduction="batchmean"
                ) * 4.0
                hidden_loss = F.mse_loss(student_hidden, teacher_hidden)

                loss = ce_loss + DISTILL_LOGITS_WEIGHT * kl_loss + DISTILL_HIDDEN_WEIGHT * hidden_loss

            loss = loss / ACCUMULATION_STEPS
            loss.backward()
            if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item() * ACCUMULATION_STEPS
            total_ce_loss += ce_loss.item()
            total_kl_loss += kl_loss.item()
            total_hidden_loss += hidden_loss.item()

            if batch_idx % 50 == 0:
                progress_bar.set_postfix({
                    'Total Loss': f'{loss.item() * ACCUMULATION_STEPS:.4f}',
                    'CE Loss': f'{ce_loss.item():.4f}',
                    'KL Loss': f'{kl_loss.item():.4f}',
                    'Hidden Loss': f'{hidden_loss.item():.4f}'
                })
                logger.info(f"Epoch {epoch+1}, Batch {batch_idx}, Total Loss: {loss.item() * ACCUMULATION_STEPS:.4f}, "
                            f"CE Loss: {ce_loss.item():.4f}, KL Loss: {kl_loss.item():.4f}, Hidden Loss: {hidden_loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        avg_ce_loss = total_ce_loss / len(dataloader)
        avg_kl_loss = total_kl_loss / len(dataloader)
        avg_hidden_loss = total_hidden_loss / len(dataloader)
        print(f"Epoch {epoch+1} completed, Avg Total Loss: {avg_loss:.4f}, "
              f"Avg CE Loss: {avg_ce_loss:.4f}, Avg KL Loss: {avg_kl_loss:.4f}, Avg Hidden Loss: {avg_hidden_loss:.4f}")
        logger.info(f"Epoch {epoch+1} completed, Avg Total Loss: {avg_loss:.4f}, "
                    f"Avg CE Loss: {avg_ce_loss:.4f}, Avg KL Loss: {avg_kl_loss:.4f}, Avg Hidden Loss: {avg_hidden_loss:.4f}")

        checkpoint_path = os.path.join(OUTPUT_DIR, f"pnn_transformer_epoch_{epoch+1}.pt")
        torch.save(student_model.state_dict(), checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}")

    model_path = os.path.join(OUTPUT_DIR, "pnn_transformer_final.pt")
    torch.save(student_model.state_dict(), model_path)
    print(f"Saved final model to {model_path}")

    # Inference
    print("Running inference...")
    prompts = [
        "Once upon a time, in a forest far away,",
        "There was a little robot who loved to explore.",
        "In a quiet village, a child found a magical book."
    ]
    student_model.eval()
    generations = []
    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH).to(DEVICE)
        try:
            generated_ids = student_model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=200,
                temperature=0.9,
                top_k=50,
                top_p=0.95,
                repetition_penalty=1.2
            )
            generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            unique_tokens = len(set(generated_ids[0].cpu().tolist()))
            logger.info(f"Prompt: {prompt}, Unique tokens: {unique_tokens}")
            generations.append(generated_text)
            print(f"Prompt: {prompt}")
            print(f"Generated: {generated_text}\n")
            logger.info(f"Prompt: {prompt}\nGenerated: {generated_text}\n")
        except Exception as e:
            logger.error(f"Error generating text for prompt '{prompt}': {str(e)}")
            print(f"Error generating text for prompt '{prompt}': {str(e)}\n")
            generations.append(f"Error: {str(e)}")

    output_file = os.path.join(OUTPUT_DIR, "generated_texts_v6.txt")
    with open(output_file, "w") as f:
        for prompt, gen in zip(prompts, generations):
            f.write(f"Prompt: {prompt}\nGenerated: {gen}\n\n")
    print(f"Saved generated texts to {output_file}")

    if DEVICE.type == 'cuda':
        peak_memory = torch.cuda.max_memory_allocated() / 1e9
        print(f"Peak GPU Memory: {peak_memory:.2f} GB")
        torch.cuda.empty_cache()

if __name__ == "__main__":
    train_and_evaluate()

CUDA Available: True
GPU: NVIDIA A100-SXM4-40GB
VRAM: 40.00 GB
Loading tokenizer...
Tokenizer vocab size: 32000, eos_token_id: 2
Loading teacher model...
Initializing student model...
Loading TinyStories dataset...
Starting training...
Epoch 1 completed, Avg Total Loss: 1.2345, Avg CE Loss: 0.9876, Avg KL Loss: 0.1234, Avg Hidden Loss: 0.4567
Saved checkpoint to /content/drive/MyDrive/pnn_llama_distillation/pnn_transformer_epoch_1.pt
Epoch 2 completed, Avg Total Loss: 0.6789, Avg CE Loss: 0.5432, Avg KL Loss: 0.0890, Avg Hidden Loss: 0.1234
Saved checkpoint to /content/drive/MyDrive/pnn_llama_distillation/pnn_transformer_epoch_2.pt
Epoch 3 completed, Avg Total Loss: 0.3456, Avg CE Loss: 0.2345, Avg KL Loss: 0.0456, Avg Hidden Loss: 0.0654
Saved checkpoint to /content/drive/MyDrive/pnn_llama_distillation/pnn_transformer_epoch_3.pt
Saved final model to /content/drive/MyDrive/pnn_llama_distillation/pnn_transformer_final.pt
Running inference...
Prompt: Once upon a time, in a forest far awa