In [2]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.53.1-py3-none-any.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Downloading tokenizers-0.21.2-cp39-abi3-macosx_11_0_arm64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl.metadata (3.8 kB)
Downloading transformers-4.53.1-py3-none-any.whl (10.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading tokenizers-0.21.2-cp39-abi3-macosx_11_0_arm64.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl (418 kB)
Installing collected packages: safetensors, tokenizers, transformers
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [transformers][0m 

In [16]:
!pip install evaluate

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting evaluate
  Downloading evaluate-0.4.4-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.4-py3-none-any.whl (84 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.4


In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    GPT2LMHeadModel, 
    GPT2Tokenizer, 
    AutoTokenizer,
    Trainer, 
    TrainingArguments, 
    DataCollatorForLanguageModeling
)
from datasets import load_dataset, Dataset as HFDataset
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')

class NextWordPredictor:
    def __init__(self, model_name="gpt2", max_length=128):
        self.model_name = model_name
        self.max_length = max_length
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.model.to(self.device)
        
    def load_and_preprocess_data(self, dataset_name="wikitext", dataset_config="wikitext-2-raw-v1"):
        print(f"Loading {dataset_name} dataset...")
        
        
        if dataset_name == "wikitext":
            dataset = load_dataset("wikitext", dataset_config)
        else: 
            dataset = load_dataset(dataset_name, dataset_config)
        
      
        def tokenize_function(examples):
            tokenized = self.tokenizer(
                examples['text'],
                truncation=True,
                padding=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            tokenized["labels"] = tokenized["input_ids"].clone()
            
            return tokenized
        
     
        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=dataset["train"].column_names
        )
        
       
        tokenized_dataset = tokenized_dataset.filter(
            lambda x: len(x["input_ids"]) > 1
        )
        
        self.train_dataset = tokenized_dataset["train"]
        self.val_dataset = tokenized_dataset["validation"]
        self.test_dataset = tokenized_dataset["test"]
        
        print(f"Dataset loaded successfully!")
        print(f"Train samples: {len(self.train_dataset)}")
        print(f"Validation samples: {len(self.val_dataset)}")
        print(f"Test samples: {len(self.test_dataset)}")
        
    def fine_tune_model(self, output_dir="./fine_tuned_gpt2", epochs=3, batch_size=8):
        
        print("Starting fine-tuning...")
        
      
        training_args = TrainingArguments(
            output_dir=output_dir,
            overwrite_output_dir=True,
            num_train_epochs=epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            eval_strategy="steps",
            eval_steps=500,
            save_steps=1000,
            warmup_steps=100,
            logging_steps=100,
            prediction_loss_only=True,
            save_total_limit=2,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            dataloader_pin_memory=False,
            report_to=None,  #
        )
        
        
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,  
        )
        
     
        trainer = Trainer(
            model=self.model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=self.train_dataset,
            eval_dataset=self.val_dataset,
        )
        
        
        trainer.train()
        
       
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)
        
        print(f"Fine-tuning completed! Model saved to {output_dir}")
        
    def calculate_perplexity(self, dataset, batch_size=8):
        
        self.model.eval()
        total_loss = 0
        total_tokens = 0
        
      
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,
        )
        
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=False,
            collate_fn=data_collator
        )
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Calculating perplexity"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
                total_loss += loss.item() * input_ids.size(0)
                total_tokens += input_ids.size(0)
        
        avg_loss = total_loss / total_tokens
        perplexity = torch.exp(torch.tensor(avg_loss))
        
        return perplexity.item()
    
    def calculate_top_k_accuracy(self, dataset, k=5, batch_size=8):
        self.model.eval()
        correct_predictions = 0
        total_predictions = 0
        
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,
        )
        
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=False,
            collate_fn=data_collator
        )
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc=f"Calculating top-{k} accuracy"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                
               
                for i in range(input_ids.size(1) - 1):
                    input_seq = input_ids[:, :i+1]
                    target_token = input_ids[:, i+1]
                    
                    outputs = self.model(input_seq)
                    logits = outputs.logits[:, -1, :]  # Get last token logits
                    
                    
                    top_k_tokens = torch.topk(logits, k, dim=-1).indices
                    
                    correct_predictions += (target_token.unsqueeze(1) == top_k_tokens).any(dim=1).sum().item()
                    total_predictions += input_ids.size(0)
        
        accuracy = correct_predictions / total_predictions
        return accuracy
    
    def predict_next_word(self, text, num_predictions=5, temperature=0.7):
        self.model.eval()
        
       
        inputs = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs.logits[0, -1, :]
            
            
            predictions = predictions / temperature
            
            
            probabilities = torch.softmax(predictions, dim=-1)
            
           
            top_predictions = torch.topk(probabilities, num_predictions)
            
            results = []
            for i in range(num_predictions):
                token_id = top_predictions.indices[i].item()
                probability = top_predictions.values[i].item()
                word = self.tokenizer.decode([token_id])
                results.append((word, probability))
            
            return results
    
    def generate_text(self, prompt, max_length=50, temperature=0.7, do_sample=True):
        self.model.eval()
        
        inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                inputs,
                max_length=max_length,
                temperature=temperature,
                do_sample=do_sample,
                pad_token_id=self.tokenizer.eos_token_id
            )
            
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return generated_text
    
    def evaluate_model(self):
        
        print("Evaluating model...")
        
     
        train_perplexity = self.calculate_perplexity(self.train_dataset)
        val_perplexity = self.calculate_perplexity(self.val_dataset)
        test_perplexity = self.calculate_perplexity(self.test_dataset)
        
     
        train_top1_acc = self.calculate_top_k_accuracy(self.train_dataset, k=1)
        train_top5_acc = self.calculate_top_k_accuracy(self.train_dataset, k=5)
        
        val_top1_acc = self.calculate_top_k_accuracy(self.val_dataset, k=1)
        val_top5_acc = self.calculate_top_k_accuracy(self.val_dataset, k=5)
        
        test_top1_acc = self.calculate_top_k_accuracy(self.test_dataset, k=1)
        test_top5_acc = self.calculate_top_k_accuracy(self.test_dataset, k=5)
        
        results = {
            'perplexity': {
                'train': train_perplexity,
                'validation': val_perplexity,
                'test': test_perplexity
            },
            'top1_accuracy': {
                'train': train_top1_acc,
                'validation': val_top1_acc,
                'test': test_top1_acc
            },
            'top5_accuracy': {
                'train': train_top5_acc,
                'validation': val_top5_acc,
                'test': test_top5_acc
            }
        }
        
        return results
    
    def plot_results(self, results):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
   
        datasets = ['train', 'validation', 'test']
        perplexities = [results['perplexity'][d] for d in datasets]
        
        axes[0].bar(datasets, perplexities, color=['blue', 'orange', 'green'])
        axes[0].set_title('Perplexity by Dataset')
        axes[0].set_ylabel('Perplexity')
        
  
        top1_accs = [results['top1_accuracy'][d] for d in datasets]
        axes[1].bar(datasets, top1_accs, color=['blue', 'orange', 'green'])
        axes[1].set_title('Top-1 Accuracy by Dataset')
        axes[1].set_ylabel('Accuracy')
        
  
        top5_accs = [results['top5_accuracy'][d] for d in datasets]
        axes[2].bar(datasets, top5_accs, color=['blue', 'orange', 'green'])
        axes[2].set_title('Top-5 Accuracy by Dataset')
        axes[2].set_ylabel('Accuracy')
        
        plt.tight_layout()
        plt.show()


def main():
   
    print(" Next Word Predictor using Transformers")
    print("Initializing predictor...")
    
    
    predictor = NextWordPredictor(model_name="gpt2", max_length=128)
    
  
    predictor.load_and_preprocess_data()
    
    # Fine-tune the model
    # uncomment to fine tune - it takes a lot of time
    # predictor.fine_tune_model(epochs=1, batch_size=4) 
    
  
    print("\n Next Word Prediction Demo")
    test_sentences = [
        "The weather today is",
        "Machine learning is a",
        "Python programming language",
        "The quick brown fox"
    ]
    
    for sentence in test_sentences:
        predictions = predictor.predict_next_word(sentence, num_predictions=5)
        print(f"\nInput: '{sentence}'")
        print("Top 5 predictions:")
        for i, (word, prob) in enumerate(predictions, 1):
            print(f"  {i}. '{word}' (probability: {prob:.4f})")
    
   
    print("\n Text Generation Demo ")
    prompts = [
        "The future of artificial intelligence",
        "In the world of technology",
        "Climate change is"
    ]
    
    for prompt in prompts:
        generated = predictor.generate_text(prompt, max_length=50)
        print(f"\nPrompt: '{prompt}'")
        print(f"Generated: '{generated}'")
    
    
    print("\n Model Evaluation")
    
    small_test = predictor.test_dataset.select(range(min(100, len(predictor.test_dataset))))
    
    test_perplexity = predictor.calculate_perplexity(small_test, batch_size=4)
    test_top1_acc = predictor.calculate_top_k_accuracy(small_test, k=1, batch_size=4)
    test_top5_acc = predictor.calculate_top_k_accuracy(small_test, k=5, batch_size=4)
    
    print(f"Test Perplexity: {test_perplexity:.4f}")
    print(f"Test Top-1 Accuracy: {test_top1_acc:.4f}")
    print(f"Test Top-5 Accuracy: {test_top5_acc:.4f}")
    
if __name__ == "__main__":
    main()

 Next Word Predictor using Transformers
Initializing predictor...
Loading wikitext dataset...


Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/36718 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset loaded successfully!
Train samples: 36718
Validation samples: 3760
Test samples: 4358

 Next Word Prediction Demo


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Input: 'The weather today is'
Top 5 predictions:
  1. ' very' (probability: 0.0982)
  2. ' good' (probability: 0.0655)
  3. ' pretty' (probability: 0.0608)
  4. ' a' (probability: 0.0592)
  5. ' not' (probability: 0.0463)

Input: 'Machine learning is a'
Top 5 predictions:
  1. ' very' (probability: 0.1373)
  2. ' great' (probability: 0.0949)
  3. ' big' (probability: 0.0461)
  4. ' new' (probability: 0.0424)
  5. ' powerful' (probability: 0.0403)

Input: 'Python programming language'
Top 5 predictions:
  1. '.' (probability: 0.5082)
  2. ',' (probability: 0.2904)
  3. ' is' (probability: 0.0422)
  4. ' and' (probability: 0.0288)
  5. ' that' (probability: 0.0171)

Input: 'The quick brown fox'
Top 5 predictions:
  1. 'es' (probability: 0.3929)
  2. ' was' (probability: 0.0885)
  3. ' is' (probability: 0.0590)
  4. ''s' (probability: 0.0560)
  5. ',' (probability: 0.0516)

 Text Generation Demo 

Prompt: 'The future of artificial intelligence'
Generated: 'The future of artificial intell

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Calculating perplexity: 100%|███████████████████| 25/25 [00:09<00:00,  2.74it/s]
Calculating top-1 accuracy: 100%|███████████████| 25/25 [10:50<00:00, 26.02s/it]
Calculating top-5 accuracy: 100%|███████████████| 25/25 [11:00<00:00, 26.44s/it]

Test Perplexity: 79.5863
Test Top-1 Accuracy: 0.0928
Test Top-5 Accuracy: 0.1679



