## Text2Picture Generation, using Deepseek's Janus.

Created by: [Alex Jenkins](https://alexj.io)

Professors: [Dr. Francesco Fedele](https://ce.gatech.edu/directory/person/francesco-fedele) and [Dr. Mark Leibert](https://lmc.gatech.edu/people/person/mark-leibert)

### MAKE SURE TO RUN WITH THE T4 GPU UNDER RUNTIME OPTIONS!

#### Version: 02/06/2025

Model Card: [View here](https://huggingface.co/deepseek-ai/Janus-Pro-1B).

Copyright (c) 2025, [Georgia Institute of Technology](https://www.gatech.edu).


In [None]:
#@title 1). Install packages and download the model
#@markdown ⬅️ Press to install and prepare the model.
#@markdown Settings will appear once finished.

! git clone https://github.com/deepseek-ai/Janus
%cd Janus
! pip install -e .

import os
import torch
import PIL.Image
import numpy as np
import ipywidgets as widgets
from transformers import AutoModel
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from IPython.display import display

# Define model path
model_path = "deepseek-ai/Janus-Pro-1B"

# Load processor and tokenizer
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

# Load model with remote code enabled
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

# Define sliders for adjustable parameters
temperature_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=2.0, step=0.1, description="Temperature"
)
parallel_size_slider = widgets.IntSlider(
    value=2, min=1, max=4, step=1, description="Parallel Size"
)
cfg_weight_slider = widgets.FloatSlider(
    value=5.0, min=1.0, max=10.0, step=0.5, description="CFG Weight"
)
image_token_num_slider = widgets.IntSlider(
    value=576, min=128, max=1024, step=64, description="Image Tokens"
)
img_size_slider = widgets.IntSlider(
    value=384, min=256, max=512, step=64, description="Image Size"
)
patch_size_slider = widgets.IntSlider(
    value=16, min=8, max=32, step=8, description="Patch Size"
)

# Display sliders
display(
    temperature_slider,
    parallel_size_slider,
    cfg_weight_slider,
    image_token_num_slider,
    img_size_slider,
    patch_size_slider,
)

In [None]:
#@title 2). Run the model (ETA: ~30 seconds per image)
#@markdown ⬅️ Press to run model, you will be asked for the prompt once loaded.

import ipywidgets as widgets
from IPython.display import display, Image
import io

userInputPrompt = input("Input your prompt:\n")

conversation = [
    {
        "role": "<|User|>",
        "content": userInputPrompt,
    },
    {"role": "<|Assistant|>", "content": ""}
]

sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_start_tag

@torch.inference_mode()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    temperature: float,
    parallel_size: int,
    cfg_weight: float,
    image_token_num_per_image: int,
    img_size: int,
    patch_size: int,
):
    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids)

    tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()

    for i in range(parallel_size*2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id

    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state

        logits = mmgpt.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]

        logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)

    dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    # Display images in Jupyter Notebook
    image_widgets = []
    for i in range(parallel_size):
        img = PIL.Image.fromarray(dec[i].astype(np.uint8))
        img_buffer = io.BytesIO()
        img.save(img_buffer, format="JPEG")
        img_widget = widgets.Image(value=img_buffer.getvalue(), format='jpg')
        image_widgets.append(img_widget)

    display(widgets.HBox(image_widgets))

generate(
    vl_gpt,
    vl_chat_processor,
    prompt,
    temperature_slider.value,
    parallel_size_slider.value,
    cfg_weight_slider.value,
    image_token_num_slider.value,
    img_size_slider.value,
    patch_size_slider.value,
)

# Garbage collection to free up resources
import gc
gc.collect()
torch.cuda.empty_cache()
