# Text-Image to Text Supervised Fine-Tuning with Align-Anything

This tutorial introduces how to perform supervised fine-tuning (SFT) on multimodal models using the Align-Anything framework.

# Prerequisites

- Align-Anything installed.
- A text-image-to-text dataset, in this tutorialm, we use the [PKU-Alignment/Align-Anything-TI2T-Instruction-100K](https://huggingface.co/datasets/PKU-Alignment/Align-Anything-TI2T-Instruction-100K) dataset.
- LLaVA-1.5-7b model, you can download it from [here](https://huggingface.co/llava-hf/llava-1.5-7b-hf).
- A GPU with at least 70GB of memory.
> A lower memory GPU is also possible. We will provide a script using smaller models in the future.

## Loading Pre-trained Models

First, we need to load a pre-trained model. We'll use the LLaVA-1.5-7b model, which is a multimodal model capable of understanding both text and images.

In [None]:
from align_anything.models.pretrained_model import load_pretrained_models
from align_anything.utils.multi_process import get_current_device

# Load the pre-trained model, tokenizer, and processor
model, tokenizer, processor = load_pretrained_models(
    "/path/to/llava-1.5-7b-hf",  # Replace with your model path
    model_max_length=4096,
    padding_side='right',
    trust_remote_code=True,
    modality=['image'],
)

# Move the model to the available device (GPU if available)
model = model.to(get_current_device())

## Setting Up the Optimizer

For fine-tuning, we'll use the AdamW optimizer, which is a popular choice.

In [2]:
from torch.optim import AdamW

# Initialize the optimizer with a learning rate of 1e-5
optimizer = AdamW(model.parameters(), lr=1e-5)

##  Configuring the Chat Template

Align-Anything uses chat templates to format the input for the model. Here, we're using the *AA_TI2T* template, which is designed for align-anything text-image-to-text datasets.

In [3]:
from align_anything.configs.template import ChatTemplate

train_template = ChatTemplate(
    formatter=processor,
    template="AA_TI2T",
)

The template is a dataset-specific formatter, mapping the input data to the model's input format. Here is the detail of the *AA_TI2T* template:

```python
@register_template('AA_TI2T')
class AA_TI2T(BaseFormatter):
    system_prompt: str = ''
    
    def format_supervised_sample(
        self, raw_sample: dict[str, Any]
    ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        prompt = raw_sample['prompt']
        answer = raw_sample['response']
        image = raw_sample['image'].convert('RGBA')

        return [
            {
                'role': 'user',
                'content': [
                    {'type': 'image'},
                    {'type': 'text', 'text': prompt},
                ],
            },
            {'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]},
        ], {'image': image}
```

## Creating the Dataset

We'll use the SupervisedDataset class to load our text-image-to-text dataset.

In [4]:
from align_anything.datasets.text_image_to_text import SupervisedDataset

# Initialize the training dataset
train_dataset = SupervisedDataset(
    path="/path/to/Align-Anything-TI2T-Instruction-100K",  # Replace with your dataset path
    template=train_template,
    tokenizer=tokenizer,
    processor=processor,
    split="train",
    size=1000,  # Limit to 1000 samples for this tutorial
)

## Setting Up the DataLoader

The DataLoader will handle batching, shuffling, and loading the data.

In [5]:
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler

# Create a DataLoader for our training dataset
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=train_dataset.get_collator(),  # Custom collate function for our dataset
    sampler=RandomSampler(train_dataset),     # Randomly sample data
    batch_size=1,                             # Process one sample at a time
)

## Training Loop

Now we'll fine-tune the model for a few epochs. We save the model after each epoch.

In [None]:
from tqdm import tqdm
from collections import deque
import numpy as np
import os

progress_bar = tqdm(range(3*len(train_dataloader)), desc="Training for 1/3 epochs...")
losses = deque(maxlen=100)
os.makedirs('./output', exist_ok=True)

for epoch in range(3):
    progress_bar.set_description(f"Training for {epoch+1}/3 epochs...")
    for batch in train_dataloader:
        batch.pop('meta_info')
        model.train()
        loss = model(**batch)['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        progress_bar.update(1)
        progress_bar.set_postfix(loss=np.mean(losses))

    # Save the model after each epoch
    model.save_pretrained('./output')
    tokenizer.save_pretrained('./output')
    processor.save_pretrained('./output')


The complete code is here.

In [None]:
import os
from tqdm import tqdm
from collections import deque
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from torch.optim import AdamW

from align_anything.models.pretrained_model import load_pretrained_models
from align_anything.datasets.text_image_to_text import SupervisedDataset
from align_anything.configs.template import ChatTemplate
from align_anything.utils.multi_process import get_current_device


model, tokenizer, processor = load_pretrained_models(
    "/path/to/llava-1.5-7b-hf",
    model_max_length=4096,
    padding_side='right',
    trust_remote_code=True,
    modality=['image'],
)

model = model.to(get_current_device())

optimizer = AdamW(model.parameters(), lr=1e-5)

train_template = ChatTemplate(
    formatter=processor,
    template="AA_TI2T",
)

train_dataset = SupervisedDataset(
    path="/path/to/Align-Anything-TI2T-Instruction-100K",  # Replace with your dataset path
    template=train_template,
    tokenizer=tokenizer,
    processor=processor,
    split="train",
    size=1000,  # Limit to 1000 samples for this tutorial
)

train_dataloader = DataLoader(
    train_dataset,
    collate_fn=train_dataset.get_collator(),
    sampler=RandomSampler(train_dataset),
    batch_size=1,
)

progress_bar = tqdm(range(3*len(train_dataloader)), desc="Training for 1/3 epochs...")
losses = deque(maxlen=100)
os.makedirs('./output', exist_ok=True)

for epoch in range(3):
    progress_bar.set_description(f"Training for {epoch+1}/3 epochs...")
    for batch in train_dataloader:
        batch.pop('meta_info')
        model.train()
        loss = model(**batch)['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        progress_bar.update(1)
        progress_bar.set_postfix(loss=np.mean(losses))

    # Save the model after each epoch
    model.save_pretrained('./output')
    tokenizer.save_pretrained('./output')
    processor.save_pretrained('./output')