In [None]:
# add janus to os path
import sys
import os
from typing import List, Dict

import accelerate
from accelerate import Accelerator

from torch.nn.functional import cross_entropy

janus_path = os.path.abspath("../Janus/janus")
sys.path.append(janus_path)

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"


In [None]:
from abc import ABC

import numpy as np
import torch
from imagebind.models.imagebind_model import ImageBindModel
from imagebind import data
from torch import nn
from PIL import Image


from transformers import AutoModelForCausalLM

from Janus.janus.models import MultiModalityCausalLM, VLChatProcessor
from Janus.janus.models.processing_vlm import VLChatProcessorOutput


# Audio Projection

In [None]:
class AudioProjection(nn.Module):
    def __init__(self, input_dim, output_dim, sequal_len=32, scale_factor=2):
        super(AudioProjection, self).__init__()
        self.scale_factor = scale_factor
        self.sequal_len = sequal_len
        self.output_dim = output_dim
        self.fc1 = nn.Linear(input_dim, scale_factor * output_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(scale_factor * output_dim, sequal_len * output_dim)

    def forward(self, x):
        B = x.shape[0]
        x = self.fc1(x)  # → [B, scale_factor * output_dim]
        x = self.act(x)
        x = self.fc2(x)  # → [B, sequal_len * output_dim]
        x = torch.reshape(x, (B, self.sequal_len, self.output_dim))

        return x

In [None]:
class ImprovedAudioProjection(nn.Module):
    def __init__(self, input_dim, output_dim, seq_len=32, num_layers=2, dropout=0.1, activation='gelu', use_l2=True,
                 scale_up: float = 1.0):
        super().__init__()
        self.seq_len = seq_len
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.use_l2 = use_l2
        self.scale_up = scale_up

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'leaky_relu':
            self.activation = nn.LeakyReLU(0.2)
        elif activation == 'swish':
            self.activation = nn.SiLU()
        else:
            self.activation = nn.GELU()

        # Sequential Layers
        self.projection_layers = nn.ModuleList()
        in_dim = input_dim
        for i in range(num_layers):
            out_dim_layer = output_dim if i == num_layers - 1 else 2 * output_dim
            self.projection_layers.append(nn.Linear(in_dim, out_dim_layer))
            self.projection_layers.append(nn.LayerNorm(out_dim_layer))
            if i < num_layers - 1:
                self.projection_layers.append(self.activation)
                self.projection_layers.append(nn.Dropout(dropout))
            in_dim = out_dim_layer

        self.final_reshape = nn.Linear(output_dim, seq_len * output_dim)

    def forward(self, x):
        B = x.shape[0]

        # Apply sequential linear and activation layers
        for layer in self.projection_layers:
            x = layer(x)

        # Final projection and reshape
        x = self.final_reshape(x)
        x = x.reshape(B, self.seq_len, self.output_dim)
        if self.use_l2:
            x = F.normalize(x, p = 2, dim = -1)
        return x * self.scale_up


In [None]:

import torch
import torch.nn.functional as F

# Assuming ImprovedAudioProjection is already imported

# Define input and output dimensions
input_dim = 1024
output_dim = 2048
seq_len = 24
num_layers = 3
dropout = 0.1
activation = 'gelu'

# Initialize the model
model = ImprovedAudioProjection(
    input_dim,
    output_dim,
    seq_len,
    num_layers=3,
    dropout=0.1,
    activation='gelu',
    use_l2=True,
    scale_up=1,
).cuda()

# if os.path.exists(f"../proj_seq_{seq_len}.pt"):
#     print("Loading model from checkpoint...")
#     model.load_state_dict(torch.load(f"../proj_seq_{seq_len}.pt"))

# Create random input tensor
batch_size = 4
input_tensor = torch.rand(batch_size, input_dim).cuda()

# Pass the input through the model
output = model(input_tensor)

# Print the output shape
print("Output shape:", output.shape)



In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('../runs/model_visualization_experiment')
writer.add_graph(model, input_tensor)
writer.close()

In [None]:
logits = torch.rand(2, 576, 2048)

logits_norm = torch.linalg.norm(logits.float(), 'fro', dim=(-2, -1)).mean().item()

audio_input_float = torch.rand(2, 64, 2048).float()

audio_input_mean = torch.mean(audio_input_float).item()
audio_input_std = torch.std(audio_input_float).item()
audio_input_norm = torch.linalg.norm(audio_input_float, 'fro', dim=(-2,-1))

In [None]:
audio_input_std

In [None]:
output.min()

# Training config

In [None]:
class TrainConfig:  # copy from hw-multimodal-llm-solved need to be changed
    log_level = "DEBUG"

    # Training
    num_epochs = 1
    train_batch_size = 2
    val_batch_size = 1
    log_grad_norm = True
    learning_rate = 1e-4
    gradient_accumulation_steps = 1

    evaluate_every_epoch_mod = 4
    save_model_every_epoch_mod = 1
    device = "cuda:0"

    # Model

    # Projector
    projector_input_dim = 1024

    # Data
    few_train_samples = None
    few_val_samples = 100
    dataloader_num_workers = 0

    train_dataset_path = ""
    audio_embeds_train_prefix = ""

    val_dataset_path = ""
    audio_embeds_val_prefix = ""


In [None]:
import yaml
import argparse
import pathlib
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas as pd
from torchvision.transforms import ToTensor

import logging
# import evaluate
# import datasets
from transformers.generation import GenerationConfig
from torch.optim import Adam
from torch.utils.data import DataLoader
import transformers
from tqdm.auto import tqdm
# import wandb
# from wandb import sdk as wandb_sdk
# import accelerate


In [None]:
model_path = f"deepseek-ai/Janus-Pro-1B"

prompt = "Abstract art for representing emotions"

vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True,
)
vl_gpt.language_model.config._attn_implementation = 'eager'

vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

In [None]:
conversation = [
    {
        "role": "User",
        "content": "Art for representing emotions",
    },
    {"role": "Assistant", "content": ""},
]

sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_start_tag

prompt_ids = vl_chat_processor.tokenizer.encode(prompt)
prompt_ids = torch.LongTensor(prompt_ids).cuda()
prompt_embeds = vl_gpt.language_model.get_input_embeddings()(prompt_ids).to(torch.bfloat16).unsqueeze(0)

In [None]:
def generate_sample(music_embedding, batched_prompt_embeds, image_token_num_per_image, processor, model, file_prefix):
    cfg_weight = 5
    temperature = 1
    img_size = 384
    patch_size = 16

    parallel_size = music_embedding.shape[0]
    conditional_embeds = torch.concat([music_embedding, batched_prompt_embeds], dim=1)
    unconditional_tokens = torch.zeros((1, conditional_embeds.shape[-2]), dtype=torch.int).cuda()
    unconditional_tokens[0, 1:-1] = processor.pad_id
    unconditional_embeds = model.language_model.get_input_embeddings()(unconditional_tokens)
    gen_input_embeds = torch.zeros(
        (conditional_embeds.shape[0] * 2, conditional_embeds.shape[1], conditional_embeds.shape[2]),
        dtype=torch.bfloat16).cuda()

    for i in range(parallel_size * 2):
        if i % 2 != 0:
            gen_input_embeds[i] = unconditional_embeds
        else:
            gen_input_embeds[i] = conditional_embeds[i // 2]

    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
    inputs_embeds = gen_input_embeds
    for i in range(image_token_num_per_image):
        outputs = model.language_model.model(inputs_embeds=inputs_embeds, use_cache=True,
                                             past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state

        logits = model.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]

        logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = model.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)

    dec = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
                                             shape=[parallel_size, 8, img_size // patch_size,
                                                    img_size // patch_size])
    dec = dec.to(torch.float32)
    dec = torch.clamp((dec + 1) / 2 * 255, min=0, max=255)

    visual_img = dec.cpu().numpy().transpose(0, 2, 3, 1).astype(np.uint8)

    os.makedirs('generated_samples', exist_ok=True)
    for i in range(parallel_size):
        save_path = os.path.join('generated_samples', "{}_{}.jpg".format(file_prefix, i))
        Image.fromarray(visual_img[i]).save(save_path)

    visual_img = dec.to(torch.uint8)
    return visual_img

In [None]:
music_embeds = torch.rand(1, 32, 2048, dtype=torch.bfloat16).cuda()
with torch.no_grad():
    generate_sample(music_embeds, prompt_embeds, 576, vl_chat_processor, vl_gpt, "image")

# Data loading

In [None]:
matched_df_path = "../data/notebooks/matched_dataset_concat.pkl"
matched_df = pd.read_pickle(matched_df_path)

In [None]:
class ImageAudioDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.dataframe["image_path"] = self.dataframe["image_path"].apply(
            lambda x: os.path.abspath(x.replace("\\", "/")))
        self.dataframe["audio_path"] = self.dataframe["audio_path"].apply(
            lambda x: os.path.abspath(x.replace("\\", "/")))
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx) -> dict:
        image_path = self.dataframe.iloc[idx]["image_path"]
        audio_path = self.dataframe.iloc[idx]["audio_path"]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        else:
            image = ToTensor()(image)
            # resize to 384x384
            image = nn.functional.interpolate(image.unsqueeze(0), size=(384, 384), mode='bilinear', align_corners=False)

        image_embedding = self.dataframe.iloc[idx]["image_embedding"]
        music_embedding = self.dataframe.iloc[idx]["music_embedding"]
        res = {
            "audio_path": audio_path,
            "image": image,
            "music_embedding": music_embedding,
            "image_embedding": image_embedding,
        }
        return res

In [None]:

dataset = ImageAudioDataset(matched_df.sample(n=10, random_state=42))

In [None]:
res = dataset[0]
res["image"].shape, res["music_embedding"].shape, res["image_embedding"].shape

In [None]:
from transformers.models.bart.modeling_bart import shift_tokens_right  # similar utility
from einops import rearrange


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.LongTensor:
    """
    Shifts input_ids one token to the right, prepending the decoder_start_token_id.
    This is a common way to create decoder_input_ids for teacher-forcing.
    """
    if not isinstance(input_ids, torch.Tensor):
        raise TypeError("input_ids should be a torch.Tensor")
    if not isinstance(pad_token_id, int):
        raise TypeError("pad_token_id should be an int")
    if not isinstance(decoder_start_token_id, int):
        raise TypeError("decoder_start_token_id should be an int")

    shifted_input_ids = torch.full_like(input_ids, pad_token_id)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    return shifted_input_ids.to(torch.long)



def get_collate_fn(gen_model: MultiModalityCausalLM):
    def collate_fn(items):
        result = dict()
        with torch.no_grad():
            images = torch.cat([item["image"] for item in items], dim=0)
            quant, _, info = gen_model.gen_vision_model.encode(
                images.to(dtype=torch.bfloat16).cuda())
            B, C, Hq, Wq = quant.shape
            _, _, min_encoding_indices = info
            image_ids = min_encoding_indices.view(B, Hq * Wq)
            image_ids = shift_tokens_right(image_ids, 0, 1).cuda()
            gen_embeds = gen_model.prepare_gen_img_embeds(image_ids)

        result["image_ids"] = image_ids.squeeze(-1)
        result["image_gen_embeds"] = gen_embeds
        result["music_embedding"] = torch.stack([torch.from_numpy(item["music_embedding"]) for item in items], dim=0)
        result["images"] = images

        return result

    return collate_fn

### Dataloader example

In [None]:
for epoch in range(TrainConfig.num_epochs):
    dataloader = DataLoader(dataset, batch_size=TrainConfig.train_batch_size, shuffle=True,
                            num_workers=TrainConfig.dataloader_num_workers,
                            collate_fn=get_collate_fn(vl_gpt)
                            )
    for batch_idx, batch in enumerate(dataloader):
        # Example: Access data from the batch
        print(batch["images"].shape)
        print(batch["image_gen_embeds"].shape)
        print(batch["image_ids"].shape)
        print(batch["music_embedding"].shape)

        # Simulate training step (e.g., forward pass, loss computation, backward pass)
        print(f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Processed {len([])} samples")
        break


In [None]:
batch["image_ids"].dtype
s_image_ids = shift_tokens_right(batch["image_ids"], 0, 0)
gen_embeds = vl_gpt.prepare_gen_img_embeds(s_image_ids)

In [None]:
batch["image_ids"].max()

In [None]:
vl_gpt.config.gen_head_config

In [None]:
batch["image_gen_embeds"].max()

# Training Loop

In [None]:
def train_loop(accelerator, model, processor, projection, optimizer, train_dataloader, epoch, criterion, train_config,
               metric_logger: None,
               mock_run: bool = False):
    model.eval()
    projection.train()
    progress_bar = tqdm(range(len(train_dataloader)), desc=f"Epoch {epoch}")
    total_loss = 0
    prompt_embeds = None

    if train_config.sys_prompt is not None:
        prompt_ids = processor.tokenizer.encode(train_config.sys_prompt)
        prompt_ids = torch.LongTensor(prompt_ids).cuda()
        prompt_embeds = model.language_model.get_input_embeddings()(prompt_ids).to(torch.bfloat16).unsqueeze(0)

    for batch in train_dataloader:
        with accelerator.accumulate(projection):
            B = audio_input.shape[0]

            music_embedding = batch["music_embedding"].to(model.device)
            audio_input = projection(music_embedding).to(torch.bfloat16)
            audio_attention_mask = torch.ones(B, audio_input.shape[1], dtype=torch.long)

            image_ids = batch["image_ids"]
            image_gen_embeds = batch["image_gen_embeds"].to(torch.bfloat16)
            image_attention_mask = torch.ones(B, image_gen_embeds.shape[1], dtype=torch.long)
            image_attention_mask[:, 0] = 0

            if prompt_embeds is not None:
                batched_prompt_embeds = prompt_embeds.repeat(B, 1, 1)
                prompt_attention_mask = torch.ones(B, batched_prompt_embeds.shape[1], dtype=torch.long)
                input_embeds = torch.concat([batched_prompt_embeds, audio_input, image_gen_embeds], dim=1)
                attention_mask = torch.concat([prompt_attention_mask, audio_attention_mask, image_attention_mask], dim=1)

            else:
                input_embeds = torch.concat([audio_input, image_gen_embeds], dim=1)
                attention_mask = torch.concat([audio_attention_mask, image_attention_mask], dim=1)

            if mock_run:
                hidden_states = torch.rand(input_embeds.shape).cuda().to(torch.bfloat16)
            else:
                outputs = model.language_model.model(inputs_embeds=input_embeds, use_cache=False, past_key_values=None,
                                                     decoder_input_ids=1)
                hidden_states = outputs.last_hidden_state

            logits = model.gen_head(hidden_states)
            logits = logits.permute(0, 2, 1)
            loss = criterion(logits[:, :, -576:], image_ids)
            total_loss += loss.item()
            step_metrics = {"train_loss": loss.item(), "epoch": epoch}

            model.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
            progress_bar.update(1)
            progress_bar.set_description(f'Epoch={epoch} Loss={loss.item():.3f}')

    average_loss = total_loss / len(train_dataloader)
    return average_loss

In [None]:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.Tensor:
    """
    Shifts input_ids one token to the right, prepending the decoder_start_token_id.
    This is a common way to create decoder_input_ids for teacher-forcing.
    """
    if not isinstance(input_ids, torch.Tensor):
        raise TypeError("input_ids should be a torch.Tensor")
    if not isinstance(pad_token_id, int):
        raise TypeError("pad_token_id should be an int")
    if not isinstance(decoder_start_token_id, int):
        raise TypeError("decoder_start_token_id should be an int")

    shifted_input_ids = torch.full_like(input_ids, pad_token_id)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    return shifted_input_ids

In [None]:
B = batch["music_embedding"].shape[0]
audio_input = torch.rand(B, 64, 2048).to(torch.bfloat16).cuda()
prompt_embeds = torch.rand(1, 64, 2048).to(torch.bfloat16).cuda()

music_embedding = batch["music_embedding"].to(vl_gpt.device)
audio_attention_mask = torch.ones(B, audio_input.shape[1], dtype=torch.long)

image_ids = batch["image_ids"]
image_gen_embeds = batch["image_gen_embeds"].to(torch.bfloat16)
image_attention_mask = torch.ones(B, image_gen_embeds.shape[1], dtype=torch.long)
image_attention_mask[:, 0] = 0

batched_prompt_embeds = prompt_embeds.repeat(B, 1, 1)
prompt_attention_mask = torch.ones(B, batched_prompt_embeds.shape[1], dtype=torch.long)
input_embeds = torch.concat([batched_prompt_embeds, audio_input, image_gen_embeds], dim=1)

attention_mask = torch.concat([prompt_attention_mask, audio_attention_mask, image_attention_mask], dim=1)
unmasked_outputs = vl_gpt.language_model.model(inputs_embeds=input_embeds, use_cache=False, past_key_values=None)

In [None]:
outputs.last_hidden_state.shape

In [None]:
# get the difference between unmasked_outputs.last_hidden_state and outputs.last_hidden_state
difference = torch.norm(unmasked_outputs.last_hidden_state - outputs.last_hidden_state)
difference

In [None]:
im = torch.rand([1, 50, 2048]).cuda().to(torch.bfloat16)

outputs = vl_gpt.language_model.model(inputs_embeds=im, attention_mask=torch.zeros(1,50,dtype=torch.long).cuda(), suse_cache=False, past_key_values=None)
outputs.last_hidden_state.detach()
 # 0.4258,  0.1797,  1.1797,  ...,  0.3301, -0.4980,  0.1426

# Val loop

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

In [None]:
@torch.no_grad()
def val_loop(model, processor, projection, val_dataloader, metrics: dict = None, epoch=1, no_loss=False,
             generate_freq=0):
    criterion = nn.CrossEntropyLoss(ignore_index=processor.pad_id)
    sumloss = 0
    num_batches = 0
    if not metrics:
        fid = FrechetInceptionDistance(feature=2048).to(model.device)
        inception_score = InceptionScore(feature='logits_unbiased', splits=10).to(model.device)
    else:
        fid = metrics['fid']
        inception_score = metrics['inception_score']

    # hardcoded values
    cfg_weight = 5
    temperature = 1
    img_size = 384
    patch_size = 16

    if not generate_freq:
        generate_freq = len(val_dataloader) + 1

    model.eval()
    projection.eval()
    i = 0
    for batch in tqdm(val_dataloader):
        batch_input_ids = batch['image_ids'].to(model.device)
        music_embedding = batch["music_embedding"].to(model.device)
        music_embedding = projection(music_embedding).to(torch.bfloat16)
        image_gen_embeds = batch["image_gen_embeds"].to(model.device).to(torch.bfloat16)

        input_embeds = torch.concat([music_embedding, image_gen_embeds], dim=1)

        if not no_loss:
            # outputs = model.language_model.model(inputs_embeds=input_embeds, use_cache=False, past_key_values=None,
            #                                      decoder_input_ids=1)
            # hidden_states = outputs.last_hidden_state
            hidden_states = torch.rand((1, 608, 2048)).cuda().to(torch.bfloat16)
            logits = model.gen_head(hidden_states)
            logits = logits.permute(0, 2, 1)
            loss = criterion(logits[:, :, -576:].cpu(), batch_input_ids.cpu())
            sumloss += loss.item()

        num_batches += 1

        # generate images and metrics

        parallel_size = music_embedding.shape[0]
        unconditional_tokens = torch.zeros((1, input_embeds.shape[-2]), dtype=torch.int).cuda()
        unconditional_tokens[0, 1:-1] = processor.pad_id
        unconditional_embeds = model.language_model.get_input_embeddings()(unconditional_tokens)
        gen_input_embeds = torch.zeros((input_embeds.shape[0] * 2, *input_embeds.shape[1:]),
                                       dtype=torch.bfloat16).cuda()

        for i in range(parallel_size * 2):
            if i % 2 != 0:
                gen_input_embeds[i] = unconditional_embeds
            else:
                gen_input_embeds[i] = input_embeds[i // 2]

        generated_tokens = torch.zeros((parallel_size, batch_input_ids.shape[-1]), dtype=torch.int).cuda()
        inputs_embeds = gen_input_embeds
        for i in range(batch_input_ids.shape[-1]):
            outputs = model.language_model.model(inputs_embeds=inputs_embeds, use_cache=True,
                                                 past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state
            logits = model.gen_head(hidden_states[:, -1, :])
            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]

            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = model.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

        dec = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
                                                 shape=[parallel_size, 8, img_size // patch_size,
                                                        img_size // patch_size])
        dec = dec.to(torch.float32)
        dec = torch.clamp((dec + 1) / 2 * 255, min=0, max=255)

        visual_img = dec.to(torch.uint8)
        target_images = batch["images"].cuda()

        if target_images.dtype != torch.uint8:
            target_images = (target_images * 255.0).to(torch.uint8)
        if target_images.shape[1] != 3:
            target_images = target_images.permute(0, 3, 1, 2)

        fid.update(target_images, real=True)
        fid.update(visual_img, real=False)
        inception_score.update(visual_img)

    val_res = {
        "loss": sumloss / num_batches if num_batches > 0 else 0,
        "num_batches": num_batches,
        # mock values
        "imagebind_sim": 0,
    }
    try:
        val_res["fid"] = metrics["fid"].compute()
        val_res["inception_score_mean"], val_res["inception_score_std"] = metrics["inception_score"].compute()
    except RuntimeError as e:
        val_res["fid"] = 0
        val_res["inception_score_mean"], val_res["inception_score_std"] = 0, 0

    return val_res

In [None]:
metrics = {
        "fid": FrechetInceptionDistance(feature=768).cuda(),
        "inception_score": InceptionScore(feature='logits_unbiased', splits=10).cuda()
    }

proj = AudioProjection(1024, 2048, scale_factor=2, sequal_len=2).cuda()
val_dataloader = DataLoader(dataset, batch_size=1, shuffle=True,
                            num_workers=TrainConfig.dataloader_num_workers,
                            collate_fn=get_collate_fn(vl_gpt)
                            )
val_loop(vl_gpt, vl_chat_processor, proj, val_dataloader, metrics=metrics)

In [None]:
# get samples in fid


In [None]:
def train(
        model: MultiModalityCausalLM,
        projection: AudioProjection,
        processor: VLChatProcessor,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        train_config: TrainConfig,
        device_placement=True,
        ):
    best_fid = 0

    trainable_parameters = list(projection.parameters())
    optimizer = Adam(trainable_parameters, lr=train_config.learning_rate)
    criterion = nn.CrossEntropyLoss()

    accelerator = accelerate.Accelerator(device_placement=device_placement)
    accelerator.gradient_accumulation_steps = train_config.gradient_accumulation_steps

    model, projection, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, projection, optimizer,
                                                                                         train_dataloader,
                                                                                         val_dataloader)

    metrics = {
        "fid": FrechetInceptionDistance(feature=2048).to(model.device),
        "inception_score": InceptionScore(feature='logits_unbiased', splits=10).to(model.device)
    }

    for epoch in range(train_config.num_epochs):
        # train_loop(accelerator, model, projection, optimizer, train_dataloader, epoch=epoch, criterion=criterion,
        #            train_config=train_config)

        if epoch % train_config.evaluate_every_epoch_mod == 0:
            print("Evaluating model for epoch: ", epoch)
            validation_metrics = val_loop(model, processor, projection, val_dataloader, epoch=epoch, metrics=metrics,
                                          generate_freq=1)

            final_fid_score = metrics["fid"].compute()
            final_is_mean, final_is_std =  metrics["inception_score"].compute()
            validation_metrics["fid"] = final_fid_score
            validation_metrics["is_mean"] = final_is_mean
            validation_metrics["is_std"] = final_is_std

            print(f"Epoch {epoch} validation metrics: {validation_metrics}")
            if final_fid_score < best_fid or best_fid == 0:
                best_fid = final_fid_score
                print("New best fid: ", best_fid)


def freeze_model(model):
    for p in model.parameters():
        p.requires_grad = False
    return


In [None]:
proj = AudioProjection(1024, 2048, scale_factor=2)
train_dataloader = DataLoader(dataset, batch_size=TrainConfig.train_batch_size, shuffle=True,
                            num_workers=TrainConfig.dataloader_num_workers,
                            collate_fn=get_collate_fn(vl_gpt)
                            )
#
# train(
#     model=vl_gpt,
#     projection=proj,
#     processor=vl_chat_processor,
#     train_dataloader=train_dataloader,
#     val_dataloader=None,
#     train_config=TrainConfig,
# )

In [None]:
accelerator = accelerate.Accelerator(device_placement=True)
trainable_parameters = list(proj.parameters())
model, projection, optimizer, train_dataloader, val_dataloader = accelerator.prepare(vl_gpt, proj, Adam(trainable_parameters, lr=0.0001),
                                                                                         train_dataloader,
                                                                                         val_dataloader)
accelerator.save_state(f"model_epoch_{epoch}.pt")

In [None]:
# empty cuda
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

torch.cuda.empty_cache()

import gc

gc.collect()

In [None]:
ra = torch.rand(1, 3, 384, 384)
images = torch.cat([ra], dim=0)
images.shape

In [None]:
class ConfigurableAudioProjection(nn.Module):
    def __init__(self,
                 input_dim: int,  # Dimension of the input audio features
                 embedding_dim: int,  # Dimension of the output embedding for each sequence element
                 sequence_length: int,  # Desired output sequence length
                 mlp_layers: int = 2,  # Number of layers in the MLP
                 mlp_expansion_factor: int = 2,  # Expansion factor for hidden layers in MLP
                 dropout: float = 0.1,
                 activation_fn: str = 'gelu',  # 'gelu', 'relu', 'leaky_relu', 'swish'
                 use_final_l2_norm: bool = True,  # Whether to L2 normalize the final output embeddings
                 use_learnable_scale: bool = True,  # If True, applies a learnable scalar after potential normalization
                 initial_scale: float = 1.0,  # Initial value for the learnable scale, or fixed scale if not learnable
                 mlp_outputs_sequence: bool = False
                 # If True, MLP directly outputs flattened sequence; if False, MLP outputs embedding_dim then expands
                 ):
        super().__init__()
        self.sequence_length = sequence_length
        self.embedding_dim = embedding_dim
        self.use_final_l2_norm = use_final_l2_norm
        self.use_learnable_scale = use_learnable_scale
        self.mlp_outputs_sequence = mlp_outputs_sequence

        if activation_fn.lower() == 'relu':
            self.activation = nn.ReLU()
        elif activation_fn.lower() == 'leaky_relu':
            self.activation = nn.LeakyReLU(0.2)
        elif activation_fn.lower() == 'swish' or activation_fn.lower() == 'silu':
            self.activation = nn.SiLU()
        else:  # Default to GELU
            self.activation = nn.GELU()

        self.projection_mlp = nn.ModuleList()
        current_dim = input_dim

        if not self.mlp_outputs_sequence:
            for i in range(mlp_layers):
                is_last_mlp_layer = (i == mlp_layers - 1)
                layer_output_dim = embedding_dim if is_last_mlp_layer else embedding_dim * mlp_expansion_factor

                self.projection_mlp.append(nn.Linear(current_dim, layer_output_dim))
                self.projection_mlp.append(nn.LayerNorm(layer_output_dim))
                if not is_last_mlp_layer:  # No activation/dropout after the final MLP layer output
                    self.projection_mlp.append(self.activation)
                    self.projection_mlp.append(nn.Dropout(dropout))
                current_dim = layer_output_dim

            self.sequence_expansion_layer = nn.Linear(current_dim, sequence_length * embedding_dim)
        else:
            # Path B: MLP directly projects the input to the full flattened sequence dimension.
            # The last layer of the MLP will output sequence_length * embedding_dim.
            final_flattened_dim = sequence_length * embedding_dim
            for i in range(mlp_layers):
                is_last_mlp_layer = (i == mlp_layers - 1)
                # For intermediate layers, hidden_dim can be an expansion of input or final_flattened_dim
                # For the last layer, output is final_flattened_dim
                if not is_last_mlp_layer:
                    # Heuristic: expand based on a factor of the larger of input or eventual flat output
                    hidden_dim_base = max(input_dim, final_flattened_dim)
                    layer_output_dim = int(hidden_dim_base * mlp_expansion_factor)
                else:
                    layer_output_dim = final_flattened_dim

                self.projection_mlp.append(nn.Linear(current_dim, layer_output_dim))
                self.projection_mlp.append(nn.LayerNorm(layer_output_dim))
                if not is_last_mlp_layer:
                    self.projection_mlp.append(self.activation)
                    self.projection_mlp.append(nn.Dropout(dropout))
                current_dim = layer_output_dim
            self.sequence_expansion_layer = None  # MLP directly outputs the final flat shape

        # Setup scaling parameter
        if self.use_learnable_scale:
            self.scale_param = nn.Parameter(torch.tensor(float(initial_scale)))
        else:
            # If not learnable, register as a buffer if it's not 1.0 (to avoid unnecessary multiplication by 1.0)
            if float(initial_scale) != 1.0:
                self.register_buffer('scale_param', torch.tensor(float(initial_scale)))
            else:
                self.scale_param = 1.0  # Will be a float, not a tensor or parameter

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is expected to be [B, input_dim]
        B = x.shape[0]

        # Pass through MLP
        for layer in self.projection_mlp:
            x = layer(x)

        # If MLP didn't output the sequence directly, use the expansion layer
        if not self.mlp_outputs_sequence and self.sequence_expansion_layer is not None:
            x = self.sequence_expansion_layer(x)
            # x is now [B, sequence_length * embedding_dim]

        # Reshape to [B, sequence_length, embedding_dim]
        try:
            x = x.view(B, self.sequence_length, self.embedding_dim)
        except RuntimeError as e:
            raise RuntimeError(
                f"Failed to reshape tensor of shape {x.shape} to ({B}, {self.sequence_length}, {self.embedding_dim}). Original error: {e}")

        if self.use_final_l2_norm:
            x = F.normalize(x, p=2, dim=-1)

        # Apply scaling (learnable or fixed)
        # If self.scale_param is 1.0 (float), this multiplication is trivial but harmless.
        x = x * self.scale_param

        return x