In [1]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig
import numpy as np
import random
import time
from tqdm import tqdm

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Load model and tokenizer
model_name = "xhan77/ssdlm"  # Replace with your model path if different
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name, config=config).to(device)
model.eval()

def generate_text_with_mlm(prompt, max_length=100, num_samples=1, top_p=0.92, temperature=0.8):
    """
    Generate text using a masked language model through iterative infilling.
    This is more suitable for MLMs than naive autoregressive generation.
    
    Args:
        prompt (str): The input prompt to condition generation on
        max_length (int): Maximum length of generated text
        num_samples (int): Number of samples to generate
        top_p (float): Nucleus sampling parameter
        temperature (float): Sampling temperature
    
    Returns:
        list: Generated text samples
    """
    generated_texts = []
    
    # Generate multiple samples if requested
    for _ in range(num_samples):
        # Start with the prompt
        tokens = tokenizer.encode(prompt, return_tensors='pt').to(device)
        
        # Add mask tokens to reach the desired length
        num_masks = min(max_length - tokens.size(1), 20)  # Add masks in batches
        mask_tokens = torch.full((1, num_masks), tokenizer.mask_token_id, dtype=torch.long, device=device)
        tokens = torch.cat([tokens, mask_tokens], dim=1)
        
        # Keep track of which positions are masked
        masked_positions = list(range(tokens.size(1) - num_masks, tokens.size(1)))
        
        # Fill in the masks left to right
        with tqdm(total=num_masks) as pbar:
            while masked_positions:
                # Always fill the leftmost mask first
                position = masked_positions[0]
                
                # Forward pass to get predictions for all positions
                with torch.no_grad():
                    outputs = model(tokens)
                    logits = outputs.logits[0, position]
                    
                    # Apply temperature
                    if temperature != 1.0:
                        logits = logits / temperature
                    
                    # Apply top-p (nucleus) sampling
                    if top_p < 1.0:
                        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                        cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
                        
                        # Remove tokens with cumulative probability above the threshold
                        sorted_indices_to_remove = cumulative_probs > top_p
                        # Shift the indices to the right to keep the first token above the threshold
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = 0
                        
                        # Apply the filtering
                        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
                        logits[indices_to_remove] = float('-inf')
                    
                    # Sample from the filtered distribution
                    probabilities = torch.nn.functional.softmax(logits, dim=-1)
                    next_token_id = torch.multinomial(probabilities, 1)
                    
                    # Replace the mask with the predicted token
                    tokens[0, position] = next_token_id
                
                # Remove the current position from masked_positions
                masked_positions.pop(0)
                pbar.update(1)
                
                # If we've filled all masks and still need more tokens, add more masks
                if not masked_positions and tokens.size(1) < max_length:
                    next_num_masks = min(max_length - tokens.size(1), 20)
                    if next_num_masks > 0:
                        mask_tokens = torch.full((1, next_num_masks), tokenizer.mask_token_id, 
                                                dtype=torch.long, device=device)
                        tokens = torch.cat([tokens, mask_tokens], dim=1)
                        masked_positions = list(range(tokens.size(1) - next_num_masks, tokens.size(1)))
                        pbar.total += next_num_masks
                
                # Early stopping if we see an EOS token
                if next_token_id.item() == tokenizer.eos_token_id:
                    break
        
        # Decode the generated tokens
        generated_text = tokenizer.decode(tokens[0], skip_special_tokens=True)
        generated_texts.append(generated_text)
    
    return generated_texts

# Example usage
if __name__ == "__main__":
    prompts = [
        "Once upon a time, there was a magical kingdom",
        "The future of artificial intelligence is",
        "In a world where technology has advanced beyond our wildest dreams,"
    ]
    
    print("Model and tokenizer loaded successfully!")
    
    # Try with different temperature and top_p settings
    for prompt in prompts:
        print(f"\nPrompt: {prompt}")
        print("Generated text:")
        
        # Try with conservative parameters
        print("\nConservative parameters (temp=0.7, top_p=0.85):")
        start_time = time.time()
        generated_texts = generate_text_with_mlm(prompt, max_length=50, num_samples=1, top_p=0.85, temperature=0.7)
        end_time = time.time()
        print(f"Generation took {end_time - start_time:.2f} seconds")
        for i, text in enumerate(generated_texts, 1):
            print(f"Sample {i}: {text}")
        
        # Try with more creative parameters
        print("\nCreative parameters (temp=1.0, top_p=0.95):")
        start_time = time.time()
        generated_texts = generate_text_with_mlm(prompt, max_length=50, num_samples=1, top_p=0.95, temperature=1.0)
        end_time = time.time()
        print(f"Generation took {end_time - start_time:.2f} seconds")
        for i, text in enumerate(generated_texts, 1):
            print(f"Sample {i}: {text}")
        
        # Generate a longer sample
        print("\nLonger text generation (temp=0.8, top_p=0.9):")
        start_time = time.time()
        generated_texts = generate_text_with_mlm(prompt, max_length=75, num_samples=1, top_p=0.9, temperature=0.8)
        end_time = time.time()
        print(f"Generation took {end_time - start_time:.2f} seconds")
        for i, text in enumerate(generated_texts, 1):
            print(f"Sample {i}: {text}")


Using device: cuda




Model and tokenizer loaded successfully!

Prompt: Once upon a time, there was a magical kingdom
Generated text:

Conservative parameters (temp=0.7, top_p=0.85):


100%|██████████| 38/38 [00:01<00:00, 29.60it/s]


Generation took 1.37 seconds
Sample 1: Once upon a time, there was a magical kingdom...... ........ jQuery. Slash Sauce Sauce Sauce.... jQuery Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce

Creative parameters (temp=1.0, top_p=0.95):


100%|██████████| 38/38 [00:01<00:00, 33.62it/s]


Generation took 1.13 seconds
Sample 1: Once upon a time, there was a magical kingdom....... . ...... ... ..
.....". .. jQueryVERTISEMENT�ÛÛ municip Canaver Canaver Canaver Canaver

Longer text generation (temp=0.8, top_p=0.9):


100%|██████████| 63/63 [00:01<00:00, 36.33it/s]


Generation took 1.74 seconds
Sample 1: Once upon a time, there was a magical kingdom...... . . ...... ESPN. Safari. SVG SVG... Urug UCH UCH UCH UCH UCH UCH Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib

Prompt: The future of artificial intelligence is
Generated text:

Conservative parameters (temp=0.7, top_p=0.85):


100%|██████████| 42/42 [00:01<00:00, 38.68it/s]


Generation took 1.09 seconds
Sample 1: The future of artificial intelligence is........... agriculture agriculture pdf. ebook. ebook..... genome...═.══════.═.════

Creative parameters (temp=1.0, top_p=0.95):


100%|██████████| 42/42 [00:01<00:00, 35.89it/s]


Generation took 1.17 seconds
Sample 1: The future of artificial intelligence is............ Gawker Gawker Gawker Gawker Gawker Gawker malaria Gawker Jones. Gonzalez. Gonzalez...MpServer.. ``(ADVERTISEMENT Gawker Gawker Gawker Gawker Gawker Gonzalez Gauntlet Gawker

Longer text generation (temp=0.8, top_p=0.9):


100%|██████████| 67/67 [00:01<00:00, 37.18it/s]


Generation took 1.80 seconds
Sample 1: The future of artificial intelligence is........... Gawker Gawker Gawker Gawker Gawker Gawker Gawker Gawker Gawker. . .

... ._. ._. .......... .......... .......... .......... .......... Sauce Sauce Sauce vegetable downloadable sauces vegetable sauces sauces Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce

Prompt: In a world where technology has advanced beyond our wildest dreams,
Generated text:

Conservative parameters (temp=0.7, top_p=0.85):


100%|██████████| 35/35 [00:00<00:00, 40.43it/s]


Generation took 0.87 seconds
Sample 1: In a world where technology has advanced beyond our wildest dreams,............... jQuery SVG SVG SVG SVG Urug Urug CSV Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib Scrib

Creative parameters (temp=1.0, top_p=0.95):


100%|██████████| 35/35 [00:00<00:00, 40.37it/s]


Generation took 0.87 seconds
Sample 1: In a world where technology has advanced beyond our wildest dreams,................ url url aspect url adhesive certain adhesive conservative pdf html html HTML HTML HTML HTML HTML HTML HTML HTML

Longer text generation (temp=0.8, top_p=0.9):


100%|██████████| 60/60 [00:01<00:00, 36.59it/s]

Generation took 1.64 seconds
Sample 1: In a world where technology has advanced beyond our wildest dreams,............... gif tab tab foreskin Scrib url saliva Arduino Arduino Arduino Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce Sauce



