In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import AUROC

import transformers
from transformers import (
    LlamaConfig, 
    LlamaModel, 
    LlamaTokenizer, 
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoModel
)
from peft import (
    LoraConfig, 
    IA3Config, 
    get_peft_model, 
    TaskType
)


In [3]:
transformers.logging.set_verbosity_error()

In [4]:
# Các target modules cho PEFT
OPERA_CT_TARGET_MODULES = ["qkv", "proj"]
OPERA_CE_TARGET_MODULES = ["conv", "fc", "linear"]
target_module_dict = {
    "operaCT": OPERA_CT_TARGET_MODULES, 
    "operaCE": OPERA_CE_TARGET_MODULES
}

LLM_TARGET_MODULES = ["q_proj", "v_proj"]
LLM_TARGET_MODULES_ALLPROJ = ["q_proj", "k_proj", "v_proj", "o_proj"]

In [5]:
class OperaCTEncoder(nn.Module):
    def __init__(self, in_dim=256, out_dim=256):
        super().__init__()
        # Đặt tên module như "qkv" và "proj" để PEFT có thể nhận diện
        self.qkv = nn.Conv1d(in_dim, out_dim, kernel_size=3, padding=1)
        self.proj = nn.Linear(out_dim, out_dim)

    def forward(self, x):
        """
        x: (batch_size, in_dim, time)
        """
        x = self.qkv(x)   # (B, out_dim, time)
        x = F.relu(x)
        # Global average pool theo chiều time
        x = x.mean(dim=-1)  # (B, out_dim)

        # Linear projection
        x = self.proj(x)    # (B, out_dim)
        return x

    def forward_window(self, x):
        """
        Nếu muốn chia x thành nhiều patch (ví dụ 64 patch),
        ta có thể implement logic ở đây. Tạm thời dùng chung forward().
        """
        return self.forward(x)

In [6]:
class FlattenHead(nn.Module):
    def __init__(self, nf, out_dim, head_dropout=0):
        super().__init__()
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, out_dim)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x, no_fc=False):
        x = self.flatten(x)
        if no_fc:
            return x
        x = self.linear(x)
        x = self.dropout(x)
        return x

In [17]:
class RespLLM(nn.Module):
    def __init__(self, configs):
        super(RespLLM, self).__init__()

        # Lưu các tham số từ configs
        self.n_cls = configs.n_cls
        self.d_ff = configs.d_ff
        self.d_llm = configs.llm_dim
        self.audio_peft = configs.audio_peft
        self.d_audio = configs.enc_dim
        self.patch_nums = configs.patch_nums
        self.head_nf = self.d_ff * self.patch_nums

        self.llm_peft = configs.llm_peft
        self.llm_lora_rank = configs.llm_lora_rank
        self.llm_lora_alpha = configs.llm_lora_alpha
        self.llm_lora_dropout = configs.llm_lora_dropout

        self.use_audio = configs.use_audio
        self.loss = nn.CrossEntropyLoss()
        self.validation_step_outputs = []
        self.test_step_outputs = []

        # LLM model initialization 
        if configs.llm_model == 'llama3':
            # Thay meta-llama/Llama-2-7b-hf cho hợp lệ
            self.llama_config = LlamaConfig.from_pretrained('meta-llama/Llama-3.3-70B-Instruct')
            try:
                self.llm_model = LlamaModel.from_pretrained(
                    'meta-llama/Llama-3.3-70B-Instruct',
                    trust_remote_code=True,
                    local_files_only=False,
                    config=self.llama_config,
                )
            except EnvironmentError:
                print("Không tìm thấy model cục bộ, thử tải về ...")
                self.llm_model = LlamaModel.from_pretrained(
                    'meta-llama/Llama-3.3-70B-Instruct',
                    trust_remote_code=True,
                    local_files_only=False,
                    config=self.llama_config,
                )

            try:
                self.tokenizer = LlamaTokenizer.from_pretrained(
                    'meta-llama/Llama-3.3-70B-Instruct',
                    trust_remote_code=True,
                    local_files_only=False
                )
            except EnvironmentError:
                print("Không tìm thấy tokenizer cục bộ, thử tải về ...")
                self.tokenizer = LlamaTokenizer.from_pretrained(
                    'meta-llama/Llama-3.3-70B-Instruct',
                    trust_remote_code=True,
                    local_files_only=False
                )

        # Khởi tạo Audio Encoder và bọc PEFT 
        # (Ví dụ: OperaCTEncoder -> LoRA)
        self.base_audio_encoder = OperaCTEncoder(
            in_dim=self.d_audio, 
            out_dim=self.d_audio
        )

        if self.audio_peft == "lora":
            peft_config = LoraConfig(
                r=configs.audio_lora_rank,
                lora_alpha=32,
                lora_dropout=0.1,
                target_modules=target_module_dict[configs.audio_encoder]
            )
        elif self.audio_peft == "IA3":
            peft_config = IA3Config(
                target_modules=target_module_dict[configs.audio_encoder],
                feedforward_modules=['proj']
            )
        else:
            raise NotImplementedError("Audio fine-tuning mode undefined")

        # Bọc audio_encoder với PEFT
        self.audio_encoder = get_peft_model(self.base_audio_encoder, peft_config)
        self.audio_encoder.print_trainable_parameters()

        # Aligner module (chuyển từ d_audio sang d_llm) 
        if configs.aligner == "projection":
            self.aligner = nn.Linear(self.d_audio, self.d_llm)
        else:
            raise NotImplementedError("Aligner module undefined")

        # Output projection head 
        self.head_dropout = configs.head_dropout
        self.output_projection = FlattenHead(self.head_nf, self.n_cls, head_dropout=self.head_dropout)

        # In số lượng tham số trainable
        self.print_trainable()


In [18]:
def reinitialize_clf(self, n_cls):
        self.output_projection = FlattenHead(self.head_nf, n_cls, head_dropout=self.head_dropout)


In [19]:
def print_trainable(self):
    trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
    print("total trainable parameters:", trainable_params)


In [20]:
def reset_trainable(self):
        # Nếu muốn frozen LLM (chẳng hạn)
        if self.llm_peft == "frozen":
            for param in self.llm_model.parameters():
                param.requires_grad = False
        # Ngược lại nếu muốn LoRA cho LLM, có thể định nghĩa:
        # (Ví dụ)
        # if self.llm_peft == "lora":
        #     # Gọi get_peft_model(self.llm_model, some_lora_config)
        #     pass

        # Audio encoder
        if self.audio_peft == "frozen":
            for param in self.audio_encoder.parameters():
                param.requires_grad = False
        elif self.audio_peft == "full":
            for param in self.audio_encoder.parameters():
                param.requires_grad = True

        # Aligner
        for param in self.aligner.parameters():
            param.requires_grad = True

        # Output projection
        for param in self.output_projection.parameters():
            param.requires_grad = True

        self.print_trainable()

In [21]:
def forward(self, x_spectrogram, x_prompt, x_context, no_fc=False):
        # ======= Encode audio =======
        if self.patch_nums == 1:
            x_enc = self.audio_encoder(x_spectrogram)  # (B, d_audio)
            enc_out = self.aligner(x_enc)              # (B, d_llm)
            enc_out = enc_out.unsqueeze(dim=1)         # (B, 1, d_llm)
        elif self.patch_nums == 64:
            x_enc = self.audio_encoder.forward_window(x_spectrogram)
            enc_out = self.aligner(x_enc)  # Giả sử (B, 64, d_llm) 
        else:
            raise NotImplementedError
        
        # Tokenize prompt/context rồi lấy embeddings 
        prompt = self.tokenizer(
            x_prompt, return_tensors="pt", 
            padding=True, truncation=True, max_length=2048
        ).input_ids.to(x_enc.device)

        context = self.tokenizer(
            x_context, return_tensors="pt", 
            padding=True, truncation=True, max_length=2048
        ).input_ids.to(x_enc.device)

        prompt_embeddings = self.llm_model.get_input_embeddings()(prompt)     # (B, prompt_len, d_llm)
        context_embeddings = self.llm_model.get_input_embeddings()(context)   # (B, context_len, d_llm)

        # ======= Ghép embeddings (text + audio) =======
        if self.use_audio:
            llama_enc_out = torch.cat([prompt_embeddings, context_embeddings, enc_out], dim=1)
        else:
            llama_enc_out = torch.cat([prompt_embeddings, context_embeddings], dim=1)

        # Cho qua LLM 
        dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
        # dec_out shape: (B, seq_len, d_llm)

        # Giả sử ta chỉ giữ lại d_ff đầu (nếu d_llm >= d_ff)
        dec_out = dec_out[:, :, :self.d_ff]  # (B, seq_len, d_ff)
        dec_out = dec_out.permute(0, 2, 1).contiguous()  # (B, d_ff, seq_len)

        # Lấy phần cuối, hoặc patch cuối, rồi cho qua head 
        dec_out = dec_out[:, :, -self.patch_nums:]  # (B, d_ff, patch_nums)
        logits = self.output_projection(dec_out, no_fc=no_fc)  # (B, n_cls)

        return logits


In [22]:
class Config:
    n_cls = 10
    d_ff = 512
    llm_dim = 1024
    audio_peft = "lora"
    enc_dim = 256
    patch_nums = 1
    llm_peft = "frozen"     # hoặc "lora" nếu muốn LoRA cho LLM
    llm_lora_rank = 8
    llm_lora_alpha = 32
    llm_lora_dropout = 0.1
    use_audio = True
    llm_model = "llama3"    # giả sử 'llama3' tương ứng Llama2 7B HF
    audio_lora_rank = 4
    audio_encoder = "operaCT"
    aligner = "projection"
    head_dropout = 0.1

In [23]:
# Tạo instance và chạy thử forward
if __name__ == "__main__":
    configs = Config()
    model = RespLLM(configs)

    # Tạo input ví dụ
    x_spectrogram = torch.randn(1, 256, 64)  # (batch=1, in_dim=256, time=64)
    x_prompt = "Đây là một prompt thử nghiệm."
    x_context = "Đây là phần context thử nghiệm."

    # Forward
    output = model.forward(x_spectrogram, x_prompt, x_context)
    print("Output shape:", output.shape)
    print("Output:", output)

config.json:   0%|          | 0.00/879 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


model.safetensors.index.json:   0%|          | 0.00/59.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/30 [00:00<?, ?it/s]

model-00001-of-00030.safetensors:   0%|          | 0.00/4.58G [00:00<?, ?B/s]

KeyboardInterrupt: 