In [3]:
%%writefile multimodel_train.py
import math
from typing import List, Optional, Tuple, Union
import sys
from PIL import Image
from pathlib import Path
import torch
import os
from accelerate import Accelerator
from torch import nn
from transformers import (
    AutoModel,
    AutoTokenizer,
    AutoConfig,
    RobertaModel,
    RobertaConfig,
)
import sys
from sklearn.model_selection import train_test_split
import os
from tqdm.notebook import tqdm
import numpy as np
import torch.nn as nn
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Tuple
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers.tokenization_utils_base import AddedToken
from transformers import AutoModel
from accelerate import Accelerator, DistributedDataParallelKwargs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings

warnings.simplefilter("ignore")
from transformers import ViTFeatureExtractor, ViTModel
from einops import rearrange
class RotaryEmbedding(nn.Module):
    """
    RotaryEmbedding is a PyTorch module that implements rotary positional embeddings for attention mechanisms.
    Args:
        config (object): Configuration object containing the following attributes:
            hidden_size (int): The hidden size of the model.
            num_attention_heads (int): The number of attention heads.
    Attributes:
        inv_freq (torch.Tensor): A tensor containing the inverse frequencies for the rotary embeddings.
    Methods:
        forward(seq_len):
            Computes the rotary positional embeddings for a given sequence length.
            Args:
                seq_len (int): The length of the input sequence.
            Returns:
                torch.Tensor: A tensor containing the rotary positional embeddings with shape (1, seq_len, dim).
    """

    def __init__(self, config):
        super().__init__()
        dim = int(config.hidden_size // config.num_attention_heads)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
        freqs = torch.einsum("i, j -> i j", t, self.inv_freq)

        return freqs[None, :, :]


# Copied from transformers
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, freqs, unsqueeze_dim=1) -> Tuple[torch.Tensor]:
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        freqs: precalculated frqs for sin cos
        only_q: bool = False for encoder decoder
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = emb.cos().to(dtype=q.dtype)
    sin = emb.sin().to(dtype=q.dtype)
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q * cos) + (rotate_half(q) * sin)

    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
_ACT_ = {
    "gelu": nn.GELU(),
    "leaky_relu": nn.LeakyReLU(),
    "relu6": nn.ReLU6(),
    "sigmoid": nn.Sigmoid(),
    "silu": nn.SiLU(),
    "swish": nn.SiLU(),
    "tanh": nn.Tanh(),
}


class FeedForward(nn.Module):
    def __init__(self, config, multiplier: Union[int, float] = 4) -> None:
        super().__init__()
        intermediate_size = getattr(config, "intermediate_size", None)
        self.intermediate_size = (
            int(multiplier) * config.hidden_size
            if intermediate_size is None
            else intermediate_size
        )

        self.intermediate = nn.Linear(config.hidden_size, self.intermediate_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        if _ACT_.get(getattr(config, "hidden_act", None), None):
            self.act_fn = _ACT_[config.hidden_act]
        else:
            self.act_fn = nn.GELU()
        self.out = nn.Linear(self.intermediate_size, config.hidden_size)

    def forward(
        self, hidden_state: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        output = self.intermediate(hidden_state)
        output = self.act_fn(output)
        output = self.out(output)
        output = self.dropout(output)
        output = self.layernorm(output + input_tensor)
        return output
class AbsoluteEncoding(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.pos_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=getattr(config, "pad_token_id", None),
        )
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )
        self.max_size = config.max_position_embeddings

    def forward(self, size: int) -> torch.Tensor:
        if self.max_size < size:
            raise ValueError(
                f"The hidden size ({size }) is more than the config max_position_embeddings {self.max_size}"
            )
        return self.pos_embeddings(self.position_ids[:, :size])


_position_embeddings = {"absolute": AbsoluteEncoding}
class AttentionSelfOutput(nn.Module):
    def __init__(self, config, bias: Optional[bool] = True):
        super().__init__()
        self.dense = nn.Linear(
            config.hidden_size,
            config.hidden_size,
            bias=bias,
        )
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states + input_tensor)
        return hidden_states


class DecoderAttention(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.attention_bias = getattr(config, "attention_bias", True)
        self.layer_idx = layer_idx
        # self.qkv = nn.Linear(config.hidden_size,3*config.hidden_size)
        self.query = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.key = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.value = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.out = AttentionSelfOutput(config=config, bias=self.attention_bias)
        self.num_attention_heads = config.num_attention_heads
        self.scaling = 1 / math.sqrt(self.head_size)

        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if not self.flash and self.layer_idx == 0:  # avoid to print m times:
            print("WARNING: Flash Attention requires PyTorch >= 2.0")

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ) -> Tuple[torch.Tensor, object]:
        q = self.query(hidden_state)
        k = self.key(hidden_state)
        v = self.value(hidden_state)
        # q,k,v = self.qkv(hidden_state).chunk(3, dim = -1) #b X l X d dim =-1 or 2
        # place holder for RoPe operation
        q = rearrange(q, "b l (h d) -> b h l d", h=self.num_attention_heads)
        k = rearrange(k, "b l (h d) -> b h l d", h=self.num_attention_heads)
        v = rearrange(v, "b l (h d) -> b h l d", h=self.num_attention_heads)

        if freqs is not None:
            q, k = apply_rotary_pos_emb(q, k, freqs)

        if use_cache:
            cache = getattr(self, "cache", None)
            if cache is None:
                raise ValueError(
                    "you need to setup cache for every attention layer with model.setup_cache()"
                )
            k, v = cache.update(k, v, start_pos)

        attn_weights = torch.einsum("b h i d , b h j d -> b h i j", q, k) * self.scaling

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : k.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
        # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        out = torch.einsum("b h i j , b h j d -> b h i d", attn_weights, v)

        out = rearrange(out, "b h l d -> b l (h d)")

        return self.out(out, hidden_state)
class DecoderLayer(nn.Module):
    def __init__(self, config, layer_idx: int = 0, attention_type: str = None) -> None:
        super().__init__()
        self.attention = (
            DecoderAttentionGqa(config, layer_idx=layer_idx)
            if attention_type == "gqa"
            else DecoderAttention(config, layer_idx=layer_idx)
        )
        if attention_type == "gqa" and layer_idx == 0:  # avoid to print m times
            print("Decoder Using GQA Attention")
        if attention_type == "gqa" and layer_idx == 0:  # avoid to print m times
            print("Using GQA in Cross Attention")
        self.feed_forward = FeedForward(config)
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        out = self.attention(
            hidden_state=hidden_state,
            attention_mask=attention_mask,
            freqs=freqs,
            use_cache=use_cache,
            start_pos=start_pos,
        )

        out = self.feed_forward(out, hidden_state)
        return out


class LMHead(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.vocab = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.vocab.bias = self.bias

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_state)
        x = nn.GELU()(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.vocab(x)

        return x
# mainly 2way to do one keep it into the model init like llama https://github.com/meta-llama/llama/blob/main/llama/model.py
# every attention layer have its own kv-cache storage
# or keep all attention layer kv-cache into single storage like Huggingface Transformer


from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Tuple
import torch


class DynamicCache:
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self, config, is_gqa: Optional[bool] = False) -> None:
        self.key_cache: torch.Tensor = None
        self.value_cache: torch.Tensor = None
        self._seen_tokens = False
        self.maxlen = config.max_position_embeddings

    def __len__(self) -> int:
        if self.key_cache is None:
            return 0
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return self.key_cache.shape[-2]

    def update(
        self, key_states: torch.Tensor, value_states: torch.Tensor, start_pos: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """

        # Update the cache first iteration
        if self.key_cache is None:
            self._seen_tokens = True
            self.key_cache = key_states.clone()
            self.value_cache = value_states.clone()
        else:
            self.key_cache = torch.cat([self.key_cache, key_states], dim=-2)
            self.value_cache = torch.cat([self.value_cache, value_states], dim=-2)

        return self.key_cache, self.value_cache

    def get(self) -> Tuple[torch.Tensor]:
        if self._seen_tokens:
            return self.key_cache, self.value_cache
        else:
            raise ValueError("there is no token available in kv-cache")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if self.key_cache is None:
            return 0
        return self.key_cache.shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return None


class StaticCache:
    """
    A cache that is size fixed suitable for torch.compile

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self, config, is_gqa: Optional[bool] = False) -> None:
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.heads = None
        if is_gqa:
            self.heads = getattr(config, "num_key_value_heads", None)
            if self.heads is None:
                raise ValueError(
                    "you are using is_gqa=True and config.num_key_value_heads is not available"
                )
        if self.heads is None:

            self.heads = config.num_attention_heads

        self.key_cache: torch.Tensor = torch.zeros(
            1,
            self.heads,
            config.max_position_embeddings,
            self.head_size,
        )
        self.value_cache: torch.Tensor = torch.zeros(
            1,
            self.heads,
            config.max_position_embeddings,
            self.head_size,
        )
        self._seen_tokens = False

    def update(
        self, k: torch.Tensor, v: torch.Tensor, start_pos: int = 0
    ) -> Tuple[torch.Tensor]:
        self._seen_tokens = True
        bsz, head, seqlen, _ = k.shape
        self.first_update_len = seqlen
        if seqlen > self.key_cache.size()[2]:
            raise ValueError(
                f"{k.shape} is more than init k_cache size {self.key_cache}"
            )

        assert bsz == 1, "Only support batch size 1"

        self.key_cache = self.key_cache.to(k)
        self.value_cache = self.value_cache.to(v)

        self.key_cache[:bsz, :, start_pos : start_pos + seqlen] = k
        self.value_cache[:bsz, :, start_pos : start_pos + seqlen] = v

        k = self.key_cache[:bsz, :, : start_pos + seqlen]
        v = self.value_cache[:bsz, :, : start_pos + seqlen]

        return k, v

    def get(self) -> Tuple[torch.Tensor]:
        if self._seen_tokens:
            k = self.key_cache[:, :, : self.first_update_len]
            v = self.value_cache[:, :, : self.first_update_len]

            return k, v
        else:
            raise ValueError("there is no token available in kv-cache")

    def __len__(self) -> int:
        if self._seen_tokens == False:
            return 0
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return self.key_cache.shape[2]
@dataclass
class DecoderOutput(object):
    logits: torch.Tensor


class LMHead(nn.Module):
    """Head for masked language modelling"""

    def __init__(self, config) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.vocab = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.vocab.bias = self.bias

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_state)
        x = nn.GELU()(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.vocab(x)

        return x


class DecoderModel(nn.Module):
    def __init__(
        self,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> None:
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            padding_idx=getattr(config, "pad_token_id", None),
        )
        if _position_embeddings.get(pos_embedding_type, None) is not None:
            self.position_embeddings = _position_embeddings.get(pos_embedding_type)(
                config
            )
        else:
            self.position_embeddings = None
        if pos_embedding_type == "rope":
            self.emb_freq = RotaryEmbedding(config)(config.max_position_embeddings)
            print(
                "Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable"
            )
        self.all_layer = nn.ModuleList(
            [
                DecoderLayer(config, layer_idx, attention_type=attention_type)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        _bsz, seqlen, _ = hidden_state.shape

        if self.position_embeddings is not None:
            pos_info = pos_info = self.position_embeddings(start_pos + seqlen)[
                :, start_pos : start_pos + seqlen, :
            ].to(inputs_embeds.device)
            hidden_state = inputs_embeds + pos_info
        else:
            freqs = self.emb_freq[:, start_pos : start_pos + seqlen].to(
                hidden_state.device
            )

        for layer in self.all_layer:
            hidden_state = layer(
                hidden_state=hidden_state,
                attention_mask=attention_mask,
                freqs=freqs,
                use_cache=use_cache,
                start_pos=start_pos,
            )
        return hidden_state

    @classmethod
    def from_config(cls, config) -> nn.Module:
        return cls(config)
def _update_causal_mask(
    attention_mask,
    token_type_ids,
    inputs_embeds,
    cache_position,
    is_training: bool = False,
):


    dtype = inputs_embeds.dtype


    min_dtype = torch.finfo(dtype).min


    sequence_length = inputs_embeds.shape[1]


    target_length = attention_mask.shape[-1]

    causal_mask = torch.full(
        (sequence_length, target_length),
        fill_value=min_dtype,
        dtype=dtype,
        device=inputs_embeds.device,
    )


    # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
    if sequence_length != 1:
        if is_training:


            causal_mask = torch.triu(causal_mask, diagonal=1)
        else:


            causal_mask[:, :sequence_length] = 0.0  # when using kv-cache


    causal_mask *= torch.arange(
        target_length, device=inputs_embeds.device
    ) > cache_position.reshape(-1, 1)


    causal_mask = causal_mask[None, None, :, :].expand(
        inputs_embeds.shape[0], 1, -1, -1
    )


    if attention_mask is not None:


        causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit


        mask_length = attention_mask.shape[-1]


        padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
            :, None, None, :
        ].to(causal_mask.device)


        padding_mask = padding_mask == 0


        causal_mask[:, :, :, :mask_length] = causal_mask[
            :, :, :, :mask_length
        ].masked_fill(padding_mask, min_dtype)
    return causal_mask
class VisionLanguageModel(nn.Module):

    def __init__(
        self,
        encoder,
        decoder_config,
        decoder_pos_embedding_type: Optional[str] = "absolute",
        decoder_attention_type: str = None,
    ) -> None:
        super().__init__()
        self.is_gqa = True if decoder_attention_type == "gqa" else False
        self.encoder = encoder
        self.decoder = DecoderModel(
            config=decoder_config,
            pos_embedding_type=decoder_pos_embedding_type,
            attention_type=decoder_attention_type,
        )
        self.lm_head = LMHead(config=decoder_config)
        self.image_token_index = 128001

    def forward(
        self,
        pixel_values: Optional[torch.LongTensor] = None,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ):

        inputs_embeds = self.decoder.word_embeddings(input_ids)
        is_training = use_cache == False and token_type_ids is not None
        cache_position = None
        if is_training == True:
            cache_position = torch.arange(
                inputs_embeds.shape[1], device=inputs_embeds.device
            )

        else:

            if pixel_values is not None:
                cache_position = torch.arange(
                    inputs_embeds.shape[1], device=inputs_embeds.device
                )

            else:
                past_seen_tokens = (
                    start_pos  # previous stored information attn_mask.size-1
                )
                cache_position = torch.arange(
                    past_seen_tokens,
                    past_seen_tokens + inputs_embeds.shape[1],
                    device=inputs_embeds.device,
                )

        if pixel_values is not None:
            image_features = self.get_encoder_output(pixel_values)

            special_image_mask = (input_ids == self.image_token_index).unsqueeze(-1)
            special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
                inputs_embeds.device
            )

            image_features = image_features.to(
                inputs_embeds.device, inputs_embeds.dtype
            )
            inputs_embeds = inputs_embeds.masked_scatter(
                special_image_mask, image_features
            )

        # if position_ids is None:
        position_ids = cache_position.unsqueeze(0)  # +1 not improving

        causal_mask = _update_causal_mask(
            attention_mask, token_type_ids, inputs_embeds, cache_position, is_training
        )

        decoder_output = self.decoder(
            hidden_state=inputs_embeds,
            attention_mask=causal_mask,
            use_cache=use_cache,
            start_pos=start_pos,
        )
        decoder_output = self.lm_head(decoder_output)
        return DecoderOutput(logits=decoder_output)

    def get_decoder(self) -> DecoderModel:
        return self.decoder

    def get_encoder_output(self, pixel_values: torch.Tensor) -> object:
        return self.encoder(pixel_values=pixel_values).last_hidden_state

    def _setup_cache(self, config, cls: Optional[object] = StaticCache) -> None:
        for layer in self.decoder.all_layer:
            layer.attention.cache = cls(config, is_gqa=self.is_gqa)

    def _clean_cache(self) -> None:
        for layer in self.decoder.all_layer:
            layer.attention.cache = None

    @classmethod
    def from_config(
        cls,
        encoder,
        decoder_config,
        decoder_pos_embedding_type: Optional[str] = "absolute",
        decoder_attention_type: str = None,
    ) -> nn.Module:
        return cls(
            encoder,
            decoder_config,
            decoder_pos_embedding_type,
            decoder_attention_type,
        )

def build_string_from_input(
    prompt,
    bos_token,
    image_seq_len,
    image_token,
    num_images=1,
):


    return f"{prompt}{image_token * image_seq_len * num_images}"



def get_model_inputs(prompt,tokenizer, suffix=None, max_length=248):


    return_token_type_ids = False


    if suffix:
        suffix = suffix


        return_token_type_ids = True

    image_seq_length = 197  #
    
    IMAGE_TOKEN = "<image>"
    input_string = build_string_from_input(prompt,tokenizer.bos_token,image_seq_length,IMAGE_TOKEN)


    return_token_type_ids = True if suffix is not None else False
    


    inputs = tokenizer(
        input_string,
        text_pair=suffix,
        return_token_type_ids=return_token_type_ids,
        padding="max_length",
        add_special_tokens=False,
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    return inputs

class ImgDataset(Dataset):
    def __init__(self, df, tokenizer, feature_extractor):
        self.df = df
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor

    def __len__(
        self,
    ):
        return len(self.df)

    def __getitem__(self, idx):
        caption = self.df.caption.iloc[idx] + self.tokenizer.eos_token
        img_path = self.df.image.iloc[idx]
        img = Image.open(img_path).convert("RGB")
        prompt = self.tokenizer.bos_token + "<Caption>"
        
        inputs = get_model_inputs(prompt,self.tokenizer, caption)
        pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values
        pixel_values = pixel_values #.to(dtype)
       
        inputs_ = {"pixel_values": pixel_values.squeeze()}
        for k, v in inputs.items():
            inputs_[k] = v.squeeze()

       
        return inputs_



def loss_fn(logits, labels, attention_mask,config):
    shift_logits = logits[..., :-1, :]
    shift_labels = labels[..., 1:]
    if attention_mask is not None:
        # we use the input attention mask to shift the logits and labels, because it is 2D.
        # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
        shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
            logits.device
        )
        shift_logits = shift_logits[
            shift_attention_mask.to(logits.device) != 0
        ].contiguous()
        shift_labels = shift_labels[
            shift_attention_mask.to(shift_labels.device) != 0
        ].contiguous()
    else:
        shift_logits = shift_logits.contiguous()
        shift_labels = shift_labels.contiguous()

    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)

    flat_logits = shift_logits.view(-1, config.vocab_size)
    flat_labels = shift_labels.view(-1)
    loss = loss_fct(flat_logits, flat_labels)
    return loss


def prep(batch_size=8):
    df = pd.read_csv("../input/flickr30k/captions.txt")
    df["image"] = "../input/flickr30k/Images/" + df["image"]
    BASE_PATH = "../input/coco-image-caption"
    import json
    
    with open(
        f"{BASE_PATH}/annotations_trainval2014/annotations/captions_train2014.json", "r"
    ) as f:
        data = json.load(f)
        data = data["annotations"]
    
    img_cap_pairs = []
    
    for sample in data:
        img_name = "%012d.jpg" % sample["image_id"]
        img_cap_pairs.append([img_name, sample["caption"]])
    
    df1 = pd.DataFrame(img_cap_pairs, columns=["image", "caption"])
    df1["image"] = df1["image"].apply(
        lambda x: f"{BASE_PATH}/train2014/train2014/COCO_train2014_{x}"
    )
    # captions = captions.sample(70000)
    df1 = df1.reset_index(drop=True)
    df1.head()
    df = pd.concat([df, df1])
    
    df = df.sample(frac=1)
    df = df.dropna()
    
    feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
    model_ckpt = "microsoft/deberta-v3-base"

    tokenizer = AutoTokenizer.from_pretrained(
        model_ckpt, truncation_side="right"
    )
    
    
    image_seq_length = 197  #
    
    
    IMAGE_TOKEN = "<image>"
    
    
    image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
    
    
    tokens_to_add = {"additional_special_tokens": [image_token]}
    
    
    tokenizer.add_special_tokens(tokens_to_add)
    
    
    image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
    vit_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224")
    
    
    model = AutoModel.from_pretrained(model_ckpt)
    model.resize_token_embeddings(len(tokenizer))
    config = model.config
    config.num_hidden_layers = 8
    multimodel = VisionLanguageModel(vit_encoder, config, decoder_pos_embedding_type="rope")
    multimodel.decoder.word_embeddings.weight = model.embeddings.word_embeddings.weight
    # df.head()
    train_dataset = ImgDataset(
        df,
        tokenizer=tokenizer,
        feature_extractor=feature_extractor,
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return train_loader ,multimodel, config
    

def main():
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
    
    
    device = accelerator.device
    
    if accelerator.is_main_process:
        accelerator.print(f"Using device: {device}")
        accelerator.print(f"Number of processes: {accelerator.num_processes}")
        accelerator.print(f"Distributed type: {accelerator.distributed_type}")
    
    epochs = 1
    learning_rate = 1e-5
    weight_decay = 0.0
    
    base_dir = "path"
    save_dir = os.path.join(base_dir, "checkpoints")
    samples_dir = os.path.join(base_dir, "samples")
    
    # main process is from accelerator
    if accelerator.is_main_process:
        os.makedirs(save_dir, exist_ok=True)
        accelerator.print(f"Saving checkpoints to: {save_dir}")
       

    
    
    if accelerator.is_main_process:
        accelerator.print("Initializing Dataloader and Model...")
        
    train_loader, model, config = prep()
    
    
    optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate,weight_decay=weight_decay)
    
    model, optimizer, train_loader= accelerator.prepare(model, optimizer, train_loader)
    
    
    if accelerator.is_main_process:
        accelerator.print("Starting training...")
    global_step = 0
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}",disable=not accelerator.is_main_process, file=sys.stdout) #
        for data in progress_bar:
            labels = data["input_ids"].masked_fill(data["input_ids"] == 128001, -100)
            labels = torch.where(data["input_ids"] == config.pad_token_id, -100, labels)
          
            
            with accelerator.accumulate(model):
                pred = model(**data)
                loss = loss_fn(pred.logits, labels, data["attention_mask"],config)
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                
                optimizer.step()
                optimizer.zero_grad()
            
            loss_value = loss.detach().float().item()
            epoch_loss += loss_value
            
            if accelerator.is_main_process:
                progress_bar.set_postfix({"loss": loss_value, "lr": optimizer.param_groups[0]['lr']})
            
            global_step += 1

        gathered_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device)).sum().item()
        avg_loss = gathered_epoch_loss / (len(train_loader) * accelerator.num_processes)
        
        if accelerator.is_main_process and global_step%5==0:
            accelerator.print(f" step -{global_step+1}--loss--{epoch_loss}")

        accelerator.print(f" step -{global_step+1}--loss--{epoch_loss}")
        
        
        if accelerator.is_main_process:
            accelerator.print(f"Epoch {epoch+1} : loss = {avg_loss}")
        
        if accelerator.is_main_process and epoch == epochs - 1:
            unwrapped_model = accelerator.unwrap_model(model)
            torch.save({
                'model_state_dict': unwrapped_model.state_dict()
                
            }, 
            f"{save_dir}/multimodel_rope_{epoch+1}.pt")
            accelerator.print(f"Checkpoint saved at epoch {epoch+1}")
        
        accelerator.wait_for_everyone()
    
    if accelerator.is_main_process:
        accelerator.end_training()
        print("Training completed!")

if __name__ == "__main__":
    main()


Overwriting multimodel_train.py


In [None]:
! accelerate launch --num_processes=2 ../working/multimodel_train.py

2025-07-13 14:21:07.598200: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-13 14:21:07.598199: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752416467.766157      98 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752416467.766146      99 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752416467.816873      98 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
E0000 00:00:1752416467.816883      99 cuda_blas.cc:1