In [6]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer, GPT2Config, GPT2Model, GPT2Tokenizer, BertConfig, BertModel, BertTokenizer, AutoTokenizer, AutoModelForCausalLM, AutoModel
import transformers
from peft import LoraConfig, TaskType, get_peft_model, IA3Config
import logging
import pytorch_lightning as pl
from torchmetrics import AUROC

In [7]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

transformers.logging.set_verbosity_error()

token = "redacted"

OPERA_CT_TARGET_MODULES = ["qkv", "proj"]
OPERA_CE_TARGET_MODULES = ['conv', 'fc', 'linear']
target_module_dict = {"operaCT": OPERA_CT_TARGET_MODULES, "operaCE": OPERA_CE_TARGET_MODULES}
LLM_TARGET_MODULES = ["q_proj", "v_proj"]
LLM_TARGET_MODULES_ALLPROJ = ["q_proj", "k_proj", "v_proj", "o_proj"]


In [8]:
class FlattenHead(nn.Module):
    def __init__(self, nf, out_dim, head_dropout=0):
        super().__init__()
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, out_dim)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x, no_fc=False):
        x = self.flatten(x)
        if no_fc:
            return x
        x = self.linear(x)
        x = self.dropout(x)
        return x

In [10]:
class RespLLM(nn.Module):
    def __init__(self, configs):
        super(RespLLM, self).__init__()

        # Loss and configuration parameters
        self.loss = nn.CrossEntropyLoss()
        self.n_cls = configs.n_cls
        self.validation_step_outputs = []
        self.test_step_outputs = []

        self.d_ff = configs.d_ff
        self.d_llm = configs.llm_dim
        self.audio_peft = configs.audio_peft
        self.d_audio = configs.enc_dim
        self.patch_nums = configs.patch_nums
        self.head_nf = self.d_ff * self.patch_nums

        self.llm_peft = configs.llm_peft
        self.llm_lora_rank = configs.llm_lora_rank
        self.llm_lora_alpha = configs.llm_lora_alpha
        self.llm_lora_dropout = configs.llm_lora_dropout

        self.use_audio = configs.use_audio

        # LLaMA model initialization
        if configs.llm_model == 'llama3':
            self.llama_config = LlamaConfig.from_pretrained('meta-llama/Meta-Llama-3-8B')  # 13.5G
            try:
                self.llm_model = LlamaModel.from_pretrained(
                    'meta-llama/Meta-Llama-3-8B',
                    trust_remote_code=True,
                    local_files_only=True,
                    config=self.llama_config,
                )
            except EnvironmentError:
                print("Local model files not found. Attempting to download...")
                self.llm_model = LlamaModel.from_pretrained(
                    'meta-llama/Meta-Llama-3-8B',
                    trust_remote_code=True,
                    local_files_only=False,
                    config=self.llama_config,
                )
            try:
                self.tokenizer = LlamaTokenizer.from_pretrained(
                    'meta-llama/Meta-Llama-3-8B',
                    trust_remote_code=True,
                    local_files_only=True
                )
            except EnvironmentError:
                print("Local tokenizer files not found. Attempting to download them...")
                self.tokenizer = LlamaTokenizer.from_pretrained(
                    'meta-llama/Meta-Llama-3-8B',
                    trust_remote_code=True,
                    local_files_only=False
                )
            else:
                print("end here")

        # Audio PEFT configuration
        if self.audio_peft == "lora":
            peft_config = LoraConfig(
                r=configs.audio_lora_rank,
                lora_alpha=32,
                lora_dropout=0.1,
                target_modules=target_module_dict[configs.audio_encoder]
            )
        elif self.audio_peft == "IA3":
            peft_config = IA3Config(
                target_modules=target_module_dict[configs.audio_encoder],
                feedforward_modules=['proj']
            )
        else:
            raise NotImplementedError("Audio fine-tuning mode undefined")

        self.audio_encoder = get_peft_model(self.audio_encoder, peft_config)
        self.audio_encoder.print_trainable_parameters()

        # Aligner module
        if configs.aligner == "projection":
            self.aligner = nn.Linear(self.d_audio, self.d_llm)
        else:
            raise NotImplementedError("Aligner module undefined")

        # Output projection
        self.head_dropout = configs.head_dropout
        self.output_projection = FlattenHead(self.head_nf, self.n_cls, head_dropout=self.head_dropout)

        self.print_trainable()

        


In [11]:
def reinitialize_clf(self, n_cls):
        self.output_projection = FlattenHead(self.head_nf, n_cls, head_dropout=self.head_dropout)

In [12]:
def print_trainable(self):
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("total trainable parameters:", trainable_params)


In [13]:
def reset_trainable(self):
        if self.llm_peft == "lora":
            for name, param in self.audio_encoder.named_parameters():
                if "lora" in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        elif self.llm_peft == "frozen":
            for param in self.llm_model.parameters():
                param.requires_grad = False
        
        for param in self.aligner.parameters():
            param.requires_grad = True
        
        if self.audio_peft == "frozen":
            for param in self.audio_encoder.parameters():
                param.requires_grad = False
        elif self.audio_peft == "full":
            for param in self.audio_encoder.parameters():
                param.requires_grad = True
        
        for param in self.output_projection.parameters():
            param.requires_grad = True
        self.print_trainable()

In [14]:
def forward(self, x_spectrogram, x_prompt, x_context, no_fc=False):

        if self.patch_nums == 1:
            x_enc = self.audio_encoder(x_spectrogram)
            # print(x_enc.shape)
            enc_out = self.aligner(x_enc)
            enc_out = enc_out.unsqueeze(dim=1)
        elif self.patch_nums == 64:
            x_enc = self.audio_encoder.forward_window(x_spectrogram)
            # print(x_enc.shape)
            enc_out = self.aligner(x_enc)
        else:
            raise NotImplementedError

        prompt = self.tokenizer(x_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
        prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device))  # (batch, prompt_token, dim)

        context = self.tokenizer(x_context, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
        context_embeddings = self.llm_model.get_input_embeddings()(context.to(x_enc.device))  # (batch, prompt_token, dim)

        # print(prompt_embeddings.shape, enc_out.shape)

        if self.use_audio:
            llama_enc_out = torch.cat([prompt_embeddings, context_embeddings, enc_out], dim=1)
        else:
            llama_enc_out = torch.cat([prompt_embeddings, context_embeddings], dim=1)

        dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
        # print(dec_out.shape)
        dec_out = dec_out[:, :, :self.d_ff]
        # print(dec_out.shape)

        dec_out = dec_out.permute(0, 2, 1).contiguous()
        # print(dec_out.shape)

        dec_out = self.output_projection(dec_out[:, :, -self.patch_nums:], no_fc=no_fc)
        # print(dec_out.shape)
        return dec_out

In [15]:
class Config:
    n_cls = 10
    d_ff = 512
    llm_dim = 1024
    audio_peft = "lora"
    enc_dim = 256
    patch_nums = 1
    llm_peft = "lora"
    llm_lora_rank = 8
    llm_lora_alpha = 32
    llm_lora_dropout = 0.1
    use_audio = True
    llm_model = "llama3"
    audio_lora_rank = 4
    audio_encoder = "operaCT"
    aligner = "projection"
    head_dropout = 0.1

configs = Config()

# Create an instance of RespLLM
model = RespLLM(configs)

# Example inputs for the forward function
x_spectrogram = torch.randn(1, 256, 64)  # Example spectrogram input
x_prompt = "This is a test prompt."
x_context = "This is a test context."

# Call the forward function
output = model.forward(x_spectrogram, x_prompt, x_context)
print(output)

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B.
403 Client Error. (Request ID: Root=1-67e27daa-492d549c72ca496079f49ce2;055e6229-bdc4-4a8b-acee-b768501dfdc3)

Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/config.json.
Access to model meta-llama/Meta-Llama-3-8B is restricted and you are not in the authorized list. Visit https://huggingface.co/meta-llama/Meta-Llama-3-8B to ask for access.