In [1]:
import sys
import torch
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer, LlamaForCausalLM, get_linear_schedule_with_warmup
from datasets import load_dataset
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

# Configuration
MODEL_NAME = "huggingface/llama"  # Replace with the actual model name/path
DATASET_NAME = "wikipedia"  # Wikimedia dataset
BATCH_SIZE = 4  # Adjust based on your GPU memory
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 5e-5
EPOCHS = 3
TOP_PERCENT = 0.3  # Top 30%

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
model.to(device)
model.train()

# Load and preprocess dataset
# Here, we're using a subset for demonstration. Adjust as needed.
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

def tokenize_function(examples):
    return tokenizer(examples["text"], return_tensors="pt", truncation=True, padding='max_length', max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Convert to PyTorch tensors
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# Create DataLoader
dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(dataloader) * EPOCHS // GRADIENT_ACCUMULATION_STEPS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_steps//10, num_training_steps=total_steps)


Using device: cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:

# Training loop
model.train()
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Training")
    
    optimizer.zero_grad()
    
    for step, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)
        print(logits.shape)
        print(logits)
        print(input_ids.shape)
        print(outputs.labels)
        sys.exit()
        # Shift tokens for next-token prediction
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_attention = attention_mask[:, 1:].contiguous()
        
        loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
        pre_token_loss = loss_fn(shift_logits, )

        # Compute log probabilities
        log_probs = F.log_softmax(shift_logits, dim=-1)  # (batch_size, seq_length-1, vocab_size)
        
        # Gather log probabilities of the correct tokens
        shift_labels_flat = shift_labels.view(-1)
        log_probs_flat = log_probs.view(-1, log_probs.size(-1))
        token_log_probs = log_probs_flat[torch.arange(shift_labels_flat.size(0)), shift_labels_flat]  # (batch_size * (seq_length-1))
        token_log_probs = token_log_probs.view(shift_labels.size())  # (batch_size, seq_length-1)
        
        # Compute negative log-likelihood loss per token
        token_losses = -token_log_probs  # (batch_size, seq_length-1)
        
        # Mask padding tokens
        token_losses = token_losses * shift_attention  # Zero out losses for padding tokens
        print(token_losses)
        # Determine the threshold for top 30% losses
        # Compute the number of tokens to keep
        num_tokens = (shift_attention.sum()).item()
        if num_tokens == 0:
            continue  # Skip if no tokens to process
        k = int(num_tokens * TOP_PERCENT)
        if k == 0:
            k = 1  # Ensure at least one token is kept
        
        # Flatten the losses and filter out padding tokens
        losses_flat = token_losses.view(-1)
        attention_flat = shift_attention.view(-1)
        valid_losses = losses_flat[attention_flat == 1]
        
        if valid_losses.numel() == 0:
            continue  # Skip if no valid losses
        
        # Find the threshold
        threshold = torch.topk(valid_losses, k, largest=True, sorted=False).values.min()
        
        # Create a mask for top 30% losses
        mask = (token_losses >= threshold).float()
        
        # Apply the mask
        masked_losses = token_losses * mask
        
        # Compute the final loss
        if mask.sum() == 0:
            continue  # Avoid division by zero
        final_loss = masked_losses.sum() / mask.sum()
        
        # Backward pass
        final_loss.backward()
        epoch_loss += final_loss.item()
        
        # Gradient accumulation
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        progress_bar.set_postfix({"Loss": final_loss.item()})
    
    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f"Average Loss for Epoch {epoch + 1}: {avg_epoch_loss}")

# Optionally, save the model
# model.save_pretrained("llama_finetuned_wikimedia")

Epoch 1/3


  attn_output = torch.nn.functional.scaled_dot_product_attention(
Training:   0%|          | 0/1090 [00:10<?, ?it/s]

torch.Size([4, 512, 151936])
tensor([[[ 8.0257,  6.8792,  5.6502,  ..., -2.8272, -2.8270, -2.8272],
         [ 6.5129,  6.4042,  6.3487,  ..., -2.2473, -2.2474, -2.2472],
         [ 6.2590,  5.3727,  4.4408,  ..., -2.6565, -2.6559, -2.6566],
         ...,
         [11.2313, 13.6219, 15.3390,  ..., -3.8554, -3.8560, -3.8553],
         [11.5029, 13.5382, 15.4038,  ..., -3.8214, -3.8220, -3.8213],
         [11.7401, 13.6712, 15.9136,  ..., -3.7064, -3.7069, -3.7063]],

        [[ 7.4132,  7.1866,  5.4109,  ..., -2.6561, -2.6551, -2.6560],
         [ 4.9528, -4.0142,  2.1513,  ..., -1.9784, -1.9775, -1.9781],
         [ 6.3596,  6.1117,  5.5970,  ..., -5.7713, -5.7705, -5.7712],
         ...,
         [10.6778, 12.7759, 11.8910,  ..., -3.6970, -3.6975, -3.6970],
         [10.4127, 12.5988, 11.6687,  ..., -3.6961, -3.6964, -3.6961],
         [10.6908, 12.8864, 12.1103,  ..., -3.7378, -3.7381, -3.7378]],

        [[10.4748,  7.0472,  5.7916,  ..., -1.4985, -1.4971, -1.4984],
         [11.557




AttributeError: 'CausalLMOutputWithPast' object has no attribute 'labels'