In [2]:
# add janus to os path
import sys
import os
from typing import List

janus_path = os.path.abspath("../Janus/janus")
sys.path.append(janus_path)


In [3]:
from abc import ABC

import numpy as np
import torch
from imagebind.models.imagebind_model import ImageBindModel
from imagebind import data
from torch import nn
from PIL import Image

from transformers import AutoModelForCausalLM

from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.models.processing_vlm import VLChatProcessorOutput




Python version is above 3.10, patching the collections module.




# Audio Projection

In [4]:
class AudioProjection(nn.Module):
    def __init__(self, input_dim, output_dim, sequal_len=32, scale_factor=2):
        super(AudioProjection, self).__init__()
        self.scale_factor = scale_factor
        self.sequal_len = sequal_len
        self.output_dim = output_dim
        self.fc1 = nn.Linear(input_dim, scale_factor * output_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(scale_factor * output_dim, sequal_len * output_dim)

    def forward(self, x):
        B = x.shape[0]
        x = self.fc1(x)  # → [B, scale_factor * output_dim]
        x = self.act(x)
        x = self.fc2(x)  # → [B, sequal_len * output_dim]
        x = torch.reshape(x, (B, self.sequal_len, self.output_dim))

        return x

In [3]:
x = torch.randn(2, 1024)
with torch.no_grad():
    proj = AudioProjection(1024, 2048, scale_factor=2)
    res = proj(x)
print(res.shape)

NameError: name 'AudioProjection' is not defined

# Janus Image Generator

In [37]:
AVAILABLE_MODELS = ["Janus-1.3B", "Janus-Pro-1B", "JanusFlow-1.3B"]


class JanusImageGenerator:
    name = "Janus Image Generator"

    @classmethod
    def init_model(cls, config):
        model_name = config.get("model_name", "Janus-Pro-1B")
        cls.sys_prompt = config.get("sys_prompt", "Abstract art for representing emotions")
        if model_name not in AVAILABLE_MODELS:
            logging.warning(f"Model {model_name} not available. Using {AVAILABLE_MODELS[0]} instead.")
            model_name = AVAILABLE_MODELS[0]

        model_path = f"deepseek-ai/{model_name}"
        cls.vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
        cls.tokenizer = cls.vl_chat_processor.tokenizer

        cls.vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
            model_path, trust_remote_code=True
        )
        cls.vl_gpt = cls.vl_gpt.to(torch.bfloat16).cuda().eval()

        cls.model = cls.vl_gpt
        cls.audio_embeds_shape = {0: 2, 2: 2048}
        cls.audio_embeds_type = torch.bfloat16
        # hardcoded need to learn more about this
        cls.parallel_size = 1
        cls.img_size = 384
        cls.patch_size = 16

    @classmethod
    def _preprocess_input(cls, inputs):
        conversation = [
            {
                "role": "User",
                "content": inputs,
            },
            {"role": "Assistant", "content": ""},
        ]

        sft_format = cls.vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
            conversations=conversation,
            sft_format=cls.vl_chat_processor.sft_format,
            system_prompt="",
        )

        return sft_format + cls.vl_chat_processor.image_start_tag

    @classmethod
    def _postprocess_output(cls, outputs):
        # concated = torch.cat([outputs, torch.zeros_like(outputs)], dim=1)
        outputs = outputs.numpy()
        outputs = np.clip((outputs + 1) / 2 * 255, 0, 255)
        image = Image.fromarray(outputs.astype(np.uint8))

        return image

    @classmethod
    def invoke_model(cls, prompt: str,
                     temperature: float = 1,
                     cfg_weight: float = 5,
                     image_token_num_per_image: int = 576,
                     audio_embeds: torch.Tensor = None,
                     **kwargs):

        input_ids = cls.vl_chat_processor.tokenizer.encode(prompt)
        input_ids = torch.LongTensor(input_ids)

        tokens = torch.zeros((cls.parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
        for i in range(cls.parallel_size * 2):
            tokens[i, :] = input_ids
            if i % 2 != 0:
                tokens[i, 1:-1] = cls.vl_chat_processor.pad_id
        cls.vl_gpt.language_model.config._attn_implementation = 'eager'

        inputs_embeds = cls.vl_gpt.language_model.get_input_embeddings()(tokens)
        print(inputs_embeds.shape)
        if audio_embeds is not None:
            inputs_embeds = torch.cat([audio_embeds, inputs_embeds], dim=1)

        generated_tokens = torch.zeros((cls.parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

        for i in range(image_token_num_per_image):
            outputs = cls.vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True,
                                                      past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state

            logits = cls.vl_gpt.gen_head(hidden_states[:, -1, :])
            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]

            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = cls.vl_gpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

        dec: torch.Tensor = cls.vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
                                                                    shape=[cls.parallel_size, 8,
                                                                           cls.img_size // cls.patch_size,
                                                                           cls.img_size // cls.patch_size])
        dec = dec.to(torch.float32).cpu().permute(0, 2, 3, 1)[0]

        return dec

    @classmethod
    def generate(cls, inputs: str) -> Image:
        inputs = cls._preprocess_input(inputs)
        outputs = cls.invoke_model(inputs=inputs)
        return cls._postprocess_output(outputs)

    @classmethod
    def generate_from_embeds(cls, inputs: np.ndarray) -> Image:
        prompt = cls._preprocess_input(cls.sys_prompt)
        input_tensor = torch.from_numpy(inputs).to(cls.audio_embeds_type).cuda()

        if not all([input_tensor.shape[d] == s for d, s in cls.audio_embeds_shape.items()]):
            logger.error(f"Input tensor had shape {inputs.shape} was expected {cls.audio_embeds_shape}")

        output = cls.invoke_model(prompt, audio_embeds=input_tensor)
        image = cls._postprocess_output(output)
        return image


# Training config

In [4]:
class TrainConfig:  # copy from hw-multimodal-llm-solved need to be changed
    log_level = "DEBUG"

    # Training
    num_epochs = 5
    train_batch_size = 8
    val_batch_size = 1
    log_grad_norm = True
    learning_rate = 1e-4
    gradient_accumulation_steps = 1

    evaluate_every_epoch_mod = 4
    save_model_every_epoch_mod = 1
    device = "cuda:0"

    # Model

    # Projector
    projector_input_dim = 1024

    # Data
    few_train_samples = None
    few_val_samples = 100
    dataloader_num_workers = 0

    train_dataset_path = ""
    audio_embeds_train_prefix = ""

    val_dataset_path = ""
    audio_embeds_val_prefix = ""


In [5]:
import yaml
import argparse

import pathlib
import random
import torch
import torch.nn as nn

import logging
# import evaluate

# import datasets
from transformers.generation import GenerationConfig

from torch.optim import Adam
from torch.utils.data import DataLoader

import transformers

from tqdm.auto import tqdm

# import wandb
# from wandb import sdk as wandb_sdk

# import accelerate


In [None]:
image_bind = ImageBindModel()

In [6]:
from imagebind import ModalityType

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')


def prepare_image_tokens(model: MultiModalityCausalLM, batch):
    """
    Get tokens from images in batch
    """

    vq_model = model.gen_vision_model
    if "images" not in batch:
        return

    quantized, _, info = vq_model.encode(batch["images"])
    image_tokens = info[2].view(batch, -1)

    batch["image_tokens"] = image_tokens

    return batch


def get_audio_embeds(audio_encoder: ImageBindModel, batch):
    """
    From file or call imagebind
    """
    audio_embeds = batch.get("audio_embeds", [])
    if audio_embeds:
        return audio_embeds

    inputs = {
        ModalityType.AUDIO: data.load_and_transform_audio_data(batch["audio_path"], TrainConfig.device),
    }

    with torch.no_grad():
        embeddings = audio_encoder(inputs)

    batch["audio_embeds"] = embeddings
    return batch


def get_image_embeds(image_encoder: ImageBindModel, batch):
    """
    From file or call imagebind
    """
    image_embeds = batch.get("image_embeds", [])
    if image_embeds:
        return image_embeds

    inputs = {
        ModalityType.VISION: data.load_and_transform_vision_data(batch["image_path"], TrainConfig.device),
    }

    with torch.no_grad():
        embeddings = image_encoder(inputs)

    batch["image_embeds"] = embeddings
    return batch


def save_model(train_config: TrainConfig, model: AudioProjection, path: pathlib.Path):
    path.mkdir(parents=True, exist_ok=True)
    logger.info(f"save model to {path}")

    model.save_pretrained(path)
    if train_config.llm_train_lora:
        model.lm_model.save_pretrained(path.joinpath("lora_adapter"))

    return



SyntaxError: invalid syntax. Perhaps you forgot a comma? (527465804.py, line 14)

In [None]:
def train_loop(accelerator: accelerate.Accelerator, model: AudioProjection, optimizer, train_dataloader: DataLoader,
               epoch, criterion, last_validation_bleu=0.0, train_config=None):
    model.train()
    sumloss = 0
    progress_bar = tqdm(range(len(train_dataloader)), desc=f'Epoch {epoch}')
    # janus eval

    for batch in train_dataloader:
        with accelerator.accumulate(model):
            audio_embeds = get_audio_embeds(image_bind, batch["audio_path"])
            image_tokens = prepare_image_tokens(model, batch)
            proj_embeds = model(audio_embeds)
            with torch.no_grad():
                tokens = JanusImageGenerator.generate_from_embeds(proj_embeds)

            loss = criterion(tokens, image_tokens)
            sumloss += loss.item()
            model.zero_grad()
            accelerator.backward(loss)
            optimizer.step()


# Training Pipelines research

In [6]:
# inputs_embeds — prompt embeds (text), image_bind_projection, image_embeds (concat with dim=1)
model_path = f"deepseek-ai/Janus-Pro-1B"

prompt = "Abstract art for representing emotions"

vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Some kwargs in processor config are unused and will not have any effect: sft_format, image_tag, mask_prompt, add_spe

In [7]:
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True,
)
vl_gpt.language_model.config._attn_implementation = 'eager'

vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

In [18]:
from transformers.models.bart.modeling_bart import shift_tokens_right  # similar utility
from einops import rearrange


def get_image_janus_embeds(imgs: List[Image.Image]):
    prepare = vl_chat_processor.process_one(prompt="<image_placeholder>", images=imgs)
    bs, n = prepare.pixel_values.unsqueeze(0).shape[0:2]
    images = rearrange(prepare.pixel_values.unsqueeze(0).to(torch.bfloat16).to("cuda:0"), "b n c h w -> (b n) c h w")
    image_embeds_shifted = shift_tokens_right(
        images,
        pad_token_id=vl_chat_processor.pad_id,
        decoder_start_token_id=vl_chat_processor.image_start_id
    )
    images_embeds = vl_gpt.aligner(vl_gpt.vision_model(image_embeds_shifted))

    return images_embeds

In [19]:
image_path = "generated_samples/img_0.jpg"
image = Image.open(image_path).convert("RGB")

image_embeds = get_image_janus_embeds([image])
image_embeds.shape

torch.Size([1, 576, 2048])

In [22]:
from torch.nn.functional import mse_loss, cross_entropy

random_tensor = torch.randn(1, 576, 2048).cuda().to(torch.bfloat16)
image_tensor = image_embeds = get_image_janus_embeds([image])
random_tensor = torch.concat([random_tensor, image_tensor], dim=1)

with torch.no_grad():
    outputs = vl_gpt.language_model.model(inputs_embeds=random_tensor, use_cache=False, past_key_values=None,
                                          decoder_input_ids=1)
    hidden_states = outputs.last_hidden_state

    print(hidden_states.shape)
    print(hidden_states.dtype)

    logits = vl_gpt.gen_head(hidden_states)
    labels_logits = logits[:, -576:, :]
    pred_logits = logits[:, :-576, :]

    labels_probs = torch.softmax(labels_logits, dim=-1)
    pred_probs = torch.softmax(pred_logits, dim=-1)
    print(pred_probs.shape)
    pred_probs = pred_probs.permute(0, 2, 1)
    pred_logits = pred_logits.permute(0, 2, 1)
    print(pred_probs.shape)
    # pred_tokens = torch.multinomial(pred_probs[0], num_samples=256)
    labels_tokens = torch.multinomial(labels_probs[0], num_samples=1)
    labels_tokens = labels_tokens.squeeze(-1).unsqueeze(0)
    print(labels_tokens.shape)
    loss = cross_entropy(pred_logits, labels_tokens, ignore_index=-100)

torch.Size([1, 1152, 2048])
torch.bfloat16
torch.Size([1, 576, 16384])
torch.Size([1, 16384, 576])
torch.Size([1, 576])


In [23]:
loss.item()  # 9.6875

11.75

In [49]:
# reshape pred_probs from torch.Size([1, 256, 16384]) to torch.Size([1, 16384, 256])
pred_probs = pred_probs.permute(0, 2, 1)
print(pred_probs.shape)  # Should print torch.Size([1, 16384, 256])

torch.Size([1, 16384, 256])


In [58]:
from janus.utils.io import load_pil_images

image_path = "generated_samples/img_0.jpg"

conversation = [
    {
        "role": "User",
        "content": "<image_placeholder>",
        "images": [image_path],
    },
    {"role": "Assistant", "content": ""},
]
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
    conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.device)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

inputs_embeds.shape

torch.Size([1, 623, 2048])

In [67]:
def prepare_inputs_embeds(
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        images_seq_mask: torch.LongTensor,
        images_emb_mask: torch.LongTensor,
        **kwargs,
):
    bs, n = pixel_values.shape[0:2]
    images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
    # [b x n, T2, D]
    images_embeds = vl_gpt.aligner(vl_gpt.vision_model(images))

    # [b x n, T2, D] -> [b, n x T2, D]
    images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
    # [b, n, T2] -> [b, n x T2]
    images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
    print(images_seq_mask.shape)
    print(images_emb_mask.shape)
    # [b, T, D]
    input_ids[input_ids < 0] = 0  # ignore the image embeddings
    inputs_embeds = vl_gpt.language_model.get_input_embeddings()(input_ids)

    # replace with the image embeddings
    inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
    print(images_embeds.shape)
    return inputs_embeds

In [68]:
inputs_embeds = prepare_inputs_embeds(**prepare_inputs)
inputs_embeds.shape

torch.Size([1, 623])
torch.Size([1, 576])
torch.Size([1, 576, 2048])


torch.Size([1, 623, 2048])

In [31]:
probs.shape
next = torch.multinomial(probs[0], num_samples=1)
print(next.shape)
print(next.detach().cpu())

torch.Size([6, 1])
tensor([[13255],
        [12526],
        [ 2021],
        [ 1168],
        [ 5789],
        [15969]])


In [35]:
a = vl_gpt.prepare_gen_img_embeds(next)
a.loss

AttributeError: 'Tensor' object has no attribute 'loss'

In [None]:
vision_model = vl_gpt.gen_vision_model
with torch.no_grad():
    quantized, a, info = vision_model.encode(image_tensor)
    image_embeds = info[2]  # Extract the embeddings
    print("Shape: ", image_embeds.shape, "Type: ", image_embeds.type)

image_embeds.shape

In [9]:
from torchvision import transforms
from einops import rearrange

image_path = "generated_samples/img_0.jpg"
image = Image.open(image_path).convert("RGB")

# preprocess = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize to the required input size
#     transforms.ToTensor(),  # Convert to a tensor
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
# ])

# image_tensor = preprocess(image).unsqueeze(0).to(torch.bfloat16).to("cuda:0")

images_outputs = vl_chat_processor.image_processor([image], return_tensors="pt")

prompt = "<image_placeholder>"

input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)

image_token_mask: torch.BoolTensor = input_ids == vl_chat_processor.image_id
image_indices = image_token_mask.nonzero()

input_ids, num_image_tokens = vl_chat_processor.add_image_token(
    image_indices=image_indices,
    input_ids=input_ids,
)

prepare = VLChatProcessorOutput(
    sft_format=prompt,
    input_ids=input_ids,
    pixel_values=images_outputs.pixel_values,
    num_image_tokens=num_image_tokens,
)
input_ids, num_image_tokens = input_ids.cuda(), num_image_tokens.cuda()
bs, n = images_outputs.pixel_values.unsqueeze(0).shape[0:2]
images = rearrange(images_outputs.pixel_values.unsqueeze(0).to(torch.bfloat16).to("cuda:0"), "b n c h w -> (b n) c h w")
images_embeds = vl_gpt.aligner(vl_gpt.vision_model(images))
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
images_embeds.shape


NameError: name 'vl_gpt' is not defined

In [77]:
temp = image_token_mask.detach().cpu()
# get the indices of the True values
indices = torch.nonzero(temp, as_tuple=True)[0]
num_image_tokens

tensor([576], device='cuda:0', dtype=torch.int32)

In [78]:
from janus.utils.io import load_pil_images

conversation = [
    {
        "role": "User",
        "content": "<image_placeholder>",
        "images": ["generated_samples/img_0.jpg"],
    },
    {"role": "Assistant", "content": ""},
]

# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
    conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.device)

inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

In [79]:
inputs_embeds.shape

torch.Size([1, 623, 2048])

In [97]:
prepare = vl_chat_processor.process_one(prompt="<image_placeholder>", images=[image, image])

In [102]:
images_outputs = vl_chat_processor.image_processor([image, image], return_tensors="pt")
images_outputs["pixel_values"].shape

torch.Size([2, 3, 384, 384])

In [103]:
res = get_image_janus_embeds([image, image])

In [104]:
res.shape

torch.Size([2, 576, 2048])

In [80]:

prompt = "<image_placeholder>"

input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)

image_token_mask: torch.BoolTensor = input_ids == vl_chat_processor.image_id
image_indices = image_token_mask.nonzero()

input_ids, num_image_tokens = input_ids, num_image_tokens

input_ids, num_image_tokens = vl_chat_processor.add_image_token(
    image_indices=image_indices,
    input_ids=input_ids,
)
images_outputs = vl_chat_processor.image_processor([image], return_tensors="pt")

prepare = VLChatProcessorOutput(
    sft_format=prompt,
    input_ids=input_ids,
    pixel_values=images_outputs.pixel_values.cuda().to(torch.bfloat16),
    num_image_tokens=num_image_tokens,
)

batch_prepare = vl_chat_processor.batchify([prepare])

bs, n = images_outputs.pixel_values.unsqueeze(0).shape[0:2]
images = rearrange(images_outputs.pixel_values.unsqueeze(0).to(torch.bfloat16).to("cuda:0"), "b n c h w -> (b n) c h w")
images_embeds = vl_gpt.aligner(vl_gpt.vision_model(images))
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
images_embeds.shape

torch.Size([1, 576, 2048])

In [12]:
print(input_ids.shape)
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(input_ids.to("cuda:0"))
print(inputs_embeds.shape)

torch.Size([6])
torch.Size([6, 2048])


In [40]:
images_emb_mask = rearrange(num_image_tokens, "b n t -> b (n t)")

EinopsError:  Error while processing rearrange-reduction pattern "b n t -> b (n t)".
 Input tensor shape: torch.Size([1]). Additional info: {}.
 Wrong shape: expected 3 dims. Received 1-dim tensor.

In [46]:
inputs_embeds.shape

torch.Size([1, 623, 2048])

AttributeError: 'Image' object has no attribute 'shape'

In [None]:
pad_tokens = torch.zeros((1, len(text_embeds) + len(image_embeds) + len(audio_proj)), dtype=torch.int)
pad_tokens[1, 1:-1] = vl_chat_processor.pad_id
pad_embeds = vl_gpt.language_model.get_input_embeddings()(pad_tokens)
input_embeds = torch.concat([text_embeds, image_embeds, audio_proj], dim=1)
input_embeds = torch.stack([input_embeds, pad_embeds])