In [1]:
# add janus to os path
import sys
import os
from typing import List

janus_path = os.path.abspath("../Janus/janus")
sys.path.append(janus_path)

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [2]:
from abc import ABC

import numpy as np
import torch
from imagebind.models.imagebind_model import ImageBindModel
from imagebind import data
from torch import nn
from PIL import Image

from transformers import AutoModelForCausalLM

from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.models.processing_vlm import VLChatProcessorOutput




Python version is above 3.10, patching the collections module.




# Audio Projection

In [3]:
class AudioProjection(nn.Module):
    def __init__(self, input_dim, output_dim, sequal_len=32, scale_factor=2):
        super(AudioProjection, self).__init__()
        self.scale_factor = scale_factor
        self.sequal_len = sequal_len
        self.output_dim = output_dim
        self.fc1 = nn.Linear(input_dim, scale_factor * output_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(scale_factor * output_dim, sequal_len * output_dim)

    def forward(self, x):
        B = x.shape[0]
        x = self.fc1(x)  # → [B, scale_factor * output_dim]
        x = self.act(x)
        x = self.fc2(x)  # → [B, sequal_len * output_dim]
        x = torch.reshape(x, (B, self.sequal_len, self.output_dim))

        return x

In [4]:
x = torch.randn(2, 1024)
with torch.no_grad():
    proj = AudioProjection(1024, 2048, scale_factor=2)
    res = proj(x)
print(res.shape)

torch.Size([2, 32, 2048])


# Janus Image Generator

In [None]:
AVAILABLE_MODELS = ["Janus-1.3B", "Janus-Pro-1B", "JanusFlow-1.3B"]


class JanusImageGenerator:
    name = "Janus Image Generator"

    @classmethod
    def init_model(cls, config):
        model_name = config.get("model_name", "Janus-Pro-1B")
        cls.sys_prompt = config.get("sys_prompt", "Abstract art for representing emotions")
        if model_name not in AVAILABLE_MODELS:
            logging.warning(f"Model {model_name} not available. Using {AVAILABLE_MODELS[0]} instead.")
            model_name = AVAILABLE_MODELS[0]

        model_path = f"deepseek-ai/{model_name}"
        cls.vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
        cls.tokenizer = cls.vl_chat_processor.tokenizer

        cls.vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
            model_path, trust_remote_code=True
        )
        cls.vl_gpt = cls.vl_gpt.to(torch.bfloat16).cuda().eval()

        cls.model = cls.vl_gpt
        cls.audio_embeds_shape = {0: 2, 2: 2048}
        cls.audio_embeds_type = torch.bfloat16
        # hardcoded need to learn more about this
        cls.parallel_size = 1
        cls.img_size = 384
        cls.patch_size = 16

    @classmethod
    def _preprocess_input(cls, inputs):
        conversation = [
            {
                "role": "User",
                "content": inputs,
            },
            {"role": "Assistant", "content": ""},
        ]

        sft_format = cls.vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
            conversations=conversation,
            sft_format=cls.vl_chat_processor.sft_format,
            system_prompt="",
        )

        return sft_format + cls.vl_chat_processor.image_start_tag

    @classmethod
    def _postprocess_output(cls, outputs):
        # concated = torch.cat([outputs, torch.zeros_like(outputs)], dim=1)
        outputs = outputs.numpy()
        outputs = np.clip((outputs + 1) / 2 * 255, 0, 255)
        image = Image.fromarray(outputs.astype(np.uint8))

        return image

    @classmethod
    def invoke_model(cls, prompt: str,
                     temperature: float = 1,
                     cfg_weight: float = 5,
                     image_token_num_per_image: int = 576,
                     audio_embeds: torch.Tensor = None,
                     **kwargs):

        input_ids = cls.vl_chat_processor.tokenizer.encode(prompt)
        input_ids = torch.LongTensor(input_ids)

        tokens = torch.zeros((cls.parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
        for i in range(cls.parallel_size * 2):
            tokens[i, :] = input_ids
            if i % 2 != 0:
                tokens[i, 1:-1] = cls.vl_chat_processor.pad_id
        cls.vl_gpt.language_model.config._attn_implementation = 'eager'

        inputs_embeds = cls.vl_gpt.language_model.get_input_embeddings()(tokens)
        print(inputs_embeds.shape)
        if audio_embeds is not None:
            inputs_embeds = torch.cat([audio_embeds, inputs_embeds], dim=1)

        generated_tokens = torch.zeros((cls.parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

        for i in range(image_token_num_per_image):
            outputs = cls.vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True,
                                                      past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state

            logits = cls.vl_gpt.gen_head(hidden_states[:, -1, :])
            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]

            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = cls.vl_gpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

        dec: torch.Tensor = cls.vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
                                                                    shape=[cls.parallel_size, 8,
                                                                           cls.img_size // cls.patch_size,
                                                                           cls.img_size // cls.patch_size])
        dec = dec.to(torch.float32).cpu().permute(0, 2, 3, 1)[0]

        return dec

    @classmethod
    def generate(cls, inputs: str) -> Image:
        inputs = cls._preprocess_input(inputs)
        outputs = cls.invoke_model(inputs=inputs)
        return cls._postprocess_output(outputs)

    @classmethod
    def generate_from_embeds(cls, inputs: np.ndarray) -> Image:
        prompt = cls._preprocess_input(cls.sys_prompt)
        input_tensor = torch.from_numpy(inputs).to(cls.audio_embeds_type).cuda()

        if not all([input_tensor.shape[d] == s for d, s in cls.audio_embeds_shape.items()]):
            logger.error(f"Input tensor had shape {inputs.shape} was expected {cls.audio_embeds_shape}")

        output = cls.invoke_model(prompt, audio_embeds=input_tensor)
        image = cls._postprocess_output(output)
        return image


# Training config

In [3]:
class TrainConfig:  # copy from hw-multimodal-llm-solved need to be changed
    log_level = "DEBUG"

    # Training
    num_epochs = 5
    train_batch_size = 8
    val_batch_size = 1
    log_grad_norm = True
    learning_rate = 1e-4
    gradient_accumulation_steps = 1

    evaluate_every_epoch_mod = 4
    save_model_every_epoch_mod = 1
    device = "cuda:0"

    # Model

    # Projector
    projector_input_dim = 1024

    # Data
    few_train_samples = None
    few_val_samples = 100
    dataloader_num_workers = 0

    train_dataset_path = ""
    audio_embeds_train_prefix = ""

    val_dataset_path = ""
    audio_embeds_val_prefix = ""


In [4]:
import yaml
import argparse

import pathlib
import random
import torch
import torch.nn as nn

import logging
# import evaluate

# import datasets
from transformers.generation import GenerationConfig

from torch.optim import Adam
from torch.utils.data import DataLoader

import transformers

from tqdm.auto import tqdm

# import wandb
# from wandb import sdk as wandb_sdk

# import accelerate


# Training Pipelines research

## Image decoding

In [5]:
model_path = f"deepseek-ai/Janus-Pro-1B"

prompt = "Abstract art for representing emotions"

vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Some kwargs in processor config are unused and will not have any effect: ignore_id, image_tag, mask_prompt, sft_form

In [6]:
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True,
)
vl_gpt.language_model.config._attn_implementation = 'eager'

vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

Image encoding and restoring

In [43]:
img = Image.open("generated_samples/img_0.jpg").convert("RGB")
prepare = vl_chat_processor.process_one(prompt="<image_placeholder>", images=[img])
bs, n = prepare.pixel_values.unsqueeze(0).shape[0:2]

images = prepare.pixel_values.to(torch.bfloat16).cuda()

with torch.no_grad():
    quant, _, info = vl_gpt.gen_vision_model.encode(images)  # torch.Size([1, 3, 384, 384])
    B, C, Hq, Wq = quant.shape
    _, _, min_encoding_indices = info
    image_ids = min_encoding_indices.view(B, Hq * Wq)
    gen_embeds = vl_gpt.prepare_gen_img_embeds(image_ids)  # torch.Size([1, 576, 2048])
    logits = vl_gpt.gen_head(gen_embeds)

    probs = torch.softmax(logits / 1, dim=-1)
    # gen_image_tokens = torch.multinomial(probs[0], num_samples=1)
    indices = logits.max(dim=-1).indices
    gen_image_tokens = indices
    dec = vl_gpt.gen_vision_model.decode_code(gen_image_tokens.to(dtype=torch.int), shape=[1, 8, 24, 24])
    dec_temp = dec
    dec = dec.to(torch.float32).detach().cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)
    visual_img = dec.astype(np.uint8)
    Image.fromarray(visual_img[0]).save("test.jpg")

RuntimeError: Input type (int) and bias type (c10::BFloat16) should be the same

In [47]:
dec_temp = vl_gpt.gen_vision_model.decode_code(image_ids.to(dtype=torch.int), shape=[1, 8, 24, 24])
dec_temp = dec_temp.to(torch.float32).detach().cpu().numpy().transpose(0, 2, 3, 1)

dec_temp = np.clip((dec_temp + 1) / 2 * 255, 0, 255)
visual_img = dec_temp.astype(np.uint8)
Image.fromarray(visual_img[0]).save("test.jpg")

In [35]:
indices = logits.max(dim=-1).indices
indices.shape

torch.Size([1, 576])

In [29]:

# Image.fromarray(images[0].permute(1, 2, 0).detach().cpu().to(torch.float).numpy()).save("test3.jpg")
img_arr = images[0].permute(1, 2, 0).detach().to(torch.float).cpu().numpy()
img_arr = np.clip((img_arr + 1) / 2 * 255, 0, 255)
image = Image.fromarray(img_arr.astype("uint8"))
image.save("test3.jpg")

In [14]:
from transformers.models.bart.modeling_bart import shift_tokens_right  # similar utility
from einops import rearrange


def shift_image_tokens(image_ids: torch.Tensor):

    image_embeds_shifted = shift_tokens_right(
        image_ids,
        pad_token_id=vl_chat_processor.pad_id,
        decoder_start_token_id=vl_chat_processor.image_start_id
    )
    return image_embeds_shifted


def get_image_janus_embeds(imgs: List[Image.Image]):
    prepare = vl_chat_processor.process_one(prompt="<image_placeholder>", images=imgs)
    images = rearrange(prepare.pixel_values.unsqueeze(0).to(torch.bfloat16).to("cuda:0"), "b n c h w -> (b n) c h w")
    quant, _, info = vl_gpt.gen_vision_model.encode(images)  # torch.Size([1, 3, 384, 384])
    B, C, Hq, Wq = quant.shape
    _, _, min_encoding_indices = info
    image_ids = min_encoding_indices.view(B, Hq * Wq)
    # gen_embeds = vl_gpt.prepare_gen_img_embeds(image_ids)
    gen_embeds = vl_gpt.gen_embed(image_ids)
    gen_aligned = vl_gpt.gen_aligner(gen_embeds)

    return image_ids, gen_embeds, gen_aligned

In [10]:
image_path = "generated_samples/img_0.jpg"
image = Image.open(image_path).convert("RGB")

In [19]:
from torch.nn.functional import cross_entropy

random_tensor = torch.randn(1, 1, 2048).cuda().to(torch.bfloat16)  # audio projection tensor

with torch.no_grad():
    image_ids, image_embeds, image_aligned = get_image_janus_embeds([image])
    # input_tensor — prompt embeds (text), image_bind_projection, image_embeds (concat with dim=1)
    input_tensor = torch.concat([image_aligned], dim=1)

    outputs = vl_gpt.language_model.model(inputs_embeds=input_tensor, use_cache=False, past_key_values=None,
                                          decoder_input_ids=1)

    hidden_states = outputs.last_hidden_state  # torch.Size([1, 608, 2048])

    logits = vl_gpt.gen_head(hidden_states)
    print(logits.shape)
    probs = torch.softmax(logits[:, -576:, :], dim=-1)
    logits = logits.permute(0, 2, 1)
    image_ids = image_ids.squeeze(-1) # torch.Size([1, 608, 2048])
    shifted_image_ids = shift_image_tokens(image_ids) # torch.Size([1, 576])
    # loss = cross_entropy(logits[:, :, -576:], shifted_image_ids, ignore_index=-100)

torch.Size([1, 576, 16384])


In [20]:
# indices = logits[:, :, -576:].max(dim=-2).indices
probs = torch.softmax(logits[:, :, -576:], dim=-1)
indices = probs.max(dim=-2).indices
indices.shape

torch.Size([1, 576])

In [21]:
gen_image_tokens = indices
dec = vl_gpt.gen_vision_model.decode_code(gen_image_tokens.to(dtype=torch.int), shape=[1, 8, 24, 24])
dec_temp = dec
dec = dec.to(torch.float32).detach().cpu().numpy().transpose(0, 2, 3, 1)

dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = dec.astype(np.uint8)
Image.fromarray(visual_img[0]).save("test.jpg")