### Pipeline

1) Generating images
2) Generating text images with empty background
3) Combining images with texts using replicate
4) Combining new images into a video using shot stack


In [None]:
# !pip install -q transformers accelerate torch
# !pip install -q diffusers==0.31.0 safetensors accelerate
# !nvidia-smi || true
# !pip install pillow
# !pip install sacremoses
# !pip install replicate

# # Fonts
# !sudo apt -y update >/dev/null
# !sudo apt -y install fonts-dejavu fonts-freefont-ttf fonts-noto-core >/dev/null
# !wget -q https://github.com/google/fonts/raw/main/ofl/lobster/Lobster-Regular.ttf -O /usr/share/fonts/truetype/lobster.ttf
# !wget -q https://github.com/google/fonts/raw/main/ofl/montserrat/Montserrat-Bold.ttf -O /usr/share/fonts/truetype/montserrat-bold.ttf
# !fc-cache -f

# # Fonts
# !apt-get -y update -qq
# !apt-get -y install -qq fonts-dejavu fonts-dejavu-core fonts-dejavu-extra fonts-freefont-ttf fonts-noto-core
# !fc-cache -f

import re, torch, os, random, glob, unicodedata as ud
import asyncio
import replicate
import requests
import json
from typing import List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from diffusers import StableDiffusionXLPipeline
from IPython.display import display
from PIL import Image, ImageDraw, ImageFont
from pathlib import Path
from uuid import UUID, uuid4

In [None]:
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE    = torch.float16 if DEVICE == "cuda" else torch.float32
W, H     = 512, 512
LORA_DIR = ""
WT_NAME  = "pytorch_lora_weights.safetensors"
BASE     = "Manojb/stable-diffusion-2-1-base"

ruen_tok = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
ruen_mt  = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en").to(DEVICE).eval()

@torch.inference_mode()
def ru2en(text: str) -> str:
    enc = ruen_tok(text.strip(), return_tensors="pt").to(DEVICE)
    out = ruen_mt.generate(**enc, max_new_tokens=64)
    return ruen_tok.batch_decode(out, skip_special_tokens=True)[0].strip()

CYR = re.compile(r"[А-Яа-яЁё]")
def to_english(s: str) -> str:
    return ru2en(s) if CYR.search(s) else s.strip()

q_model_id = "Qwen/Qwen2.5-1.5B-Instruct"
q_tok = AutoTokenizer.from_pretrained(q_model_id)
q_llm = AutoModelForCausalLM.from_pretrained(q_model_id, torch_dtype=DTYPE).to(DEVICE).eval()

SYSTEM = (
    "You rewrite a short holiday/topic into three concise Stable Diffusion XL prompts for a festive POSTCARD.\n"
    "- Keep each line concrete and short (<= 48 words).\n"
    "- Focus on specific subjects and simple composition (what/where/background).\n"
    "- Allowed: pastel background, soft bokeh, glitter sparkles, glossy highlights, photomontage postcard style.\n"
    "- Must include: 'postcard, no text'.\n"
    "- FORBIDDEN: masterpiece, 8k, absurdres, by <artist>, nsfw, anatomy, camera brands, friends, family, group, people, girl, boy, man, woman\n"
    "- Avoid winter/snow/Christmas motifs unless the topic explicitly mentions them.\n"
    "- Output exactly three lines, no numbering."
)

EXAMPLES = (
    "Examples:\n"
    "Valentine's Day -> picture of two kittens sitting inside a big pink heart, sparkles and roses, pastel background, no text\n"
    "International Women's Day -> pink roses and satin ribbon arranged neatly, pastel gradient background, soft bokeh, postcard, no text\n"
    "Easter -> yellow chick emerging from decorated egg shell, surrounded by spring flowers, pastel pink and green background, glitter sparkles, glossy highlights, photomontage postcard style, no humans, no text\n"
    "Christmas -> white dove flying over red roses, glowing sunset, sparkles, high saturation, no text\n"
)

def normalize_topic(en_text: str) -> str:
    t = en_text.strip().lower()
    if t in {"8 march", "march 8", "8th march"}:
        return "International Women's Day"
    return en_text.strip()

def build_chat_inputs(topic_en: str):
    messages = [
        {"role": "system", "content": SYSTEM + "\n" + EXAMPLES},
        {"role": "user",   "content": f"Topic: {topic_en}\nReturn 3 lines, one per prompt."}
    ]

    prompt_text = q_tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    enc = q_tok(prompt_text, return_tensors="pt", padding=True).to(DEVICE)
    return enc

BLACK = [
    r"\bmasterpiece\b", r"\b8k\b", r"\babsurdres\b",
    r"\b(anatomy|nsfw)\b", r"\bby\s+[A-Z][a-z]+\b", r"\bdslr\b"
]

def sanitize_line(line: str, max_words=48) -> str:
    t = line.strip()
    t = re.sub(r"^\s*\d+\s*[\.\)\-:]\s*", "", t)
    for patt in BLACK:
        t = re.sub(patt, "", t, flags=re.IGNORECASE)
    t = re.sub(r"\s*,\s*", ", ", t)
    t = re.sub(r"(,\s*){2,}", ", ", t).strip(" ,")
    low = t.lower()
    if "postcard" not in low: t += ", postcard"
    if "no text" not in low:  t += ", no text"
    words = t.split()
    if len(words) > max_words: t = " ".join(words[:max_words])
    return t

@torch.inference_mode()
def make_three_prompts(user_text: str, temperature=0.2) -> dict:
    topic_en = normalize_topic(to_english(user_text))
    enc = build_chat_inputs(topic_en)
    out = q_llm.generate(
        **enc, max_new_tokens=200, do_sample=True, temperature=temperature,
        top_p=0.8, repetition_penalty=1.05,
        eos_token_id=q_tok.eos_token_id, pad_token_id=q_tok.eos_token_id
    )
    raw = q_tok.decode(out[0, enc["input_ids"].shape[-1]:], skip_special_tokens=True)
    lines = [s.strip() for s in raw.splitlines() if s.strip()]
    lines = lines[:3] if len(lines) >= 3 else (lines + [""]*(3-len(lines)))
    clean = [sanitize_line(l) if l else
             "festive postcard scene with fitting symbols, pastel background, soft bokeh, glitter sparkles, postcard, no text"
             for l in lines]
    neg = "text, letters, watermark, logo, low quality, blurry, jpeg artifacts, duplicates, extra limbs, extra heads, people, bad anatomy"
    return {"topic_en": topic_en, "prompts": clean, "negative": neg}



pipe = StableDiffusionPipeline.from_pretrained(
    BASE, torch_dtype=torch.float16
).to(DEVICE)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)


pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

# LoRA
pipe.load_lora_weights(LORA_DIR, weight_name=WT_NAME)


def gen(prompt: str, outfile: str, negative: str,
        steps=24, cfg=6.2, w=512, h=512, seed=1234):
    g = torch.Generator(device=DEVICE).manual_seed(seed)
    image = pipe(
        prompt=prompt, negative_prompt=negative,
        num_inference_steps=steps, guidance_scale=cfg,
        width=w, height=h, generator=g
    ).images[0]
    image.save(outfile)
    return image

In [None]:
# Images
INPUT_PATH = "../resources/"
user_text = "Новый год" # input()


pack = make_three_prompts(user_text)
print("EN topic:", pack["topic_en"])
for i, p in enumerate(pack["prompts"], 1):
    print(f"Prompt {i}: {p}")
print("NEG:", pack["negative"])

imgs = []
image_filenames = []
for i, p in enumerate(pack["prompts"], 1):
    try:
        file_name = INPUT_PATH + uuid4().hex + ".png"
        img = gen(p, file_name, negative=pack["negative"], seed=1000+i)
        image_filenames.append(file_name)
        imgs.append(img); display(img)
    except Exception as e:
        print(e)

In [None]:
# Texts

MODEL_NAME = os.environ.get("SLOGAN_LLM", "Qwen/Qwen2.5-1.5B-Instruct")
K_OUT      = 5
SEED       = 23
BG_DEFAULT = "transparent"      # "transparent" | "white"
PALETTE    = [(231,76,60),(46,204,113),(52,152,219),(155,89,182),(241,196,15),(20,20,20)]

SYS = (
    "You output ultra-short keywords. "
    "Language: {lang}. One variant per line, no numbering. "
    "1 to {maxw} words. No punctuation, quotes, emoji, hashtags."
)

SAFE_FAMILIES = ("DejaVuSans", "NotoSans")

random.seed(SEED); torch.manual_seed(SEED)

def collect_fonts() -> list[Path]:
    roots = ("/usr/share/fonts", "/usr/local/share/fonts")
    keep = []
    for d in roots:
        keep += [Path(p) for p in glob.glob(os.path.join(d, "**", "DejaVuSans*.ttf"), recursive=True)]
        keep += [Path(p) for p in glob.glob(os.path.join(d, "**", "NotoSans*.ttf"),   recursive=True)]
    order = {"DejaVuSans.ttf":0, "DejaVuSans-Bold.ttf":1, "DejaVuSans-Oblique.ttf":2,
             "NotoSans-Regular.ttf":3, "NotoSans-Bold.ttf":4, "NotoSans-Italic.ttf":5}
    keep = [p for p in keep if p.exists()]
    keep.sort(key=lambda p: (order.get(p.name, 99), p.name))
    return keep


def font_supports_text(font_path: Path, text: str) -> bool:
    try:
        f = ImageFont.truetype(str(font_path), size=28)
    except Exception:
        return False
    for ch in text:
        if ch in " -":
            continue
        try:
            bbox = f.getbbox(ch)
        except Exception:
            try:
                w,h = f.getsize(ch); bbox = (0,0,w,h)
            except Exception:
                return False
        if not bbox or (bbox[2]-bbox[0]) <= 0 or (bbox[3]-bbox[1]) <= 0:
            return False
    return True

def load_llm():
    tok = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, torch_dtype=DTYPE, device_map="auto", trust_remote_code=True
    )
    return tok, model

def detect_lang(text: str) -> str:
    return "ru" if re.search(r"[\u0400-\u04FF]", text) else "en"

def sanitize(text: str, lang: str) -> str:
    t = ud.normalize("NFKC", text)
    t = t.replace("\u00A0", " ").replace("\u2009", " ").replace("\u202F", " ")
    t = re.sub(r"[\"""'«».,:;!?/\\()\[\]{}*_+=~^`|]", "", t)
    t = re.sub(r"\s+", " ", t).strip()
    allow = r"[^\u0400-\u04FF0-9\- ]" if lang == "ru" else r"[^A-Za-z0-9\- ]"
    t = re.sub(allow, "", t)
    t = re.sub(r"\s+", " ", t).strip()
    if re.search(r"\d{4}", t):
        t = re.sub(r"\d{4}", "", t).strip()
    return t

def gen_phrases(topic: str, lang: str, k: int) -> list[str]:
    tok, model = load_llm()
    RX = re.compile(r"^[\u0400-\u04FF0-9\- ]{1,40}$") if lang=="ru" else re.compile(r"^[A-Za-z0-9\- ]{1,40}$")

    examples = {
        "ru": "Примеры:\nС Новым годом\nС Рождеством\nЗимняя сказка\nТёплых праздников",
        "en": "Examples:\nMerry Christmas\nHappy New Year\nHoliday cheer\nWarm wishes",
    }
    script_rule = "Use only Cyrillic letters (А-Я, а-я)." if lang=="ru" else "Use only Latin letters (A–Z, a–z)."

    def ask_one() -> str:
        sys_prompt = (
            f"Generate ONE ultra-short greeting/tagline. Language: {lang}. "
            "1–4 words. No quotes, punctuation, emoji, hashtags. "
            + script_rule + "\n" + examples[lang]
        )
        user = f"Topic: {topic}\nReturn only the phrase on the first line."
        if hasattr(tok, "apply_chat_template"):
            prompt = tok.apply_chat_template(
                [{"role":"system","content":sys_prompt},
                 {"role":"user","content":user}],
                tokenize=False, add_generation_prompt=True
            )
        else:
            prompt = sys_prompt + "\n" + user + "\nAssistant:\n"

        enc = tok(prompt, return_tensors="pt"); enc = {k:v.to(model.device) for k,v in enc.items()}
        out = model.generate(**enc, max_new_tokens=48, temperature=0.6, top_p=0.85,
                             do_sample=True, repetition_penalty=1.08,
                             eos_token_id=getattr(tok,"eos_token_id",None))
        new = out[0, enc["input_ids"].shape[1]:]
        raw = tok.decode(new, skip_special_tokens=True)
        s = raw.splitlines()[0].strip() or re.split(r"[,\uFF0C;]+", raw)[0].strip()
        s = ud.normalize("NFKC", s)
        s = re.sub(r"[\"""'«».,:;!?/\\()\[\]{}*_+=~^`|]", "", s)
        s = re.sub(r"\s+", " ", s).strip()
        s = re.sub(r"[^\u0400-\u04FF0-9\- ]","",s) if lang=="ru" else re.sub(r"[^A-Za-z0-9\- ]","",s)
        s = re.sub(r"\s+"," ",s).strip()
        return s

    res, seen, tries = [], set(), 0
    topic_norm = re.sub(r"\s+"," ", topic.lower()).strip()
    while len(res) < k and tries < k*10:
        tries += 1
        s = ask_one()
        if not s or not RX.match(s):
            continue
        if not (1 <= len(s.split()) <= 4):
            continue
        key = s.lower()
        if key == topic_norm or key in {"system","assistant","user","topic"}:
            continue
        if key in seen:
            continue
        seen.add(key); res.append(s)
    return res


def fit_font(font_path: Path, text: str, draw: ImageDraw.ImageDraw, w: int, h: int):
    lo, hi, best = 14, 260, None
    while lo <= hi:
        mid = (lo + hi)//2
        try:
            f = ImageFont.truetype(str(font_path), size=mid)
        except Exception:
            f = ImageFont.load_default()
        x0,y0,x1,y1 = draw.textbbox((0,0), text, font=f)
        if (x1-x0) <= w and (y1-y0) <= h:
            best = f; lo = mid + 1
        else:
            hi = mid - 1
    return best or ImageFont.load_default()

def render_phrase(text: str, font_path: Path, color=(20,20,20), bg="transparent"):
    fonts = collect_fonts()
    use = next((p for p in fonts if p.name.endswith("Bold.ttf")), fonts[0])

    mode = "RGBA" if bg=="transparent" else "RGB"
    img  = Image.new(mode, (W,H), (0,0,0,0) if mode=="RGBA" else (255,255,255))
    draw = ImageDraw.Draw(img)
    font = fit_font(use, text, draw, W-60, H-40)
    x0,y0,x1,y1 = draw.textbbox((0,0), text, font=font)
    x,y = (W-(x1-x0))//2, (H-(y1-y0))//2
    stroke = (0,0,0,200) if mode=="RGBA" else (0,0,0)
    draw.text((x,y), text, font=font, fill=color, stroke_width=2, stroke_fill=stroke)

    file_name = uuid4().hex
    file_path = os.path.join(INPUT_PATH, file_name + '.png')
    img.save(file_path)
    return img, file_path


# main function
def generate_and_display(topic: str, lang: str="auto", bg: str=BG_DEFAULT, k: int=K_OUT):
    fonts = collect_fonts()
    if not fonts:
        raise SystemExit("No fonts found. Install DejaVu/Noto.")
    lang = detect_lang(topic) if lang=="auto" else lang

    phrases = gen_phrases(topic, lang, k)
    if len(phrases) < k:
        need = k - len(phrases)
        extra = gen_phrases(topic + " variations", lang, need)
        for p in extra:
            if p.lower() not in {x.lower() for x in phrases}:
                phrases.append(p)
            if len(phrases) == k:
                break
    if not phrases:
        phrases = [sanitize(topic, lang)]

    base = next((p for p in fonts if "DejaVuSans" in p.name), fonts[0])
    alt  = next((p for p in fonts if "NotoSans"  in p.name), fonts[min(1, len(fonts)-1)])
    font_pool = [base, alt] + random.sample(fonts, k=min(3, len(fonts)))

    text_filenames = []
    for t in phrases[:k]:
        img, file_path = render_phrase(t, random.choice(font_pool), color=random.choice(PALETTE), bg=bg)
        display(img)
        text_filenames.append(file_path)
    return text_filenames

In [None]:
# Making empty images with text
text_filenames = generate_and_display(user_text, k=5)

In [None]:
COMBINATION_PROMPT = "Add an inscription from the second image to the top or to the bottom of the first image. Do not add any extra text " \
"that is not on my images. You can move the text freely to the best location. The result should look like a card my grandmother might send me."

#TODO: Change aspect_ratio
async def add_text_to_image(text_image_path: str, 
                      main_image_path: str, 
                      output_dir: str = '../image_outputs',
                      output_size: str = '1K', 
                      aspect_ratio: str = '1:1',
                      save_locally: bool = True) -> tuple[str, str]:
    """
    Combine one main image with one text image (async version).
    Must have REPLICATE_API_TOKEN=... in the .env
    """

    # Run file I/O and replicate API calls in executor to avoid blocking
    loop = asyncio.get_event_loop()
    
    # Upload images to replicate to be able to use them
    def upload_images():
        with open(text_image_path, "rb") as file:
            text_image = replicate.files.create(file)
        with open(main_image_path, 'rb') as file:
            generated_image = replicate.files.create(file)
        return text_image, generated_image
    
    text_image, generated_image = await loop.run_in_executor(None, upload_images)

    input_data = {
        "size": output_size,
        "prompt": COMBINATION_PROMPT, 
        "aspect_ratio": aspect_ratio,
        "image_input": [generated_image.urls['get'], text_image.urls['get']]
    }

    # Run replicate.run in executor
    def run_replicate():
        return replicate.run("bytedance/seedream-4", input=input_data)
    
    output = await loop.run_in_executor(None, run_replicate)

    # Write the files to disk:
    local_filename = ""
    if save_locally:
        def save_output():
            for index, item in enumerate(output):
                file_name = uuid4().hex
                with open(os.path.join(output_dir, file_name + '.jpg'), "wb") as file:
                    file.write(item.read())
                    return file.name
            return ""
        
        local_filename = await loop.run_in_executor(None, save_output)
            
    return str(output[0]), local_filename

In [None]:
# Will run for up to a couple of minutes
results = await asyncio.gather(*[
    add_text_to_image(text_filenames[i], image_filenames[i], output_dir="../image_outputs") 
    for i in range(len(text_filenames))
])

# Unpack tuples into two separate lists
images_with_text = [url for url, _ in results]
images_with_text_paths = [path for _, path in results]

print(f"Generated {len(images_with_text)} images")
print("URLs:", images_with_text)
print("Local paths:", images_with_text_paths)

In [None]:
images_with_text

['https://replicate.delivery/xezq/84T4kLMzec3PcCP634DFdaqseBOH3BViJMfzqSfkzRm0bAzVB/tmpt48oksj1.jpg',
 'https://replicate.delivery/xezq/kimVbGRzidKaH1c3Xh6AxDkM1dZjW0QDVhCgiCmZ6VepDYuKA/tmpj7womds5.jpg',
 'https://replicate.delivery/xezq/chDiVJbKeQSEMSFe4IaeeHWyAKclwZXOeQ2LpNo702Cm9AmrC/tmpx0rq582q.jpg',
 'https://replicate.delivery/xezq/e1XHxROiewrFQ0Kj7OznK18qMz4QTOS08eWC2sSH1hWIQg5qA/tmpob44khff.jpg']

In [None]:
# Combining images into a video
API_KEY = os.environ['SHOT_STACK_API_TOKEN']

# for production use "https://api.shotstack.io/v1/render"
url = "https://api.shotstack.io/stage/render"
headers = {
    "x-api-key": API_KEY,
    "Content-Type": "application/json"
}

#TODO: Change aspect_ratio
payload = {
    "timeline": {
        "tracks": [
            {
                "clips": [
                    {
                        "asset": {"type": "image", "src": images_with_text[0]},
                        "start": 0,
                        "length": 3,
                        "transition": {"out": "slideRight"}
                    },
                    {
                        "asset": {"type": "image", "src": images_with_text[1]},
                        "start": 3,
                        "length": 3,
                        "transition": {"in": "wipeLeft", "out": "fade"}
                    },
                    {
                        "asset": {"type": "image", "src": images_with_text[2]},
                        "start": 6,
                        "length": 3,
                        "transition": {"in": "slideUp", "out": "slideDown"}
                    },

                ]
            }
        ]
    },
    "output": {"format": "mp4", "resolution": "hd", "aspectRatio": "1:1"}
}

response = requests.post(url, headers=headers, data=json.dumps(payload))
print(response.status_code, response.text)


201 {"success":true,"message":"Created","response":{"message":"Render Successfully Queued","id":"8ec53e29-1721-4685-88cc-c0afcf2df688"}}


In [None]:
# Wait a bit before running
RENDER_ID = json.loads(response.text)['response']['id']

# for production use f"https://api.shotstack.io/v1/render/{RENDER_ID}"
url = "https://api.shotstack.io/stage/render"
r = requests.get(url + f"/{RENDER_ID}", headers=headers)

print(r.json())
print("Video url:", r.json()['response']['url'])

{'success': True, 'message': 'OK', 'response': {'id': '8ec53e29-1721-4685-88cc-c0afcf2df688', 'owner': '1mj5p7ly7k', 'plan': 'freeTrial', 'status': 'done', 'error': '', 'duration': 12, 'billable': 12, 'renderTime': 3708.53, 'url': 'https://shotstack-api-stage-output.s3-ap-southeast-2.amazonaws.com/1mj5p7ly7k/8ec53e29-1721-4685-88cc-c0afcf2df688.mp4', 'poster': None, 'thumbnail': None, 'data': {'output': {'format': 'mp4', 'resolution': 'hd', 'aspectRatio': '4:3'}, 'timeline': {'tracks': [{'clips': [{'start': 0, 'length': 3, 'asset': {'type': 'image', 'src': 'https://replicate.delivery/xezq/84T4kLMzec3PcCP634DFdaqseBOH3BViJMfzqSfkzRm0bAzVB/tmpt48oksj1.jpg'}, 'transition': {'out': 'slideRight'}}, {'start': 3, 'length': 3, 'asset': {'type': 'image', 'src': 'https://replicate.delivery/xezq/kimVbGRzidKaH1c3Xh6AxDkM1dZjW0QDVhCgiCmZ6VepDYuKA/tmpj7womds5.jpg'}, 'transition': {'in': 'wipeLeft', 'out': 'fade'}}, {'start': 6, 'length': 3, 'asset': {'type': 'image', 'src': 'https://replicate.delive