In [1]:
import re
import csv
import random
from typing import List, Dict

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.pipelines.text_generation import TextGenerationPipeline

In [2]:
class CFG:
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    magic_promt_hf_repo = "Gustavosta/MagicPrompt-Stable-Diffusion"
    pipeline_task = "text-generation"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(CFG.magic_promt_hf_repo)
model = AutoModelForCausalLM.from_pretrained(CFG.magic_promt_hf_repo)
model.to(CFG.device)

print("Device:", CFG.device)

Downloading:   0%|          | 0.00/255 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.01M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/912 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/487M [00:00<?, ?B/s]

Device: cuda:0


In [4]:
generation_pipe = pipeline(CFG.pipeline_task, model=model, tokenizer=tokenizer)

In [5]:
class PromptGenerator:
    def __init__(
        self, generation_pipe: TextGenerationPipeline,
        promts_count: int = 1_000_000, batch_size: int = 500
    ) -> None:
        
        self.generation_pipe: TextGenerationPipeline = generation_pipe
        self.promts_count: int = promts_count
        self.batch_size: int = batch_size
        
        self.initial_text: str = ""
        
        # lenght in symbols
        self.bot_len_limit: int = 30
        self.top_len_limit: int = 100
    
    
    def process_response(self, response_promt: Dict) -> str:
        resp = response_promt['generated_text'].strip()
        resp = re.sub('[^ ]+\.[^ ]+','', resp)
        resp = resp.replace("\n", "").replace("<", "").replace(">", "")
        resp = resp.strip(" .,")

        return resp
    
    
    def __call__(self) -> List[Dict]:
        prompts_list: List[Dict] = []
        while len(prompts_list) < self.promts_count:
            response = self.generation_pipe(
                self.initial_text, 
                max_length=(random.randint(self.bot_len_limit, self.top_len_limit)), 
                num_return_sequences=self.batch_size
            )

            for response_promt in response:
                processed_response_promt = self.process_response(response_promt)
                if processed_response_promt:
                    prompts_list.append({
                        "prompt": processed_response_promt
                    })

        return prompts_list[:self.promts_count]

In [6]:
prompt_generator = PromptGenerator(generation_pipe, promts_count=5000000)

In [7]:
prompts = prompt_generator()

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generati

In [8]:
def limited_prompt_print(prompts: List[Dict], limit: int = 10):
    for i, prompt in enumerate(prompts, start=1):
        print(f"№{i} --", prompt["prompt"], "\n")

        if i == limit:
            break

In [9]:
limited_prompt_print(prompts, limit=10)

№1 -- y'shtola casting a spell, d & d, fantasy, portrait, highly detailed, headshot, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and magali vill 

№2 -- a fantasy style portrait painting of young mila kunis oil painting unreal 5 daz. rpg portrait, extremely detailed artgerm greg rutkowski greg hildebrandt tim hildebrandt 

№3 -- t - 8 0 0 photographic upper body portrait of a young handsome dark haired halfRobin devil, throne, decorated with black and white chryslerian patterns, as seen on artgerm, octane render, in the style of alphonse mucha 

№4 -- an achingly beautiful print of a Saturn V rocket lifting off from the launchpad by Raphael, Hopper, and Rene Magritte. detailed, romantic, enchanting, trending on artstation 

№5 -- in a dark corridor, a human made out of black hole energy flows of and energy, | gapmoe kuudere moody lighting stunning bokeh highlights sharp contrast | trending pixiv fanbox | by greg ru

In [10]:
def save_prompts_to_csv(prompts: List[Dict], file_name: str = "prompts.csv") -> None:
    with open("prompts.csv", "w") as f:
        w = csv.DictWriter(f, ["prompt"])
        w.writeheader()
        w.writerows(prompts)

In [11]:
save_prompts_to_csv(prompts)