In [1]:
# add janus to os path
import sys
import os
from typing import List, Dict

import accelerate
from accelerate import Accelerator

from torch.nn.functional import cross_entropy

janus_path = os.path.abspath("../Janus/janus")
sys.path.append(janus_path)

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"


In [2]:
from abc import ABC

import numpy as np
import torch
from imagebind.models.imagebind_model import ImageBindModel
from imagebind import data
from torch import nn
from PIL import Image


from transformers import AutoModelForCausalLM

from Janus.janus.models import MultiModalityCausalLM, VLChatProcessor
from Janus.janus.models.processing_vlm import VLChatProcessorOutput




Python version is above 3.10, patching the collections module.




Python version is above 3.10, patching the collections module.


# Audio Projection

In [3]:
class AudioProjection(nn.Module):
    def __init__(self, input_dim, output_dim, sequal_len=32, scale_factor=2):
        super(AudioProjection, self).__init__()
        self.scale_factor = scale_factor
        self.sequal_len = sequal_len
        self.output_dim = output_dim
        self.fc1 = nn.Linear(input_dim, scale_factor * output_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(scale_factor * output_dim, sequal_len * output_dim)

    def forward(self, x):
        B = x.shape[0]
        x = self.fc1(x)  # → [B, scale_factor * output_dim]
        x = self.act(x)
        x = self.fc2(x)  # → [B, sequal_len * output_dim]
        x = torch.reshape(x, (B, self.sequal_len, self.output_dim))

        return x

# Training config

In [4]:
class TrainConfig:  # copy from hw-multimodal-llm-solved need to be changed
    log_level = "DEBUG"

    # Training
    num_epochs = 1
    train_batch_size = 2
    val_batch_size = 1
    log_grad_norm = True
    learning_rate = 1e-4
    gradient_accumulation_steps = 1

    evaluate_every_epoch_mod = 4
    save_model_every_epoch_mod = 1
    device = "cuda:0"

    # Model

    # Projector
    projector_input_dim = 1024

    # Data
    few_train_samples = None
    few_val_samples = 100
    dataloader_num_workers = 0

    train_dataset_path = ""
    audio_embeds_train_prefix = ""

    val_dataset_path = ""
    audio_embeds_val_prefix = ""


In [5]:
import yaml
import argparse
import pathlib
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas as pd
from torchvision.transforms import ToTensor

import logging
# import evaluate
# import datasets
from transformers.generation import GenerationConfig
from torch.optim import Adam
from torch.utils.data import DataLoader
import transformers
from tqdm.auto import tqdm
# import wandb
# from wandb import sdk as wandb_sdk
# import accelerate


In [6]:
model_path = f"deepseek-ai/Janus-Pro-1B"

prompt = "Abstract art for representing emotions"

vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True,
)
vl_gpt.language_model.config._attn_implementation = 'eager'

vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Some kwargs in processor config are unused and will not have any effect: mask_prompt, add_special_token, image_tag, 

# Data loading

In [7]:
matched_df_path = "../data/notebooks/matched_dataset_concat.pkl"
matched_df = pd.read_pickle(matched_df_path)

In [8]:
class ImageAudioDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.dataframe["image_path"] = self.dataframe["image_path"].apply(
            lambda x: os.path.abspath(x.replace("\\", "/")))
        self.dataframe["audio_path"] = self.dataframe["audio_path"].apply(
            lambda x: os.path.abspath(x.replace("\\", "/")))
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx) -> dict:
        image_path = self.dataframe.iloc[idx]["image_path"]
        audio_path = self.dataframe.iloc[idx]["audio_path"]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        else:
            image = ToTensor()(image)
            # resize to 384x384
            image = nn.functional.interpolate(image.unsqueeze(0), size=(384, 384), mode='bilinear', align_corners=False)

        image_embedding = self.dataframe.iloc[idx]["image_embedding"]
        music_embedding = self.dataframe.iloc[idx]["music_embedding"]
        res = {
            "audio_path": audio_path,
            "image": image,
            "music_embedding": music_embedding,
            "image_embedding": image_embedding,
        }
        return res

In [9]:
dataset = ImageAudioDataset(matched_df)

In [10]:
from transformers.models.bart.modeling_bart import shift_tokens_right  # similar utility
from einops import rearrange


def shift_image_tokens(image_ids: torch.Tensor):

    image_embeds_shifted = shift_tokens_right(
        image_ids,
        pad_token_id=vl_chat_processor.pad_id,
        decoder_start_token_id=vl_chat_processor.image_start_id
    )
    return image_embeds_shifted


def get_collate_fn(gen_model: MultiModalityCausalLM, validation=False):
    def collate_fn(items):
        result = dict()
        with torch.no_grad():
            images = torch.stack([item["image"] for item in items], dim=1)
            quant, _, info = gen_model.gen_vision_model.encode(
                images.squeeze(0).to(dtype=torch.bfloat16).cuda())  # torch.Size([1, 3, 384, 384])
            B, C, Hq, Wq = quant.shape
            _, _, min_encoding_indices = info
            image_ids = min_encoding_indices.view(B, Hq * Wq)
            gen_embeds = gen_model.prepare_gen_img_embeds(image_ids)

        result["image_ids"] = image_ids.squeeze(-1)
        result["image_gen_embeds"] = gen_embeds
        result["music_embedding"] = torch.stack([torch.from_numpy(item["music_embedding"]) for item in items], dim=0)
        result["images"] = images

        return result

    return collate_fn

### Dataloader example

In [46]:
for epoch in range(TrainConfig.num_epochs):
    dataloader = DataLoader(dataset, batch_size=TrainConfig.train_batch_size, shuffle=True,
                            num_workers=TrainConfig.dataloader_num_workers,
                            collate_fn=get_collate_fn(vl_gpt)
                            )
    for batch_idx, batch in enumerate(dataloader):
        # Example: Access data from the batch
        print(batch["images"].shape)
        print(batch["image_gen_embeds"].shape)
        print(batch["image_ids"].shape)
        print(batch["music_embedding"].shape)

        # Simulate training step (e.g., forward pass, loss computation, backward pass)
        print(f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Processed {len([])} samples")
        break


torch.Size([1, 2, 3, 384, 384])
torch.Size([2, 576, 2048])
torch.Size([2, 576])
torch.Size([2, 1024])
Epoch 1, Batch 1: Processed 0 samples


# Training Loop

In [17]:
def train_loop(accelerator: accelerate.Accelerator, model: MultiModalityCausalLM, projection: AudioProjection,
               optimizer, train_dataloader: DataLoader, epoch, criterion, last_validation_bleu=0.0, train_config=None):
    model.eval()
    projection.train()
    progress_bar = tqdm(range(len(train_dataloader)), desc="Epoch {}".format(epoch))
    for batch in train_dataloader:
        with accelerator.accumulate(projection):
            audio_input = projection(batch["music_embedding"]).to(torch.bfloat16)
            image_gen_embeds = batch["image_gen_embeds"].to(torch.bfloat16)
            image_ids = batch["image_ids"]

            # add sys prompt embeds to input
            input_embeds = torch.concat([audio_input, image_gen_embeds], dim=1)
            print(input_embeds.shape)
            print(input_embeds.type)
            outputs = model.language_model.model(inputs_embeds=input_embeds, use_cache=False, past_key_values=None,
                                          decoder_input_ids=1)
            hidden_states = outputs.last_hidden_state
            logits = model.gen_head(hidden_states)
            logits = logits.permute(0, 2, 1)
            shifted_image_ids = shift_image_tokens(image_ids)
            print(
                shifted_image_ids.shape, shifted_image_ids.dtype, logits.shape, sep="\n"
            )
            loss = criterion(logits[:, :, -576:].cpu(), image_ids.cpu())
            model.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
            progress_bar.update(1)
            progress_bar.set_description(f'Epoch={epoch} Loss={loss.item():.3f}')

            step_metrics = {"train_loss": loss.item(), "epoch": epoch}
            if train_config.log_grad_norm:
                for name, parameter in projection.named_parameters():
                    if parameter.grad is not None:
                        parameter_grad_norm = parameter.grad.norm(2).item()
                    else:
                        parameter_grad_norm = 0.0
                    step_metrics[f'grad_norm_{name}'] = parameter_grad_norm

    return

# Val loop

In [44]:
@torch.no_grad()
def val_loop( model: MultiModalityCausalLM, processor: VLChatProcessor, projection: AudioProjection,
              val_dataloader: DataLoader, epoch=1, no_loss=False, captioning_metrics=None):

    criterion = nn.CrossEntropyLoss(ignore_index=processor.pad_id)

    sumloss = 0
    num_batches = 0

    generations = []
    target_generations = []

    # gen_params = {
    #     "do_sample": False,
    #     "early_stopping": True,
    #     "num_beams": 3,
    #     "repetition_penalty": 2.5,
    #     "remove_invalid_values": True,
    #     "eos_token_id": processor.eos_token_id,
    #     "pad_token_id": processor.eos_token_id,
    #     "forced_eos_token_id": processor.eos_token_id,
    #     "use_cache": True,
    #     "no_repeat_ngram_size": 4,
    #     "num_return_sequences": 1,
    #     "cfg_weight": 5,
    #     "temperature":1
    # }
    cfg_weight = 5
    temperature = 1
    img_size = 384
    patch_size = 16
    genconfig = GenerationConfig.from_model_config(model.language_model.config)

    model.eval()
    for batch in tqdm(val_dataloader):
        batch_input_ids = batch['image_ids'].to(model.device)
        caption_legth = batch_input_ids.shape[1]

        music_embedding = projection(batch["music_embedding"]).to(torch.bfloat16).cuda()
        image_gen_embeds = batch["image_gen_embeds"].to(torch.bfloat16)
        batch_input_embeds = music_embedding
        input_embeds = torch.concat([batch_input_embeds, image_gen_embeds], dim=1)

        if not no_loss:
            # outputs = model.language_model.model(inputs_embeds=input_embeds, use_cache=False, past_key_values=None,
            #                               decoder_input_ids=1)
            # hidden_states = outputs.last_hidden_state
            hidden_states = torch.rand(input_embeds.shape).to(torch.bfloat16).cuda()
            logits = model.gen_head(hidden_states)
            logits = logits.permute(0, 2, 1)

            shifted_image_ids = shift_image_tokens(batch_input_ids)

            loss = criterion(logits[:, :, -576:].cpu(), batch_input_ids.cpu())
            sumloss += loss.item()
            num_batches += 1

        parallel_size = music_embedding.shape[0]
        tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
        unconditional_tokens = torch.zeros((1, input_embeds.shape[-2]), dtype=torch.int).cuda()
        print(input_embeds.shape)
        print(unconditional_tokens.shape)
        unconditional_tokens[0, 1:-1] = processor.pad_id
        assert unconditional_tokens.shape[0] == 1
        # assert unconditional_tokens.shape[-1] == input_embeds.shape[-2]

        unconditional_embeds = model.language_model.get_input_embeddings()(unconditional_tokens)
        assert unconditional_embeds.shape[-1] == input_embeds.shape[-1]

        gen_input_embeds = torch.zeros((input_embeds.shape[0] * 2, *input_embeds.shape[1:]), dtype=torch.bfloat16).cuda()
        print(gen_input_embeds.shape)
        for i in range(parallel_size*2):
            if i % 2 != 0:
                gen_input_embeds[i] = unconditional_embeds
            else:
                gen_input_embeds[i] = input_embeds[i//2]

        generated_tokens = torch.zeros((parallel_size, batch_input_ids.shape[-1]), dtype=torch.int).cuda()

        for i in range(batch_input_ids.shape[-1]):
            # outputs = model.language_model.model(inputs_embeds=gen_input_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
            # hidden_states = outputs.last_hidden_state
            hidden_states = torch.rand((4, 578, 2048), dtype=torch.bfloat16).cuda()
            logits = model.gen_head(hidden_states[:, -1, :])
            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]

            logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = model.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

        dec = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec

        os.makedirs('generated_samples', exist_ok=True)
        for i in range(parallel_size):
            save_path = os.path.join('generated_samples', "val_img_{}.jpg".format(i))
            Image.fromarray(visual_img[i]).save(save_path)


In [45]:
proj = AudioProjection(1024, 2048, scale_factor=2, sequal_len=2)
val_dataloader = DataLoader(dataset, batch_size=TrainConfig.train_batch_size, shuffle=True,
                            num_workers=TrainConfig.dataloader_num_workers,
                            collate_fn=get_collate_fn(vl_gpt)
                            )
val_loop(vl_gpt, vl_chat_processor, proj, val_dataloader )

  0%|          | 0/2743 [00:00<?, ?it/s]

torch.Size([2, 578, 2048])
torch.Size([1, 578])
torch.Size([4, 578, 2048])
torch.Size([2, 578, 2048])
torch.Size([1, 578])
torch.Size([4, 578, 2048])
torch.Size([2, 578, 2048])
torch.Size([1, 578])
torch.Size([4, 578, 2048])


KeyboardInterrupt: 

In [18]:
import evaluate


def train(
        model: MultiModalityCausalLM,
        projection: AudioProjection,
        processor: VLChatProcessor,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        train_config: TrainConfig,
        device_placement=True,
        ):
    trainable_parameters = list(projection.parameters())
    optimizer = Adam(trainable_parameters, lr=train_config.learning_rate)
    criterion = nn.CrossEntropyLoss()

    accelerator = accelerate.Accelerator(device_placement=device_placement)
    accelerator.gradient_accumulation_steps = train_config.gradient_accumulation_steps

    model, projection, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, projection, optimizer, train_dataloader, val_dataloader)
    # captioning_metrics = evaluate.combine(
    #     [
    #         evaluate.load("bleu", keep_in_memory=True),
    #         evaluate.load("rouge", keep_in_memory=True),
    #         evaluate.load("meteor", keep_in_memory=True),
    #     ]
    # )

    best_validation_bleu = 0.0
    last_validation_bleu = 0.0

    for epoch in range(train_config.num_epochs):
        train_loop(accelerator, model,projection, optimizer, train_dataloader, epoch=epoch, criterion=criterion, last_validation_bleu=last_validation_bleu, train_config=train_config)

    #     if epoch % train_config.evaluate_every_epoch_mod == 0:
    #         validation_metrics = val_loop(model, tokenizer, val_dataloader, epoch=epoch, captioning_metrics=captioning_metrics)
    #         logger.info(f"validation metrics {validation_metrics}")
    #
    #         last_validation_bleu = validation_metrics['validation/evaluate_bleu']
    #         metric_logger.log(validation_metrics)
    #
    #         if last_validation_bleu > best_validation_bleu:
    #             best_validation_bleu = last_validation_bleu
    #
    #             base_path_for_best_model = pathlib.Path(f"data/models/{metric_logger.name}/best/")
    #             save_model(train_config=train_config, model=model, path=base_path_for_best_model)
    #
    #     if epoch % train_config.save_model_every_epoch_mod == 0:
    #         base_path_for_model = pathlib.Path(f"data/models/{metric_logger.name}/last/")
    #         save_model(train_config=train_config, model=model, path=base_path_for_model)
    #
    # base_path_for_model = pathlib.Path(f"data/models/{metric_logger.name}/last/")
    # save_model(train_config=train_config, model=model, path=base_path_for_model)


def freeze_model(model):
    for p in model.parameters():
        p.requires_grad = False
    return


In [22]:
proj = AudioProjection(1024, 2048, scale_factor=2)
train_dataloader = DataLoader(dataset, batch_size=TrainConfig.train_batch_size, shuffle=True,
                            num_workers=TrainConfig.dataloader_num_workers,
                            collate_fn=get_collate_fn(vl_gpt)
                            )

train(
    model=vl_gpt,
    projection=proj,
    processor=vl_chat_processor,
    train_dataloader=train_dataloader,
    val_dataloader=None,
    train_config=TrainConfig,
)

Epoch 0:   0%|          | 0/2743 [00:00<?, ?it/s]

torch.Size([2, 608, 2048])
<built-in method type of Tensor object at 0x7f4280717fb0>
torch.Size([2, 576])
torch.int64
torch.Size([2, 16384, 608])
torch.Size([2, 608, 2048])
<built-in method type of Tensor object at 0x7f42807dbf60>
torch.Size([2, 576])
torch.int64
torch.Size([2, 16384, 608])
torch.Size([2, 608, 2048])
<built-in method type of Tensor object at 0x7f4280570270>
torch.Size([2, 576])
torch.int64
torch.Size([2, 16384, 608])
torch.Size([2, 608, 2048])
<built-in method type of Tensor object at 0x7f4281057330>
torch.Size([2, 576])
torch.int64
torch.Size([2, 16384, 608])
torch.Size([2, 608, 2048])
<built-in method type of Tensor object at 0x7f4280594810>
torch.Size([2, 576])
torch.int64
torch.Size([2, 16384, 608])
torch.Size([2, 608, 2048])
<built-in method type of Tensor object at 0x7f4281207920>
torch.Size([2, 576])
torch.int64
torch.Size([2, 16384, 608])
torch.Size([2, 608, 2048])
<built-in method type of Tensor object at 0x7f4280d265c0>


KeyboardInterrupt: 

In [25]:
# empty cuda
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

torch.cuda.empty_cache()

import gc

gc.collect()

0