In [6]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
#lmmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

In [7]:
tokenizer

PreTrainedTokenizerFast(name_or_path='meta-llama/Llama-3.2-1B', vocab_size=128000, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128003: AddedToken("<|reserved_special_token_1|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128004: AddedToken("<|finetune_right_pad_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128005: AddedToken("<|reserved_special_token_2|>", rstri

In [5]:
# Load model directly
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
asrmodel = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v2")


In [6]:
asrmodel

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0-31): 32 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias

In [7]:
lmmodel

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [1]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import (
    LlamaRotaryEmbedding,
    apply_rotary_pos_emb,
    rotate_half,
)
from transformers.models.llama.configuration_llama import LlamaConfig



class WhisperEncoderLlamaDecoder(nn.Module):
    def __init__(self, 
                 freeze_whisper_encoder: bool = False,
                 freeze_llama_decoder: bool = False, config=LlamaConfig()):
        super().__init__()

        # 1. Whisper
        self.whisper = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v2")
        # Whisper에서 encoder만 추출
        self.encoder = self.whisper.model.encoder

        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        # 2. Llama
        self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
        # llama.model -> LlamaModel, 그 안의 layers는 전체 Decoder Layers
        llama_layers = self.llama.model.layers

        # 3. Llama 디코더에서 맨 앞 2개 + 맨 뒤 2개 layer만 추출
        self.decoder_layers = nn.ModuleList([
            llama_layers[0],
            llama_layers[1],
            llama_layers[-2],
            llama_layers[-1],
        ])
        
        # Llama의 token embedding, RMSNorm 등도 필요하면 가져와야 함
        self.embed_tokens = self.llama.model.embed_tokens
        self.norm = self.llama.model.norm  # LlamaRMSNorm
        # lm_head 도 사용하려면 self.llama.lm_head 를 쓸 수도 있음
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        # 4. 차원 불일치 보정용 브릿지 레이어(Whisper enc -> Llama dec)
        #    Whisper encoder는 hidden_size=1280, Llama는 2048
        whisper_hidden_size = 1280
        llama_hidden_size = 2048
        
        self.bridge = nn.Linear(whisper_hidden_size, llama_hidden_size)
        

        # 5. (옵션) 학습 고정
        if freeze_whisper_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        if freeze_llama_decoder:
            for param in self.decoder_layers.parameters():
                param.requires_grad = False

    def _prepare_decoder_position_ids(self, seq_len: int, device: torch.device) -> torch.LongTensor:
        """
        LlamaModel.forward() 내부 아이디어를 참고하여,
        (1, seq_len) 형태의 position_ids를 생성합니다.
        
        예) seq_len=10 -> 텐서([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
        """
        return torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0)

    def forward(self, 
                input_features: torch.Tensor, 
                #decoder_input_ids: torch.Tensor,
                attention_mask: torch.Tensor = None,
                labels=None,):
        """
        input_features: Whisper encoder에 들어갈 audio feature (batch, mel_bins, frames)
        decoder_input_ids: Llama decoder에 들어갈 텍스트 토큰 (batch, seq_len)
        attention_mask: 필요 시 디코더용 마스크
        
        반환: 최종 출력(예: logits)
        """

        # 1. Whisper encoder 전방향
        #    WhisperForConditionalGeneration의 input_features 모양에 맞춤
        #    (batch_size, feature_size, sequence_length) or (batch_size, sequence_length, feature_size)에 따라 다름
        encoder_outputs = self.encoder(input_features=input_features, return_dict=True)
        print("Whisper encoder 지남")
        # encoder_outputs: BaseModelOutput 클래스(또는 비슷한 dict)
        # encoder_outputs.last_hidden_state: (batch_size, seq_len, hidden_size=1280)
        encoder_hidden_states = encoder_outputs.last_hidden_state
        print("Encoder output shape:", encoder_hidden_states.shape)
        # 기대: (batch_size, audio_seq_len, 1280)

        # 2. 차원 브릿지
        encoder_hidden_states = self.bridge(encoder_hidden_states)
        print("After bridge shape:", encoder_hidden_states.shape)
        # 기대: (batch_size, audio_seq_len, 2048)
        # 3. Llama token 임베딩
        #    decoder_input_ids -> (batch_size, seq_len) => embed_tokens -> (batch_size, seq_len, hidden_size)

        if labels is not None:  # 훈련 시
            decoder_input_ids = labels[:, :-1]  # 정답의 마지막 토큰 제외 (Teacher Forcing)
        else:  # 추론 시
            decoder_input_ids = torch.full(
                (input_features.size(0), 1),  # (batch_size, 1)
                fill_value=self.llama.config.bos_token_id,
                dtype=torch.long,
                device=input_features.device,
        )
        hidden_states = self.embed_tokens(decoder_input_ids)
        print("Decoder token emb shape:", hidden_states.shape)
        # 기대: (batch_size, tgt_seq_len, 2048)
        # 여기서는 “encoder-decoder 구조”라기보다는, Llama의 일부 레이어를 순서대로 통과시킨다고만 간단히 가정
        # 실제로는 cross-attention 추가, causal mask, rotary embedding, position ids 등 처리가 필요함
        # 4. rotary embedding 준비
        seq_len = hidden_states.size(1)
        position_ids = self._prepare_decoder_position_ids(seq_len, device=hidden_states.device)
        # rotary_emb.forward(x=hidden_states, position_ids=position_ids) -> (cos, sin)
        cos, sin = self.rotary_emb(hidden_states, position_ids)


            # 4. Llama의 일부 레이어(맨 앞 2 + 맨 뒤 2개) 순회
        #    layer() 호출 결과는 (hidden_states, ...) 형태의 튜플이므로, 첫 번째 요소만 받아 다음 레이어로 전달
        for layer in self.decoder_layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=None,          # causal mask 대신 None (실제 사용하려면 mask 필요)
                position_ids=None,            # 이미 cos, sin 구했으므로 따로 전달X
                position_embeddings=(cos, sin),  # rotary embedding 직접 주입
                use_cache=False,
                output_attentions=False,
                encoder_hidden_states=None,   # vanilla Llama에는 cross-attn 없음
            )
            # LlamaDecoderLayer의 결과에서 hidden_states만 추출
            hidden_states = layer_outputs[0]

        # 4. Llama 마지막 norm
        hidden_states = self.norm(hidden_states)

        # 5. lm_head를 통해 vocab logits 예측 (선택)
        logits = self.llama.lm_head(hidden_states)

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.llama.config.pad_token_id)
            # logits: (batch_size, seq_len, vocab_size)
            # labels: (batch_size, seq_len)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
            return {"loss": loss, "logits": logits}
    
            
        return {"logits": logits}



  from .autonotebook import tqdm as notebook_tqdm


In [9]:
lmmodel.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "meta-llama/Llama-3.2-1B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.1",
  "use_cache": true,
  "vocab_size": 128256
}

In [2]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

In [3]:

if __name__ == "__main__":
    # 예시: 모델 초기화
    custom_model = WhisperEncoderLlamaDecoder(
        freeze_whisper_encoder=True,
        freeze_llama_decoder=False
    )

    # 가짜 입력
    batch_size = 1
    #dummy_input_features = torch.randn(batch_size, 80, 3000)  # (batch, mel_bins, frames) 정도 가정
    dummy_decoder_input_ids = torch.randint(0, 1000, (batch_size, 16))  # 임의 토큰
    # forward
    outputs = custom_model(
        input_features=input_features
    )
    print(outputs.shape)  # (batch_size, seq_len, vocab_size) 형태가 기대됨


Whisper encoder 지남
Encoder output shape: torch.Size([1, 1500, 1280])
After bridge shape: torch.Size([1, 1500, 2048])
Decoder token emb shape: torch.Size([1, 1, 2048])


RuntimeError: The size of tensor a (64) must match the size of tensor b (128) at non-singleton dimension 3

AttributeError: 'WhisperEncoder' object has no attribute 'embed_proj'