# 使用Align-Anything框架进行文本-图像到文本的SFT训练

这个教程介绍如何使用Align-Anything框架对多模态模型进行监督微调(SFT)。

# 准备工作

- Align-Anything已安装。
- 一个文本-图像到文本数据集，在本教程中，我们使用[PKU-Alignment/Align-Anything-TI2T-Instruction-100K](https://huggingface.co/datasets/PKU-Alignment/Align-Anything-TI2T-Instruction-100K)数据集。
- LLaVA-1.5-7b模型，可以从[这里](https://huggingface.co/llava-hf/llava-1.5-7b-hf)下载。
- 一个至少有70GB内存的GPU。
> 较低内存的GPU也是可能的。我们将在未来提供使用较小模型的脚本。

## 加载预训练模型

首先，我们需要加载一个预训练模型。我们将使用LLaVA-1.5-7b模型，这是一个能够理解文本和图像、生成文本的多模态模型。

In [None]:
from align_anything.models.pretrained_model import load_pretrained_models
from align_anything.utils.multi_process import get_current_device

# Load the pre-trained model, tokenizer, and processor
model, tokenizer, processor = load_pretrained_models(
    "/path/to/llava-1.5-7b-hf",  # Replace with your model path
    model_max_length=4096,
    padding_side='right',
    trust_remote_code=True,
    modality=['image'],
)

# Move the model to the available device (GPU if available)
model = model.to(get_current_device())

## 设置优化器

对于微调，我们将使用AdamW优化器，这是一个流行的选择。

In [2]:
from torch.optim import AdamW

# Initialize the optimizer with a learning rate of 1e-5
optimizer = AdamW(model.parameters(), lr=1e-5)

##  配置 Chat Template

Align-Anything使用 Chat Template 来格式化模型的输入。这里，我们使用 *AA_TI2T* Chat Template，这是为 Align-Anything-TI2T-Instruction-100K 设计的。

In [3]:
from align_anything.configs.template import ChatTemplate

train_template = ChatTemplate(
    formatter=processor,
    template="AA_TI2T",
)

Chat Template 是一个数据集特定的格式化工具，将输入数据映射到模型的输入格式。以下是 *AA_TI2T* 模板的详细信息：

```python
@register_template('AA_TI2T')
class AA_TI2T(BaseFormatter):
    system_prompt: str = ''
    
    def format_supervised_sample(
        self, raw_sample: dict[str, Any]
    ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        prompt = raw_sample['prompt']
        answer = raw_sample['response']
        image = raw_sample['image'].convert('RGBA')

        return [
            {
                'role': 'user',
                'content': [
                    {'type': 'image'},
                    {'type': 'text', 'text': prompt},
                ],
            },
            {'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]},
        ], {'image': image}
```

## 创建数据集

我们将使用 SupervisedDataset 类来加载我们的文本-图像到文本数据集。

In [4]:
from align_anything.datasets.text_image_to_text import SupervisedDataset

# Initialize the training dataset
train_dataset = SupervisedDataset(
    path="/path/to/Align-Anything-TI2T-Instruction-100K",  # Replace with your dataset path
    template=train_template,
    tokenizer=tokenizer,
    processor=processor,
    split="train",
    size=1000,  # Limit to 1000 samples for this tutorial
)

## 设定 DataLoader

DataLoader 将处理批量、打乱和加载数据。

In [5]:
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler

# Create a DataLoader for our training dataset
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=train_dataset.get_collator(),  # Custom collate function for our dataset
    sampler=RandomSampler(train_dataset),     # Randomly sample data
    batch_size=1,                             # Process one sample at a time
)

## 训练循环

现在我们将对模型进行几次微调。我们在每个epoch后保存模型。

In [None]:
from tqdm import tqdm
from collections import deque
import numpy as np
import os

progress_bar = tqdm(range(3*len(train_dataloader)), desc="Training for 1/3 epochs...")
losses = deque(maxlen=100)
os.makedirs('./output', exist_ok=True)

for epoch in range(3):
    progress_bar.set_description(f"Training for {epoch+1}/3 epochs...")
    for batch in train_dataloader:
        batch.pop('meta_info')
        model.train()
        loss = model(**batch)['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        progress_bar.update(1)
        progress_bar.set_postfix(loss=np.mean(losses))

    # Save the model after each epoch
    model.save_pretrained('./output')
    tokenizer.save_pretrained('./output')
    processor.save_pretrained('./output')


完整版代码如下：

In [None]:
import os
from tqdm import tqdm
from collections import deque
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from torch.optim import AdamW

from align_anything.models.pretrained_model import load_pretrained_models
from align_anything.datasets.text_image_to_text import SupervisedDataset
from align_anything.configs.template import ChatTemplate
from align_anything.utils.multi_process import get_current_device


model, tokenizer, processor = load_pretrained_models(
    "/path/to/llava-1.5-7b-hf",
    model_max_length=4096,
    padding_side='right',
    trust_remote_code=True,
    modality=['image'],
)

model = model.to(get_current_device())

optimizer = AdamW(model.parameters(), lr=1e-5)

train_template = ChatTemplate(
    formatter=processor,
    template="AA_TI2T",
)

train_dataset = SupervisedDataset(
    path="/path/to/Align-Anything-TI2T-Instruction-100K",  # Replace with your dataset path
    template=train_template,
    tokenizer=tokenizer,
    processor=processor,
    split="train",
    size=1000,  # Limit to 1000 samples for this tutorial
)

train_dataloader = DataLoader(
    train_dataset,
    collate_fn=train_dataset.get_collator(),
    sampler=RandomSampler(train_dataset),
    batch_size=1,
)

progress_bar = tqdm(range(3*len(train_dataloader)), desc="Training for 1/3 epochs...")
losses = deque(maxlen=100)
os.makedirs('./output', exist_ok=True)

for epoch in range(3):
    progress_bar.set_description(f"Training for {epoch+1}/3 epochs...")
    for batch in train_dataloader:
        batch.pop('meta_info')
        model.train()
        loss = model(**batch)['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        progress_bar.update(1)
        progress_bar.set_postfix(loss=np.mean(losses))

    # Save the model after each epoch
    model.save_pretrained('./output')
    tokenizer.save_pretrained('./output')
    processor.save_pretrained('./output')