# Задача

Нужно обучить LLM понимать модальность аудио.

Нужно реализовать архитектуру аналогичную LLaVA ([github](https://github.com/haotian-liu/LLaVA), [paper](https://arxiv.org/abs/2304.08485))
но для аудио.

Нужно обучить адаптер аудио для LLM по аналогии с LLava

В качестве LLM будем использовать [TinyLlama](https://github.com/jzhang38/TinyLlama).

В качестве CLIP-like энкодера аудио -- ImageBind ([github](https://github.com/facebookresearch/ImageBind), [paper](https://arxiv.org/abs/2305.05665))

# Данные

Для обучения возьмем задачу описания аудио (audio captioning) на датасете [`Clotho`](https://zenodo.org/records/3490684)



# Модели



## Вкратце про ImageBind

[Github](https://github.com/facebookresearch/ImageBind)

Это CLIP-like экнодер для 5 модальностей.

Идея и метод обучения точно такой же как у CLIP. Из нового -- это больше данных разных модальностей. Крутость модельки в том, что через модальность изображений, получилось обучить совместные эмбэддинги для всех модальностей. То есть можно сделать zero-shot классификатор для Audio, хотя модальность текста и аудио ни разу во время обучения не взаимодействовали - все обучалось через изображения (то есть изображение-аудио / изображение-текст).

Мы будем использовать только энкодер для аудио. И воспользуемся свойствами этих эмбэддингов -- в векторном представлении аудио должна быть информация о семантике аудио


## Вкратце про LLaVA

За мостик между эмбэддингами `LLM` и модальностью изображений отвечает предобученный `CLIP`. Через `CLIP` получаем совместные векторные представления для картинок (используем именно `pooled` `CLIP` эмбэддинги). Каждой картинки соответствует один вектор. Не забываем нормализовать эти векторные предславления. Чтобы перевести этот эмбэддинг в пространство эмбэддингов `LLM` используем просто линейное преобразование. Это линейное преобразование переводит `CLIP` эмбэддинг в `N` эмбэддингов `LLM`.

```python
image_embeddings = clip_model.encode_image( images ) # [ bs, 1, clip_hidden_dim ]
image_embeddings = F.normalize(image_embeddings)

llm_embeddings = adapter.projection( image_embeddings ) # [ bs, adapter.projection_llm_tokens, llm_hidden_dim ]

# далее используем llm_embeddings как эмбэддинги картинки/аудио для задачи `CausalLM`
```

Эмбэддинги новой модальности дополнительно заворачиваются в специальные токены. Потом обычно идут текстовые эмбэддинги:

```
<audio_start> *audio_token_1*, *audio_token_2*, ..., *audio_token_N* <audio_end> <bos_token> *text_token_1*, ... *text_token_M* <eos_token>
```


В домашке в некоторых местах по аналогии есть аббревиатура `LLaAA` -- это расшифровывается как Large Language and Audio Assistant. По аналогии с `LLaVA` -- Large Language and Vision Assistant

## Процесс обучения

Во время обучения `CLIP` енкодер всегда заморожен.

Обучение делится на 2 этапа:

1) LLM заморожена. Обучаем только адаптер.

2) Одновременно и файнтюним LLM через Peft, и дообучаем адаптер. Код-заготовка есть в шаблоне домашки. Но это не входит в обязательную часть домашки

## Обучение проекции

Перед обучением адаптера может потребоваться рестарт рантайма колаба, чтобы освободились ресурсы от `ImageBind`, который мы подгружали во время предобработки данных.

## Метрики качества

Во время обучения метрики качества вычисляются только для небольшого количества семплов из валидационного датасета. Это сделано, чтобы уменьшить длительность валидации.

---

# Предобработка данных

In [7]:
!pip install -q aac_datasets

In [9]:
# загрузим данные
# всего тут потребуется ~10GB
from aac_datasets import Clotho

dataset_prefix = "./data"

val_dataset = Clotho(root="./data", download=False, subset='val')
dataset = Clotho(root="./data", download=False, subset='dev')

# val_dataset = Clotho(root="./data", download=True, subset='val')
# dataset = Clotho(root="./data", download=True, subset='dev')


In [10]:
val_dataset

Clotho(subset=val, size=1045, num_columns=14, version=v2.1)

In [11]:
dataset

Clotho(subset=dev, size=3839, num_columns=14, version=v2.1)

In [23]:
!pip install -q evaluate accelerate rouge_score

In [24]:
!pip install -q soundfile sox

In [8]:
!pip install -q git+https://github.com/facebookresearch/ImageBind 

In [22]:
!pip install -q rouge_score evaluate peft==0.5.0

In [1]:
import torch
import torch.nn as nn

# тк ImagebindAudioEncoder будет использоваться только для
# подготовки данных и прямо сейчас не нужен во время обучения
# для того, чтобы удовлетворить интерфейсу аудио энкодера соберем
# класс-заглушку для аудио энкодера
class DummyAudioEncoder(nn.Module):

    hidden_size = 1024

    def encode_audio(self, audio_melspec_values):
        return torch.zeros([audio_melspec_values.shape[0], 1, self.hidden_size], device=audio_melspec_values.device)


In [12]:
# from imagebind.models import imagebind_model
# from imagebind.models.imagebind_model import ModalityType

class ImagebindAudioEncoder(nn.Module):

    hidden_size = 1024

    def __init__(self, pretrained=True) -> None:
        super().__init__()
        self.imagebind = imagebind_model.imagebind_huge(pretrained=pretrained)
        return

    def encode_audio(self, audio_melspec_values):
        inputs = {
            ModalityType.AUDIO: audio_melspec_values,
        }
        return self.imagebind(inputs)[ModalityType.AUDIO]



In [2]:
from typing import Callable
import torch
from tqdm.auto import tqdm
import pathlib


import pandas as pd
from datasets import Dataset

# from imagebind import data as imagebind_data


@torch.no_grad()
def preprocess_clotho_data(
        clotho_base_path: pathlib.Path,
        split_name: str,
        generate_melspec: Callable,
        generate_audio_embeddings: Callable,
        preprocessor_suffix="",
        device='cpu'
        ):

    clotho_hf_dataset_base_path = clotho_base_path.joinpath("clotho_hf_dataset") ## what???
    clotho_wav_base_path = clotho_base_path.joinpath(f'clotho_audio_files/{split_name}')

    if preprocessor_suffix != "":
        preprocessor_suffix = "_" + preprocessor_suffix

    clotho_melspec_base_path = clotho_base_path.joinpath(f'clotho_melspec_processed{preprocessor_suffix}/{split_name}')
    clotho_melspec_base_path.mkdir(parents=True, exist_ok=True)

    clotho_audio_embeds_base_path = clotho_base_path.joinpath(f'clotho_audio_embeds_processed{preprocessor_suffix}/{split_name}')
    clotho_audio_embeds_base_path.mkdir(parents=True, exist_ok=True)

    # audio captioning
    df_clotho = pd.read_csv(clotho_base_path.joinpath(f"clotho_csv_files/clotho_captions_{split_name}.csv"))

    melspec_processed_items = []

    for i, clotho_item in tqdm(df_clotho.iterrows(), total=len(df_clotho), desc=split_name):

        melspec_file_name = clotho_item['file_name'].replace(".wav", ".pt")
        melspec_full_path = clotho_melspec_base_path.joinpath(melspec_file_name)

        # process melspec
        full_file_path = clotho_wav_base_path.joinpath(clotho_item['file_name'])
        audio_melspec_values = generate_melspec(full_file_path)
        torch.save(audio_melspec_values, melspec_full_path)

        clotho_item["melspec_file_name"] = melspec_file_name

        # process audio embeds
        audio_embeds_last_hidden_state_file_name = clotho_item['file_name'].replace(".wav", ".pt")
        audio_embeds_full_path = clotho_audio_embeds_base_path.joinpath(audio_embeds_last_hidden_state_file_name)
        audio_embeds_last_hidden_state = generate_audio_embeddings(audio_melspec_values.to(device))
        torch.save(audio_embeds_last_hidden_state, audio_embeds_full_path)

        clotho_item["audio_embeds_last_hidden_state_file_name"] = audio_embeds_last_hidden_state_file_name

        melspec_processed_items.append(dict(clotho_item))

    newmelspec_dataset = Dataset.from_list(melspec_processed_items)

    melspec_dataset_path = clotho_hf_dataset_base_path.joinpath(f'clotho_{split_name}{preprocessor_suffix}.dataset')
    print(f"saved {split_name} to dir:", melspec_dataset_path)
    newmelspec_dataset.save_to_disk(dataset_path=str(melspec_dataset_path))


In [None]:
clotho_base_path = pathlib.Path('data/CLOTHO_v2.1')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

imagebind_audio_encoder = ImagebindAudioEncoder().to(device)

Downloading imagebind weights to .checkpoints/imagebind_huge.pth ...


100%|██████████| 4.47G/4.47G [00:56<00:00, 85.1MB/s]


In [None]:
def generate_melspec_imagebind(audio_file_path):
    return imagebind_data.load_and_transform_audio_data(audio_paths=[audio_file_path], device='cpu')

# imagebind_audio_encoder closure
def generate_audio_embeddings_imagebind(audio_melspec_values):
    return imagebind_audio_encoder.encode_audio(audio_melspec_values)

preprocess_clotho_data(
    clotho_base_path,
    split_name="development",
    generate_melspec=generate_melspec_imagebind,
    generate_audio_embeddings=generate_audio_embeddings_imagebind,
    device=device,
    preprocessor_suffix='imagebind_single'
)

preprocess_clotho_data(
    clotho_base_path,
    split_name="validation",
    generate_melspec=generate_melspec_imagebind,
    generate_audio_embeddings=generate_audio_embeddings_imagebind,
    device=device,
    preprocessor_suffix='imagebind_single'
)

development:   0%|          | 0/3839 [00:00<?, ?it/s]

saved development to dir: data/CLOTHO_v2.1/clotho_hf_dataset/clotho_development_imagebind_single.dataset


Saving the dataset (0/1 shards):   0%|          | 0/3839 [00:00<?, ? examples/s]

validation:   0%|          | 0/1045 [00:00<?, ?it/s]

saved validation to dir: data/CLOTHO_v2.1/clotho_hf_dataset/clotho_validation_imagebind_single.dataset


Saving the dataset (0/1 shards):   0%|          | 0/1045 [00:00<?, ? examples/s]

In [3]:
# pip install transformers

In [4]:
import torch
from torch.utils.data import DataLoader, Dataset
import datasets
from tqdm import tqdm

# Описываем модельку

In [5]:
# Важно!
# Эта ячейка во время автогрейдинга будет копироваться в отдельный файл
# Пожалуйста, не меняйте публичный интерфейс для класса Llaaa
# Не рекомендуется прописывать в эту ячейку новые зависимости - их может не быть
# на раннере в окружении автогрейдера.

# Места, которые нужно дополнить кодом начинаются на TODO


import os
import warnings

from typing import Union, Optional, Any

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import PreTrainedModel, PretrainedConfig, LlamaForCausalLM, LlamaConfig, AutoTokenizer

import logging
logger = logging.getLogger(__name__)


class TinyLlamaLanguageModel(LlamaForCausalLM):

    def encode_text(self, intput_ids=None):
        return self.model.embed_tokens(intput_ids)


def llama_lm():
    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v0.4"

    model = TinyLlamaLanguageModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer


class LlaaaConfig(PretrainedConfig): # изучить PretrainedConfig
    def __init__(
        self,
        lm_hidden_size=2048,     # should match language model hidden size
        audio_hidden_size=1024,  # should match audio encoder hidden size
        modality_tokens=64,
        **kwargs,
    ) -> None:
        if len(kwargs) > 0:
            warnings.warn(
                f"Unknown {len(kwargs)} keywords arguments for {self.__class__.__name__}. (found {tuple(kwargs.keys())})"
            )

        super().__init__()

        self.lm_hidden_size = lm_hidden_size
        self.audio_hidden_size = audio_hidden_size
        self.modality_tokens = modality_tokens
        
# class ResidualMLPProjector(nn.Module):
#     def __init__(self, in_channels, hidden_channels, out_channels):
#         super().__init__()
#         # self.lin_proj = nn.Linear(in_channels, out_channels)
#         self.proj = nn.Sequential(
#             nn.Linear(in_channels, hidden_channels),
#             nn.GELU(),
#             nn.Linear(hidden_channels, out_channels)
#         )  
#         self.act = nn.GELU()
        
#     def forward(self, x):
#         return self.proj(x)# + self.lin_proj(x)
class ResidualMLPProjector(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.lin_proj = nn.Linear(in_channels, out_channels)
        self.proj = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.GELU(),
            nn.Linear(hidden_channels, out_channels)
        )  
        self.act = nn.GELU()
        
    def forward(self, x):
        return self.proj(x) + self.lin_proj(x)
        
def init_weights(model):
    if isinstance(model, nn.Linear):
        torch.nn.init.normal_(model.weight, mean=0, std=0.02)
        torch.nn.init.constant_(model.bias, 0)    
        
class Llaaa(PreTrainedModel):

    config_class = LlaaaConfig

    start_audio_token_id = 0
    end_audio_token_id = 1

    def __init__(self, lm_model: TinyLlamaLanguageModel, audio_encoder, config: LlaaaConfig):
        super().__init__(config)

        self.lm_model = lm_model
        self.audio_encoder = audio_encoder

        lm_hidden_size = self.config.lm_hidden_size
        audio_hidden_size = self.config.audio_hidden_size

        self.audio_tokens_embeddings = nn.Embedding(2, lm_hidden_size)
        self.projection = ResidualMLPProjector(audio_hidden_size, lm_hidden_size, self.config.modality_tokens*lm_hidden_size)
        #nn.Linear(audio_hidden_size, self.config.modality_tokens*lm_hidden_size)
        
        self.start_audio_token_id_tensor = torch.tensor(self.start_audio_token_id)
        self.end_audio_token_id_tensor = torch.tensor(self.end_audio_token_id)
        
        return

    def reinitialize_weights(self, std=0.02):
        self.projection.apply(init_weights)
        # nn.init.normal_(self.projection.weight, mean=0, std=std)
        # nn.init.constant_(self.projection.bias, 0)
        nn.init.normal_(self.audio_tokens_embeddings.weight, mean=0, std=std)

        return

    def encode_text(self, intput_ids=None):
        return self.lm_model.encode_text(intput_ids)

    def encode_audio(self, audio_melspec_values):
        audio_embeddings = self.audio_encoder.encode_audio(audio_melspec_values)
        audio_embeddings = F.normalize(audio_embeddings, dim=-1)
        return audio_embeddings


    def prepare_audio_inputs(self, input_ids=None, inputs_embeds=None, audio_embeds=None, attention_mask=None):
        """
        В этом методе нужно:
        * "завернуть" эмбэддинги аудио в специальные токены начала и окончания модальности аудио,
        * вычислить проекцию аудио токенов в N токенов пространства LLM,
        * посчитать маску вниминия с учетом добавленных токенов  (подробнее) 
        """

        if input_ids is not None:
            if inputs_embeds is not None:
                logger.info("using inputs_embeds with input_ids the same time! inputs_embeds will be fist and then concatenated with input_ids")
            inputs_embeds = self.encode_text(input_ids)
        
        # audio_embeds_projection expected shape [ bs, seq_len, llama_hidden_dim * self.config.modality_tokens ] 
        audio_embeds_projection = self.projection(audio_embeds).view(-1, self.config.modality_tokens, self.config.lm_hidden_size)

        start_audio_token = self.audio_tokens_embeddings(self.start_audio_token_id_tensor.to(self.device))\
                            .expand(audio_embeds.shape[0], 1, self.config.lm_hidden_size)
        end_audio_token = self.audio_tokens_embeddings(self.end_audio_token_id_tensor.to(self.device))\
                            .expand(audio_embeds.shape[0], 1, self.config.lm_hidden_size)
        audio_embeds_projection = torch.cat([start_audio_token, audio_embeds_projection, end_audio_token], axis=1)
        
        if inputs_embeds is not None:
            inputs_embeds = torch.cat([audio_embeds_projection, inputs_embeds.to(self.device)], axis=1)
        else:
            inputs_embeds = audio_embeds_projection
        if attention_mask is not None:
            audio_attention_mask = torch.ones(audio_embeds_projection.size(0), audio_embeds_projection.size(1)).to(self.device)
            attention_mask = torch.cat([audio_attention_mask, attention_mask], dim=1)


        return {
            "inputs_embeds":  inputs_embeds,
            "attention_mask": attention_mask,
        }

    def forward(self, input_ids=None, inputs_embeds=None, attention_mask=None, **kwargs):
        # audio_embeds  [ bs, seq_len, embed_dim ]
        # inputs_embeds [ bs, seq_len, embed_dim ]
        # подумать!!!
        llama_out = self.lm_model.forward(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask)

        return llama_out

    @classmethod
    def from_pretrained(cls, lm_model, audio_encoder, model_id: str):
        # Чтобы сохранились только модули проекции, надо переопределить
        # методы сохранения и загрузки модели
        # Иначе сохранять и загружать мы будем и аудио энкодер, и LLM
        # Этого мы не хотим, потому что так сохраненная модель будет
        # намного больше занимать места на диске и просто это неэффективно.
        # Слой проекции должен отвечать только за проекцию.
        # Если добавляете новые модули в класс, не забудьте дописать их в
        # метод сохранения и загрузки.

        llaaa_config = LlaaaConfig.from_pretrained(model_id)
        llaaa_model = cls(lm_model, audio_encoder, llaaa_config)

        projection_file_name = "projection.pt"
        audio_tokens_embeddings_file_name = "audio_tokens_embeddings.pt"

        llaaa_model.projection.load_state_dict(torch.load(projection_file_name))
        llaaa_model.audio_tokens_embeddings.load_state_dict(torch.load(audio_tokens_embeddings_file_name))

        return llaaa_model

    def save_pretrained(self, save_directory: str):
        if os.path.isfile(save_directory):
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")

        os.makedirs(save_directory, exist_ok=True)

        projection_file_name = "projection.pt"
        audio_tokens_embeddings_file_name = "audio_tokens_embeddings.pt"

        projection_file_name = os.path.join(save_directory, projection_file_name)
        audio_tokens_embeddings_file_name = os.path.join(save_directory, audio_tokens_embeddings_file_name)

        torch.save(self.projection.state_dict(), projection_file_name)
        torch.save(self.audio_tokens_embeddings.state_dict(), audio_tokens_embeddings_file_name)
        self.config.save_pretrained(save_directory)

        #
        # TODO Сохраняем веса
        # и конфиг из self.config (см https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/configuration#transformers.PretrainedConfig)
        #

        return

# Обучаем модельку

In [6]:
# pip install wandb

In [7]:
# wandb - менее глюченая альтернатива tensorboard
# и имхо даже более удобная

import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msmir_kko[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [8]:
class TrainConfig:
    log_level = "DEBUG"

    # Training
    num_epochs = 5
    train_batch_size = 8
    val_batch_size = 1
    log_grad_norm = True
    learning_rate = 1e-4
    gradient_accumulation_steps = 1

    evaluate_every_epoch_mod = 4
    save_model_every_epoch_mod = 1

    # Model
    llaaa_from_pretrained = None
    modality_tokens = 32
    llm_train_lora = False
    llm_lora_from_pretrained = None

    # Data
    few_train_samples = None
    few_val_samples = 100
    dataloader_num_workers = 0

    train_dataset_path = "data/CLOTHO_v2.1/clotho_hf_dataset/clotho_development_imagebind_single.dataset/"
    audio_embeds_train_prefix = "data/CLOTHO_v2.1/clotho_audio_embeds_processed_imagebind_single/development/"

    val_dataset_path = "data/CLOTHO_v2.1/clotho_hf_dataset/clotho_validation_imagebind_single.dataset/"
    audio_embeds_val_prefix = "data/CLOTHO_v2.1/clotho_audio_embeds_processed_imagebind_single/validation/"


In [9]:
# pip install accelerate

In [10]:
import yaml
import argparse

import pathlib
import random
import torch
import torch.nn as nn

import logging
import evaluate

import datasets
from transformers.generation import GenerationConfig

from torch.optim import Adam
from torch.utils.data import DataLoader

import transformers

from tqdm.auto import tqdm

import wandb
from wandb import sdk as wandb_sdk

import accelerate


In [11]:
# валидация

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')


def prepare_model_inputs_from_batch(model: Llaaa, batch):
    if 'audio_embeds_last_hidden_state' in batch:
        audio_embeds_last_hidden_state = batch['audio_embeds_last_hidden_state'].to(model.device)
    else:
        audio_melspec_values = batch['pixel_values'].to(model.device)
        audio_embeds_last_hidden_state = model.encode_audio(audio_melspec_values)

    inputs_embeds = model.encode_text(batch['input_ids'].to(model.device))

    model_inputs_with_audio = model.prepare_audio_inputs(
        inputs_embeds=inputs_embeds,
        attention_mask=batch['attention_mask'].to(model.device),
        audio_embeds=audio_embeds_last_hidden_state,
    )

    return {
        "inputs_embeds":  model_inputs_with_audio["inputs_embeds"],
        "attention_mask": model_inputs_with_audio["attention_mask"],
    }


def get_audio_embeds_last_hidden_state(model, batch):
    if 'audio_embeds_last_hidden_state' in batch:
        audio_embeds_last_hidden_state = batch['audio_embeds_last_hidden_state'].to(model.device)
    else:
        audio_melspec_values = batch['pixel_values'].to(model.device)
        audio_embeds_last_hidden_state = model.encode_audio(audio_melspec_values)
    return audio_embeds_last_hidden_state # зачем?


def save_model(train_config: TrainConfig, model: Llaaa, path: pathlib.Path):
    path.mkdir(parents=True, exist_ok=True)
    logger.info(f"save model to {path}")

    model.save_pretrained(path)
    if train_config.llm_train_lora:
        model.lm_model.save_pretrained(path.joinpath("lora_adapter"))

    return


@torch.no_grad()
def compute_validation_metrics(generations, target_generations, captioning_metrics=None):
    validation_metrics = {}
    if captioning_metrics is not None:
        evaluate_bleu_results = captioning_metrics.compute(predictions=generations, references=target_generations)
        logger.info(f"evaluate_bleu_results {evaluate_bleu_results}")

        validation_metrics["validation/evaluate_bleu"] = evaluate_bleu_results['bleu'] * 100
        validation_metrics["validation/evaluate_rouge1"] = evaluate_bleu_results['rouge1']
        validation_metrics["validation/evaluate_rouge2"] = evaluate_bleu_results['rouge2']
        validation_metrics["validation/evaluate_rougeL"] = evaluate_bleu_results['rougeL']
        validation_metrics["validation/evaluate_rougeLsum"] = evaluate_bleu_results['rougeLsum']
        validation_metrics["validation/evaluate_meteor"] = evaluate_bleu_results['meteor']

    return validation_metrics


@torch.no_grad()
def val_loop(model: Llaaa, tokenizer, val_dataloader: DataLoader, epoch, captioning_metrics=None):

    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    logger.info(f"go validation {epoch}")

    generations = []
    target_generations = []

    gen_params = {
        "do_sample": False,
        "early_stopping": True,
        "num_beams": 3,
        "repetition_penalty": 2.5,
        "remove_invalid_values": True,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.eos_token_id,
        "forced_eos_token_id": tokenizer.eos_token_id,
        "use_cache": True,
        "no_repeat_ngram_size": 4,
        "num_return_sequences": 1,
    }
    genconfig = GenerationConfig.from_model_config(model.lm_model.config)

    model.eval()
    for batch in tqdm(val_dataloader):

        batch_input_ids = batch['input_ids'].to(model.device)
        caption_legth = batch_input_ids.shape[1]

        audio_embeds_last_hidden_state = get_audio_embeds_last_hidden_state(model, batch)

        model_inputs_with_only_audio = model.prepare_audio_inputs(
            audio_embeds=audio_embeds_last_hidden_state,
        )

        genconfig.max_length = caption_legth

        all_generation_params = {
            'generation_config': genconfig,
            'max_new_tokens': caption_legth,
            **model_inputs_with_only_audio,
            **gen_params,
        }

        model_generation = model.lm_model.generate(**all_generation_params)
        generated_sentences = tokenizer.batch_decode(model_generation, skip_special_tokens=True)
        for sentence in generated_sentences:
            sentence: str
            sentence = sentence.replace("\n", " ")
            generations.append(sentence)

        one_audio_references = []
        all_references = tokenizer.batch_decode(batch['all_input_ids'], skip_special_tokens=True)
        assert len(all_references) % 5 == 0, f'len(all_references) {len(all_references)}'
        for i, reference in enumerate(all_references):
            reference: str
            reference = reference.replace("\n", " ")
            one_audio_references.append(reference)
            if (i+1) % 5 == 0:
                target_generations.append(one_audio_references)
                one_audio_references = []

    assert len(generations) > 0, f"len(generations)={len(generations)}"
    assert len(target_generations) == len(generations), f"len(target_generations) == len(generations): {len(target_generations)} == {len(generations)}"

    validation_metrics = compute_validation_metrics(generations, target_generations, captioning_metrics=captioning_metrics)

    return validation_metrics


In [12]:
# Обучение

def train_loop(accelerator: accelerate.Accelerator, model: Llaaa, optimizer, train_dataloader: DataLoader, epoch, criterion, last_validation_bleu=0.0, train_config=None):
    model.train()
    progress_bar = tqdm(range(len(train_dataloader)), desc=f'Epoch {epoch}')
    for batch in train_dataloader:
        with accelerator.accumulate(model):
            model_inputs_with_audio = prepare_model_inputs_from_batch(model, batch)
            model_prediction = model.forward(**model_inputs_with_audio)

            # model_prediction
            #
            # TODO вычисляем лосс!
            model_prediction = model_prediction.logits[:, model.config.modality_tokens + 2 + 1:, :]
            loss = criterion(model_prediction[:, :-1, :].contiguous().view(-1, model_prediction.shape[-1]), 
                             batch['input_ids'][:, 2:].contiguous().view(-1).to(model.device))
            # Не забываем, что лосс мы вычисляем для задачи next token prediction
            # Как было в домашке по трансформерам -- не забываем сделать сдвиг токена
            # Тк output'ом для текущего токена должен быть следующий токен.
            #
            # TODO И затем наше любимое - вычисляем градиент, шаг оптимизации, зануляем градиент)
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            progress_bar.set_description(f'Epoch={epoch} Loss={loss.item():.3f} BLEU={last_validation_bleu:.3f}')

            step_metrics = {"train_loss": loss.item(), "epoch": epoch}
            if train_config.log_grad_norm:
                for name, parameter in model.projection.named_parameters():
                    if parameter.grad is not None:
                        parameter_grad_norm = parameter.grad.norm(2).item()
                    else:
                        parameter_grad_norm = 0.0
                    step_metrics[f'grad_norm_{name}'] = parameter_grad_norm

            metric_logger.log(step_metrics)
            # end accelerator accumulation
        # end train loop

    return


def train(
        model: Llaaa,
        tokenizer: transformers.AutoTokenizer,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        train_config: TrainConfig,
        metric_logger: wandb_sdk.wandb_run.Run,
        device_placement=True,
        ):

    trainable_parameters = list(model.audio_tokens_embeddings.parameters()) + list(model.projection.parameters())
    optimizer = Adam(trainable_parameters, lr=train_config.learning_rate)

    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    accelerator = accelerate.Accelerator(device_placement=device_placement)
    accelerator.gradient_accumulation_steps = train_config.gradient_accumulation_steps
    model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader)

    captioning_metrics = evaluate.combine(
        [
            evaluate.load("bleu", keep_in_memory=True),
            evaluate.load("rouge", keep_in_memory=True),
            evaluate.load("meteor", keep_in_memory=True),
        ]
    )

    best_validation_bleu = 0.0
    last_validation_bleu = 0.0
    for epoch in range(train_config.num_epochs):

        train_loop(accelerator, model, optimizer, train_dataloader, epoch=epoch, criterion=criterion, last_validation_bleu=last_validation_bleu, train_config=train_config)

        if epoch % train_config.evaluate_every_epoch_mod == 0:
            validation_metrics = val_loop(model, tokenizer, val_dataloader, epoch=epoch, captioning_metrics=captioning_metrics)
            logger.info(f"validation metrics {validation_metrics}")

            last_validation_bleu = validation_metrics['validation/evaluate_bleu']
            metric_logger.log(validation_metrics)

            if last_validation_bleu > best_validation_bleu:
                best_validation_bleu = last_validation_bleu

                base_path_for_best_model = pathlib.Path(f"data/models/{metric_logger.name}/best/")
                save_model(train_config=train_config, model=model, path=base_path_for_best_model)

        if epoch % train_config.save_model_every_epoch_mod == 0:
            base_path_for_model = pathlib.Path(f"data/models/{metric_logger.name}/last/")
            save_model(train_config=train_config, model=model, path=base_path_for_model)

    base_path_for_model = pathlib.Path(f"data/models/{metric_logger.name}/last/")
    save_model(train_config=train_config, model=model, path=base_path_for_model)

def freeze_model(model):
    for p in model.parameters():
        p.requires_grad = False
    return


In [13]:
def data_preloader(audio_embeds_path_prefix):

    def _data_preloader(items):
        result = {
            "audio_embeds_last_hidden_state": [],
        }

        for k in items.keys():
            k: str
            if k.startswith('caption_'):
                result[k] = items[k]

        for audio_embeds_path in items["audio_embeds_last_hidden_state_file_name"]:
            audio_embeds_full_path = pathlib.Path(audio_embeds_path_prefix).joinpath(audio_embeds_path)
            audio_embeds = torch.load(audio_embeds_full_path, map_location='cpu')
            result["audio_embeds_last_hidden_state"].append(audio_embeds)

        return result

    return _data_preloader


def get_collate_fn(tokenizer, validation=False):
    def collate_fn(items):
        result = dict()
        # random select caption
        current_caption_i = random.randint(1, 5)
        tokenizer_input = [item[f'caption_{current_caption_i}'] for item in items]
        tokenized_caption = tokenizer(tokenizer_input, padding=True)
        result['input_ids'] = torch.tensor(tokenized_caption['input_ids'])
        result['attention_mask'] = torch.tensor(tokenized_caption['attention_mask'])
        # result['pixel_values'] = torch.cat([x['pixel_values'] for x in items], dim=0)
        result['audio_embeds_last_hidden_state'] = torch.cat([x['audio_embeds_last_hidden_state'] for x in items], dim=0)

        if validation:
            all_captions = []
            for item in items:
                for current_caption_i in range(1, 6):
                    all_captions.append(item[f'caption_{current_caption_i}'])

            tokenized_caption = tokenizer(all_captions, padding=True)
            result['all_input_ids'] = torch.tensor(tokenized_caption['input_ids'])
            result['all_attention_mask'] = torch.tensor(tokenized_caption['attention_mask'])
        return result
    return collate_fn


def get_train_dataloader(
        train_config: TrainConfig, llaaa: Llaaa, tokenizer,
        train_dataset_path,
        audio_embeds_train_prefix):

    audio_captions_dataset_train: datasets.Dataset = datasets.load_from_disk(train_dataset_path)
    if train_config.few_train_samples is not None:
        audio_captions_dataset_train = audio_captions_dataset_train.select(range(train_config.few_train_samples))

    audio_captions_dataset_train.set_transform(data_preloader(audio_embeds_train_prefix))
    return DataLoader(audio_captions_dataset_train, collate_fn=get_collate_fn(tokenizer), batch_size=train_config.train_batch_size, num_workers=train_config.dataloader_num_workers, shuffle=True, drop_last=True)


def get_val_dataloader(
        train_config: TrainConfig, llaaa: Llaaa, tokenizer,
        val_dataset_path,
        audio_embeds_val_prefix):

    audio_captions_dataset_val: datasets.Dataset = datasets.load_from_disk(val_dataset_path)
    if train_config.few_val_samples is not None:
        audio_captions_dataset_val = audio_captions_dataset_val.select(range(train_config.few_val_samples))

    audio_captions_dataset_val.set_transform(data_preloader(audio_embeds_val_prefix))

    return DataLoader(audio_captions_dataset_val,
                      collate_fn=get_collate_fn(tokenizer, validation=True),
                      batch_size=train_config.val_batch_size)



# Предобучаем проекцию в адаптере

In [14]:
# 1st iter (5 epoch)
# {'validation/evaluate_bleu': 6.417021607421673,
#  'validation/evaluate_rouge1': 0.33109226810102377,
#  'validation/evaluate_rouge2': 0.11095769776340605,
#  'validation/evaluate_rougeL': 0.2752364273184304,
#  'validation/evaluate_rougeLsum': 0.27553070368757215,
#  'validation/evaluate_meteor': 0.32073160908059584}
# 2nd iter (10 epoch)
# {'validation/evaluate_bleu': 6.40709213312877,
#  'validation/evaluate_rouge1': 0.32817260743542964,
#  'validation/evaluate_rouge2': 0.11000342343713507,
#  'validation/evaluate_rougeL': 0.27342381749238054,
#  'validation/evaluate_rougeLsum': 0.2733547105500546,
#  'validation/evaluate_meteor': 0.32166420512484795}
# 2nd iter (6 epoch + less leraning rate)
# 4th iter (покрутили MLP)
# 5th iter (добавили skip-connection)
# {'validation/evaluate_bleu': 7.77831788507012,
#  'validation/evaluate_rouge1': 0.34466447200431,
#  'validation/evaluate_rouge2': 0.12496668780912443,
#  'validation/evaluate_rougeL': 0.29159389419317716,
#  'validation/evaluate_rougeLsum': 0.2913731156012332,
#  'validation/evaluate_meteor': 0.3413907628246414}

In [15]:
# TODO в этой ячейке я кое-что убрал, но не скажу что)
# Подсказка про это была в текстовом описании в начале ноутбука.
# И скорее всего, без этого, вам не хватит видео-памяти чтобы обучить адаптер.
# Не нужно хранить градиенты для параметров `LLM`. Скорее всего, у вас не хватит для этого памяти.
# И это не нужно, тк мы не оптимизируем на этом этапе веса `LLM`.
train_config = TrainConfig()

log_level = logging.getLevelName(train_config.log_level)
logger.info("loglevel %s", train_config.log_level)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger.info("load audio encoder")
audio_encoder = DummyAudioEncoder()
audio_encoder.to(device)

logger.info("load language model")

lm_model, tokenizer = llama_lm()
freeze_model(lm_model) # TODO
lm_model.to(device)

mmllm_config = LlaaaConfig()
model = Llaaa(lm_model=lm_model, audio_encoder=audio_encoder, config=mmllm_config)

logger.info("reinitialize weights")
model.reinitialize_weights()

trainable_parameters_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_parameters_count = sum(p.numel() for p in model.parameters())
logger.info(f"trainable model parameters: {trainable_parameters_count}")
logger.info(f"total model parameters: {total_parameters_count}")
logger.info(f"trainable model llm parameters: {sum(p.numel() for p in model.lm_model.parameters() if p.requires_grad)}")
logger.info(f"trainable model audio encoder parameters: {sum(p.numel() for p in model.audio_encoder.parameters() if p.requires_grad)}")

logger.info("load train dataloader")
train_dataloader = get_train_dataloader(
    train_config, model, tokenizer,
    train_dataset_path=train_config.train_dataset_path,
    audio_embeds_train_prefix=train_config.audio_embeds_train_prefix,
)
logger.info("load val dataloader")
val_dataloader = get_val_dataloader(
    train_config, model, tokenizer,
    val_dataset_path=train_config.val_dataset_path,
    audio_embeds_val_prefix=train_config.audio_embeds_val_prefix,
)

logger.info("run training")

with wandb.init(project="hw-multimodal-llm") as metric_logger:
    train(
        model=model,
        tokenizer=tokenizer,
        metric_logger=metric_logger,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        train_config=train_config,
    )



2024-06-23 14:08:33,292 INFO loglevel DEBUG
2024-06-23 14:08:33,292 INFO load audio encoder
2024-06-23 14:08:33,293 INFO load language model
2024-06-23 14:08:33,296 DEBUG Starting new HTTPS connection (1): huggingface.co:443
2024-06-23 14:08:33,551 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/config.json HTTP/1.1" 200 0
2024-06-23 14:08:33,757 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/config.json HTTP/1.1" 200 0
2024-06-23 14:08:37,478 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/generation_config.json HTTP/1.1" 200 0
2024-06-23 14:08:38,171 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2024-06-23 14:08:42,122 INFO reinitialize weights
2024-06-23 14:08:44,970 INFO trainable model parameters: 405018624
2024-06-23 14:08:44,971 INFO total model parameters: 1505079296
2024-06-23 14

2024-06-23 14:08:49,094 DEBUG Starting new HTTPS connection (1): s3.amazonaws.com:443
2024-06-23 14:08:49,494 DEBUG https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/metrics/evaluate-metric/bleu/evaluate-metric/bleu.py HTTP/1.1" 404 0
2024-06-23 14:08:49,510 DEBUG Starting new HTTPS connection (1): huggingface.co:443
2024-06-23 14:08:49,710 DEBUG https://huggingface.co:443 "HEAD /spaces/evaluate-metric/bleu/resolve/v0.4.2/bleu.py HTTP/1.1" 404 0
2024-06-23 14:08:49,714 DEBUG Starting new HTTPS connection (1): huggingface.co:443
2024-06-23 14:08:50,027 DEBUG https://huggingface.co:443 "HEAD /spaces/evaluate-metric/bleu/resolve/main/bleu.py HTTP/1.1" 200 0
2024-06-23 14:08:50,053 DEBUG Starting new HTTPS connection (1): github.com:443
2024-06-23 14:08:50,352 DEBUG https://github.com:443 "HEAD /tensorflow/nmt/raw/master/nmt/scripts/bleu.py HTTP/1.1" 302 0
2024-06-23 14:08:50,364 DEBUG Starting new HTTPS connection (1): raw.githubusercontent.com:443
2024-06-23 14:08:50,6

Epoch 0:   0%|          | 0/479 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(
2024-06-23 14:12:50,234 INFO go validation 0


  0%|          | 0/100 [00:00<?, ?it/s]

2024-06-23 14:14:11,562 INFO Using default tokenizer.
2024-06-23 14:14:13,677 INFO evaluate_bleu_results {'bleu': 0.08146237253385034, 'precisions': [0.46244477172312226, 0.12003179650238474, 0.046632124352331605, 0.017013232514177693], 'brevity_penalty': 1.0, 'length_ratio': 1.3235867446393763, 'translation_length': 1358, 'reference_length': 1026, 'rouge1': 0.33741654690026357, 'rouge2': 0.11199968503994058, 'rougeL': 0.2781781588743347, 'rougeLsum': 0.27817155944114846, 'meteor': 0.291387277906247}
2024-06-23 14:14:13,678 INFO validation metrics {'validation/evaluate_bleu': 8.146237253385035, 'validation/evaluate_rouge1': 0.33741654690026357, 'validation/evaluate_rouge2': 0.11199968503994058, 'validation/evaluate_rougeL': 0.2781781588743347, 'validation/evaluate_rougeLsum': 0.27817155944114846, 'validation/evaluate_meteor': 0.291387277906247}
2024-06-23 14:14:13,680 INFO save model to data\models\expert-firebrand-14\best
2024-06-23 14:14:16,000 INFO save model to data\models\expert-f

Epoch 1:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 14:18:16,211 INFO save model to data\models\expert-firebrand-14\last


Epoch 2:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 14:22:18,273 INFO save model to data\models\expert-firebrand-14\last


Epoch 3:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 14:26:17,263 INFO save model to data\models\expert-firebrand-14\last


Epoch 4:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 14:30:17,584 INFO go validation 4


  0%|          | 0/100 [00:00<?, ?it/s]

2024-06-23 14:31:36,620 INFO Using default tokenizer.
2024-06-23 14:31:37,376 INFO evaluate_bleu_results {'bleu': 0.10243499306253456, 'precisions': [0.4823529411764706, 0.13531914893617022, 0.06325581395348837, 0.02666666666666667], 'brevity_penalty': 1.0, 'length_ratio': 1.2426900584795322, 'translation_length': 1275, 'reference_length': 1026, 'rouge1': 0.36547326897435495, 'rouge2': 0.13485497090983375, 'rougeL': 0.30783176273814994, 'rougeLsum': 0.3063263358411153, 'meteor': 0.29316970322128555}
2024-06-23 14:31:37,379 INFO validation metrics {'validation/evaluate_bleu': 10.243499306253456, 'validation/evaluate_rouge1': 0.36547326897435495, 'validation/evaluate_rouge2': 0.13485497090983375, 'validation/evaluate_rougeL': 0.30783176273814994, 'validation/evaluate_rougeLsum': 0.3063263358411153, 'validation/evaluate_meteor': 0.29316970322128555}
2024-06-23 14:31:37,380 INFO save model to data\models\expert-firebrand-14\best
2024-06-23 14:31:39,895 INFO save model to data\models\expert

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
grad_norm_lin_proj.bias,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
grad_norm_lin_proj.weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
grad_norm_proj.0.bias,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
grad_norm_proj.0.weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
grad_norm_proj.2.bias,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
grad_norm_proj.2.weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▇▃▅▅▄▇▄▄▅▂▅▄▇▆▄▄▃▅▅▆▃▇▅▄▂▅▇▄▄▃▂▂▃▅▁▄▂▆▃
validation/evaluate_bleu,▁█
validation/evaluate_meteor,▁█

0,1
epoch,4.0
grad_norm_lin_proj.bias,0.0
grad_norm_lin_proj.weight,0.0
grad_norm_proj.0.bias,0.0
grad_norm_proj.0.weight,0.0
grad_norm_proj.2.bias,0.0
grad_norm_proj.2.weight,0.0
train_loss,2.43445
validation/evaluate_bleu,10.2435
validation/evaluate_meteor,0.29317


2024-06-23 14:31:49,935 DEBUG Starting new HTTPS connection (1): o151352.ingest.sentry.io:443
2024-06-23 14:31:50,369 DEBUG https://o151352.ingest.sentry.io:443 "POST /api/4504800232407040/envelope/ HTTP/1.1" 200 0


In [16]:
model.save_pretrained('C:\\Users\\arssm\\0Jupyter_Life\\DL\\LLM')

## Посчитаем метрики на всем валидационном датасете

In [17]:
captioning_metrics = evaluate.combine(
        [
            evaluate.load("bleu", keep_in_memory=True),
            evaluate.load("rouge", keep_in_memory=True),
            evaluate.load("meteor", keep_in_memory=True),
        ]
    )

2024-06-23 14:32:16,833 DEBUG Starting new HTTPS connection (1): s3.amazonaws.com:443
2024-06-23 14:32:17,269 DEBUG https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/metrics/evaluate-metric/bleu/evaluate-metric/bleu.py HTTP/1.1" 404 0
2024-06-23 14:32:17,279 DEBUG Starting new HTTPS connection (1): huggingface.co:443
2024-06-23 14:32:17,476 DEBUG https://huggingface.co:443 "HEAD /spaces/evaluate-metric/bleu/resolve/v0.4.2/bleu.py HTTP/1.1" 404 0
2024-06-23 14:32:17,483 DEBUG Starting new HTTPS connection (1): huggingface.co:443
2024-06-23 14:32:17,717 DEBUG https://huggingface.co:443 "HEAD /spaces/evaluate-metric/bleu/resolve/main/bleu.py HTTP/1.1" 200 0
2024-06-23 14:32:17,737 DEBUG Starting new HTTPS connection (1): github.com:443
2024-06-23 14:32:18,006 DEBUG https://github.com:443 "HEAD /tensorflow/nmt/raw/master/nmt/scripts/bleu.py HTTP/1.1" 302 0
2024-06-23 14:32:18,013 DEBUG Starting new HTTPS connection (1): raw.githubusercontent.com:443
2024-06-23 14:32:18,2

In [18]:
train_config.few_val_samples = None
train_config.val_batch_size = 8
full_val_dataloader = get_val_dataloader(
    train_config, model, tokenizer,
    val_dataset_path=train_config.val_dataset_path,
    audio_embeds_val_prefix=train_config.audio_embeds_val_prefix,
)


validation_metrics = val_loop(model, tokenizer, full_val_dataloader, epoch=-1, captioning_metrics=captioning_metrics)
validation_metrics

2024-06-23 14:32:24,372 INFO go validation -1


  0%|          | 0/131 [00:00<?, ?it/s]

2024-06-23 14:37:19,325 INFO Using default tokenizer.
2024-06-23 14:37:25,824 INFO evaluate_bleu_results {'bleu': 0.0777831788507012, 'precisions': [0.4381200465554509, 0.11566066596070126, 0.0443176831943835, 0.016299973168768446], 'brevity_penalty': 1.0, 'length_ratio': 1.7429482225656878, 'translation_length': 18043, 'reference_length': 10352, 'rouge1': 0.34466447200431, 'rouge2': 0.12496668780912443, 'rougeL': 0.29159389419317716, 'rougeLsum': 0.2913731156012332, 'meteor': 0.3413907628246414}


{'validation/evaluate_bleu': 7.77831788507012,
 'validation/evaluate_rouge1': 0.34466447200431,
 'validation/evaluate_rouge2': 0.12496668780912443,
 'validation/evaluate_rougeL': 0.29159389419317716,
 'validation/evaluate_rougeLsum': 0.2913731156012332,
 'validation/evaluate_meteor': 0.3413907628246414}

In [None]:
# Заархивируем данные
# В репозиторий в этот раз я предлагаю закоммитить веса через [git-lfs](https://git-lfs.com/)
# На стороне гитхаба стоит ограничение на 50MB. Если грузить без lfs
# Вот тут гуглится туториал по добавлению файликов в lfs - https://josh-ops.com/posts/add-files-to-git-lfs/

# Альтернативный вариант разбить файлик на 10+ кусков-архивов
# но это не очень удобно - см `zip -rs 40M ...`

# !cd data/models/lunar-dragon-12/ && mv best llaaa_pretrained  && zip llaaa_pretrained.zip llaaa_pretrained/

  adding: llaaa_pretrained/ (stored 0%)
  adding: llaaa_pretrained/audio_tokens_embeddings.pt (deflated 11%)
  adding: llaaa_pretrained/projection.pt (deflated 7%)
  adding: llaaa_pretrained/config.json (deflated 21%)


In [17]:
# import os
# import shutil
# import zipfile

# # Переход в директорию (не обязательно)
# os.chdir('data/models/lunar-dragon-12/')

# # Переименование директории
# src = 'best'
# dst = 'llaaa_pretrained'
# if os.path.exists(src):
#     shutil.move(src, dst)

# # Архивирование директории
# shutil.make_archive('llaaa_pretrained', 'zip', dst)

# # Проверка результата
# print("Переименование и архивирование завершено.")


Переименование и архивирование завершено.


> **Важно!** Дальнейшие действия не оценятся автогрейдером. Если справились с первичным обучениеп адаптера, загрузите решение на гитхаб. Потом, если есть желание, можете попробовать получить бонусные баллы.



# Файнтюним LoRA адаптер LLM'ки вместе с дообучением нашего адаптера для аудио модальности

> **Важно!** Дальнейшие действия не оценятся автогрейдером. Если справились с первичным обучениеп адаптера, загрузите решение на гитхаб. Потом, если есть желание, можете попробовать получить бонусные баллы.


`peft` -- это библиотека от HF для эффктивного файнтюна LLM.

Чаще всего используется метод [LoRA](https://huggingface.co/docs/peft/main/en/conceptual_guides/lora) для файнтюна LLM -- настоятельно рекомендую почитать доку про этот метод в HF.

Хотя есть и другие методы:
* [Prompt-based methods](https://huggingface.co/docs/peft/main/en/task_guides/prompt_based_methods)
* [IA3](https://huggingface.co/docs/peft/main/en/task_guides/ia3)



In [28]:
from peft import PeftModel, LoraConfig
from peft import PeftConfig
import peft

def peft_llama_lm(from_pretrained=None):
    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v0.4"
    lm_model = TinyLlamaLanguageModel.from_pretrained(model_name)

    if from_pretrained is None:
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        lm_model = PeftTinyLlamaLanguageModel(lm_model, lora_config)
    else:
        lm_model = PeftTinyLlamaLanguageModel.from_pretrained(lm_model, from_pretrained)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return lm_model, tokenizer




class PeftTinyLlamaLanguageModel(PeftModel):

    def encode_text(self, intput_ids=None):
        #          .peft.llama_model
        return self.model.model.embed_tokens(intput_ids)

    @classmethod
    def from_pretrained(
        cls,
        model: PreTrainedModel,
        model_id: Union[str, os.PathLike],
        adapter_name: str = "default",
        is_trainable: bool = False,
        config: Optional[PeftConfig] = None,
        **kwargs: Any,
    ):
        from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING

        # load the config
        if config is None:
            config = PEFT_TYPE_TO_CONFIG_MAPPING[
                PeftConfig._get_peft_type(
                    model_id,
                    subfolder=kwargs.get("subfolder", None),
                    revision=kwargs.get("revision", None),
                    cache_dir=kwargs.get("cache_dir", None),
                    use_auth_token=kwargs.get("use_auth_token", None),
                )
            ].from_pretrained(model_id, **kwargs)
        elif isinstance(config, PeftConfig):
            config.inference_mode = not is_trainable
        else:
            raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}")

        if (getattr(model, "hf_device_map", None) is not None) and len(
            set(model.hf_device_map.values()).intersection({"cpu", "disk"})
        ) > 0:
            peft.remove_hook_from_submodules(model)

        if config.is_prompt_learning and is_trainable:
            raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
        else:
            config.inference_mode = not is_trainable

        model = cls(model, config, adapter_name)
        model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
        return model

In [29]:
class PeftTrainConfig:
    log_level = "DEBUG"
    # Training
    num_epochs = 15
    train_batch_size = 8
    val_batch_size = 1
    log_grad_norm = True
    learning_rate = 1e-4
    gradient_accumulation_steps = 1

    evaluate_every_epoch_mod = 4
    save_model_every_epoch_mod = 1

    # Model
    llaaa_from_pretrained = "data/models/swept-forest-9/best/"
    modality_tokens = 32
    llm_train_lora = True
    llm_lora_from_pretrained = None

    # Data
    few_train_samples = None
    few_val_samples = 100
    dataloader_num_workers = 0

    train_dataset_path = "data/CLOTHO_v2.1/clotho_hf_dataset/clotho_development_imagebind_single.dataset/"
    audio_embeds_train_prefix = "data/CLOTHO_v2.1/clotho_audio_embeds_processed_imagebind_single/development/"

    val_dataset_path = "data/CLOTHO_v2.1/clotho_hf_dataset/clotho_validation_imagebind_single.dataset/"
    audio_embeds_val_prefix = "data/CLOTHO_v2.1/clotho_audio_embeds_processed_imagebind_single/validation/"


In [22]:
!dir /Q data\models\swept-forest-9\best

 Том в устройстве C не имеет метки.
 Серийный номер тома: 56BB-6321

 Содержимое папки C:\Users\arssm\0Jupyter_Life\DL\LLM\data\models\swept-forest-9\best

23.06.2024  00:51    <DIR>          DESK-ARS\arssm         .
23.06.2024  00:51    <DIR>          DESK-ARS\arssm         ..
23.06.2024  01:04            17 772 DESK-ARS\arssm         audio_tokens_embeddings.pt
23.06.2024  01:04               125 DESK-ARS\arssm         config.json
23.06.2024  01:04       537 396 778 DESK-ARS\arssm         projection.pt
               3 файлов    537 414 675 байт
               2 папок  534 493 265 920 байт свободно


In [30]:
peft_train_config = PeftTrainConfig()

lm_model, tokenizer = peft_llama_lm(from_pretrained=peft_train_config.llm_lora_from_pretrained)

logger.info("load llaaa weights: %s", peft_train_config.llaaa_from_pretrained)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

audio_encoder = DummyAudioEncoder()
audio_encoder.to(device)

freeze_model(audio_encoder)

llaaa_from_pretrained = pathlib.Path(peft_train_config.llaaa_from_pretrained)
model = Llaaa.from_pretrained(lm_model, audio_encoder, llaaa_from_pretrained)

logger.info("load train dataloader")
train_dataloader = get_train_dataloader(
    peft_train_config, model, tokenizer,
    train_dataset_path=peft_train_config.train_dataset_path,
    audio_embeds_train_prefix=peft_train_config.audio_embeds_train_prefix,
)
logger.info("load val dataloader")
val_dataloader = get_val_dataloader(
    peft_train_config, model, tokenizer,
    val_dataset_path=peft_train_config.val_dataset_path,
    audio_embeds_val_prefix=peft_train_config.audio_embeds_val_prefix,
)

logger.info("run training")


# audio_tokens_embeddings = torch.load(llaaa_from_pretrained.joinpath('audio_tokens_embeddings.pt'))
# projection = torch.load(llaaa_from_pretrained.joinpath('projection.pt'))
# model.audio_tokens_embeddings.load_state_dict(audio_tokens_embeddings)
# model.projection.load_state_dict(projection)


2024-06-23 01:23:35,779 DEBUG Resetting dropped connection: huggingface.co
2024-06-23 01:23:36,003 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/config.json HTTP/1.1" 200 0
2024-06-23 01:23:36,197 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/config.json HTTP/1.1" 200 0
2024-06-23 01:23:42,086 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/generation_config.json HTTP/1.1" 200 0
2024-06-23 01:23:43,740 DEBUG https://huggingface.co:443 "HEAD /TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2024-06-23 01:23:43,851 INFO load llaaa weights: data/models/swept-forest-9/best/
2024-06-23 01:23:45,645 INFO load train dataloader
2024-06-23 01:23:45,744 INFO load val dataloader
2024-06-23 01:23:45,760 INFO run training


In [31]:
with wandb.init(project="hw-multimodal-llm-peft") as metric_logger:
    train(
        model=model,
        tokenizer=tokenizer,
        metric_logger=metric_logger,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        train_config=peft_train_config,
        # device_placement=False,
    )

2024-06-23 01:24:05,811 DEBUG Starting new HTTPS connection (1): s3.amazonaws.com:443
2024-06-23 01:24:06,254 DEBUG https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/metrics/evaluate-metric/bleu/evaluate-metric/bleu.py HTTP/1.1" 404 0
2024-06-23 01:24:06,269 DEBUG Starting new HTTPS connection (1): huggingface.co:443
2024-06-23 01:24:06,472 DEBUG https://huggingface.co:443 "HEAD /spaces/evaluate-metric/bleu/resolve/v0.4.2/bleu.py HTTP/1.1" 404 0
2024-06-23 01:24:06,474 DEBUG Starting new HTTPS connection (1): huggingface.co:443
2024-06-23 01:24:06,660 DEBUG https://huggingface.co:443 "HEAD /spaces/evaluate-metric/bleu/resolve/main/bleu.py HTTP/1.1" 200 0
2024-06-23 01:24:06,673 DEBUG Starting new HTTPS connection (1): github.com:443
2024-06-23 01:24:06,941 DEBUG https://github.com:443 "HEAD /tensorflow/nmt/raw/master/nmt/scripts/bleu.py HTTP/1.1" 302 0
2024-06-23 01:24:06,944 DEBUG Starting new HTTPS connection (1): raw.githubusercontent.com:443
2024-06-23 01:24:07,1

Epoch 0:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:28:07,100 INFO go validation 0


  0%|          | 0/100 [00:00<?, ?it/s]

2024-06-23 01:29:41,752 INFO Using default tokenizer.
2024-06-23 01:29:42,494 INFO evaluate_bleu_results {'bleu': 0.05584922636961504, 'precisions': [0.4437299035369775, 0.09003496503496504, 0.028735632183908046, 0.00847457627118644], 'brevity_penalty': 1.0, 'length_ratio': 1.2124756335282652, 'translation_length': 1244, 'reference_length': 1026, 'rouge1': 0.3301201860184502, 'rouge2': 0.09150369168575267, 'rougeL': 0.2688028954758369, 'rougeLsum': 0.26922956923436414, 'meteor': 0.26987091217259257}
2024-06-23 01:29:42,495 INFO validation metrics {'validation/evaluate_bleu': 5.584922636961505, 'validation/evaluate_rouge1': 0.3301201860184502, 'validation/evaluate_rouge2': 0.09150369168575267, 'validation/evaluate_rougeL': 0.2688028954758369, 'validation/evaluate_rougeLsum': 0.26922956923436414, 'validation/evaluate_meteor': 0.26987091217259257}
2024-06-23 01:29:42,498 INFO save model to data\models\honest-dream-1\best
2024-06-23 01:29:43,661 INFO save model to data\models\honest-dream-

Epoch 1:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:33:37,399 INFO save model to data\models\honest-dream-1\last


Epoch 2:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:37:32,717 INFO save model to data\models\honest-dream-1\last


Epoch 3:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:41:27,182 INFO save model to data\models\honest-dream-1\last


Epoch 4:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:45:21,211 INFO go validation 4


  0%|          | 0/100 [00:00<?, ?it/s]

2024-06-23 01:46:54,969 INFO Using default tokenizer.
2024-06-23 01:46:55,638 INFO evaluate_bleu_results {'bleu': 0.06729058632974053, 'precisions': [0.4549800796812749, 0.09956709956709957, 0.03601895734597156, 0.012565445026178011], 'brevity_penalty': 1.0, 'length_ratio': 1.2231968810916178, 'translation_length': 1255, 'reference_length': 1026, 'rouge1': 0.3509998605468112, 'rouge2': 0.1040502279401258, 'rougeL': 0.2951558765056362, 'rougeLsum': 0.2959805223813886, 'meteor': 0.28016096154448117}
2024-06-23 01:46:55,640 INFO validation metrics {'validation/evaluate_bleu': 6.729058632974053, 'validation/evaluate_rouge1': 0.3509998605468112, 'validation/evaluate_rouge2': 0.1040502279401258, 'validation/evaluate_rougeL': 0.2951558765056362, 'validation/evaluate_rougeLsum': 0.2959805223813886, 'validation/evaluate_meteor': 0.28016096154448117}
2024-06-23 01:46:55,641 INFO save model to data\models\honest-dream-1\best
2024-06-23 01:46:56,404 INFO save model to data\models\honest-dream-1\la

Epoch 5:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:50:52,073 INFO save model to data\models\honest-dream-1\last


Epoch 6:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:54:45,827 INFO save model to data\models\honest-dream-1\last


Epoch 7:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 01:58:41,856 INFO save model to data\models\honest-dream-1\last


Epoch 8:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 02:02:35,793 INFO go validation 8


  0%|          | 0/100 [00:00<?, ?it/s]

2024-06-23 02:04:10,747 INFO Using default tokenizer.
2024-06-23 02:04:11,434 INFO evaluate_bleu_results {'bleu': 0.07288516500218728, 'precisions': [0.44251968503937006, 0.10341880341880341, 0.037383177570093455, 0.016494845360824743], 'brevity_penalty': 1.0, 'length_ratio': 1.2378167641325537, 'translation_length': 1270, 'reference_length': 1026, 'rouge1': 0.3384924526186996, 'rouge2': 0.11473489266146347, 'rougeL': 0.30009764744214595, 'rougeLsum': 0.3002459060825564, 'meteor': 0.29204298233883047}
2024-06-23 02:04:11,435 INFO validation metrics {'validation/evaluate_bleu': 7.288516500218728, 'validation/evaluate_rouge1': 0.3384924526186996, 'validation/evaluate_rouge2': 0.11473489266146347, 'validation/evaluate_rougeL': 0.30009764744214595, 'validation/evaluate_rougeLsum': 0.3002459060825564, 'validation/evaluate_meteor': 0.29204298233883047}
2024-06-23 02:04:11,437 INFO save model to data\models\honest-dream-1\best
2024-06-23 02:04:12,204 INFO save model to data\models\honest-drea

Epoch 9:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 02:08:06,214 INFO save model to data\models\honest-dream-1\last


Epoch 10:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 02:12:00,774 INFO save model to data\models\honest-dream-1\last


Epoch 11:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 02:15:55,493 INFO save model to data\models\honest-dream-1\last


Epoch 12:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 02:19:48,779 INFO go validation 12


  0%|          | 0/100 [00:00<?, ?it/s]

2024-06-23 02:21:24,670 INFO Using default tokenizer.
2024-06-23 02:21:25,380 INFO evaluate_bleu_results {'bleu': 0.06550806533177979, 'precisions': [0.4449648711943794, 0.10499576629974598, 0.03515263644773358, 0.011213047910295617], 'brevity_penalty': 1.0, 'length_ratio': 1.2485380116959064, 'translation_length': 1281, 'reference_length': 1026, 'rouge1': 0.3379166568280968, 'rouge2': 0.1039989182843962, 'rougeL': 0.28585280349147557, 'rougeLsum': 0.2866509743696778, 'meteor': 0.2789338735482382}
2024-06-23 02:21:25,382 INFO validation metrics {'validation/evaluate_bleu': 6.550806533177979, 'validation/evaluate_rouge1': 0.3379166568280968, 'validation/evaluate_rouge2': 0.1039989182843962, 'validation/evaluate_rougeL': 0.28585280349147557, 'validation/evaluate_rougeLsum': 0.2866509743696778, 'validation/evaluate_meteor': 0.2789338735482382}
2024-06-23 02:21:25,384 INFO save model to data\models\honest-dream-1\last


Epoch 13:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 02:25:20,018 INFO save model to data\models\honest-dream-1\last


Epoch 14:   0%|          | 0/479 [00:00<?, ?it/s]

2024-06-23 02:29:13,106 INFO save model to data\models\honest-dream-1\last
2024-06-23 02:29:14,574 INFO save model to data\models\honest-dream-1\last


VBox(children=(Label(value='0.001 MB of 0.022 MB uploaded\r'), FloatProgress(value=0.044143563292499466, max=1…

0,1
epoch,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇███
grad_norm_bias,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
grad_norm_weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,▇█▄▇▅▆▅▅▆▇▄▄▄▂▄▂▂▃▅▇▅▄▃▅▃▂▅▄▅▂▂▂▁▁▃▂▁▂▃▂
validation/evaluate_bleu,▁▆█▅
validation/evaluate_meteor,▁▄█▄
validation/evaluate_rouge1,▁█▄▄
validation/evaluate_rouge2,▁▅█▅
validation/evaluate_rougeL,▁▇█▅
validation/evaluate_rougeLsum,▁▇█▅

0,1
epoch,14.0
grad_norm_bias,0.0
grad_norm_weight,0.0
train_loss,1.59014
validation/evaluate_bleu,6.55081
validation/evaluate_meteor,0.27893
validation/evaluate_rouge1,0.33792
validation/evaluate_rouge2,0.104
validation/evaluate_rougeL,0.28585
validation/evaluate_rougeLsum,0.28665


2024-06-23 02:29:19,081 DEBUG Resetting dropped connection: o151352.ingest.sentry.io
2024-06-23 02:29:19,480 DEBUG https://o151352.ingest.sentry.io:443 "POST /api/4504800232407040/envelope/ HTTP/1.1" 200 0
