In [1]:
# set os at top 
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['HF_HOME'] = '/data1/wln/hf_cache'

import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# load model and processor
processor = AutoProcessor.from_pretrained('Salesforce/blip2-opt-2.7b')
base_model = Blip2ForConditionalGeneration.from_pretrained(
    'Salesforce/blip2-opt-2.7b', 
    local_files_only=True,
    quantization_config=bnb_config
)

base_model = prepare_model_for_kbit_training(base_model)
# set training args 
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_checkpointing=True,
)
# lora config
lora_config = LoraConfig(
    r=16, #8
    lora_alpha=32, #16 
    lora_dropout=0.1, #0.05
    bias="none"
)

# get model for training
adapter_model = get_peft_model(base_model, lora_config)

  from .autonotebook import tqdm as notebook_tqdm
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.22s/it]


In [2]:
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset

ds = load_dataset(
    "parquet", 
    data_files=['../dataset/0000.parquet', '../dataset/0001.parquet']
)

# split
train_ds = ds.filter(lambda x:x['split'] == 'train', num_proc=32)['train']
print(len(train_ds))
val_ds = ds.filter(lambda x: x['split'] == 'val', num_proc=32)['train']
test_ds = ds.filter(lambda x: x['split'] == 'test', num_proc=32)['train']

# convert a huggingface dataset type to pytorch dataset type
class ImageCaptionDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        item = self.dataset[index]
        instruct = "A short image caption:"
        encoding = self.processor(
            images=item['image'], 
            text=instruct, 
            padding=True,
            truncation=True,
            return_tensors='pt'
        )
        encoding['label'] = item['caption'][0]
        return encoding

train_ds_pt = ImageCaptionDataset(train_ds, processor=processor)

def collator(batch):
    processed_batch = {}
    for key in batch[0].keys():
        if key != 'label':
            processed_batch[key] = torch.stack([b[key] for b in batch])
        else:
            labels = [b['label'] for b in batch]
            labels_pt = processor.tokenizer(
                labels,
                padding="max_length",
                truncation=True,
                max_length=128,
                return_tensors = 'pt'
            )
            labels_pt["input_ids"][labels_pt["input_ids"] == processor.tokenizer.pad_token_id] = -100
            processed_batch['labels_ids'] = labels_pt['input_ids']
            processed_batch['labels_attention_mask'] = labels_pt['attention_mask']
    
    return processed_batch
train_dataloader = DataLoader(train_ds_pt, shuffle=True, batch_size=training_args.per_device_train_batch_size, collate_fn=collator)
# batch = next(iter(train_dataloader))
# # Print the keys in the batch
# print("Batch Keys:", batch.keys())

# # Print shapes and types of each item in the batch
# for key, value in batch.items():
#     print(f"\nKey: {key}")
#     print(f"Type: {type(value)}")
#     print(f"Shape: {value.shape if isinstance(value, torch.Tensor) else 'N/A'}")
#     print(f"Sample Data: {value[0] if isinstance(value, torch.Tensor) else value}")



7033


In [None]:
from torch import autograd
from torch.amp import autocast, GradScaler
from transformers import get_scheduler
from torch.nn.utils import clip_grad_norm_
# Initialize optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(adapter_model.parameters(), lr=1e-5, eps=1e-5)
num_training_steps = len(train_dataloader) * 2  # Assume 2 epochs
num_warmup_steps = int(num_training_steps * 0.1)  # 10% warmup

lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

# Use mixed precision
scaler = GradScaler()
# lora train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
adapter_model.to(device)

## pad
from torch.nn.functional import pad
def pad_to_length(tensor, max_length, pad_value=None):
    pad_value = pad_value if pad_value is not None else processor.tokenizer.pad_token_id
    # Ensure tensor is at least 2D
    if tensor.dim() == 1:
        tensor = tensor.unsqueeze(0)
    # Pad or truncate tensor to max_length
    return torch.nn.functional.pad(
        tensor, (0, max_length - tensor.shape[1]), value=pad_value
    ) if tensor.shape[1] < max_length else tensor[:, :max_length]
## save memory 
if training_args.gradient_checkpointing:
    adapter_model.gradient_checkpointing_enable()

## train mode
adapter_model.train()
loss_list=[]
for epoch in range(2):
    print("Epoch:", epoch)
    sum_loss_list = []
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        ## Forward pass
        input_ids = batch.pop('input_ids').squeeze(1) #instruct: "A short image caption:"
        pixel_values = batch['pixel_values'].squeeze(1) # encoded pixel_values
        labels = batch['labels_ids'].squeeze(1) # label tensors

        # Padding input_ids with 0 (default padding token for most tokenizers)
        input_ids = pad_to_length(input_ids, 128, pad_value=-100)
        # Padding attention_mask with 0 (default padding token for most tokenizers)
        attention_mask = pad_to_length(attention_mask, 128, pad_value=-100) # "a/m for instruct"
        # # Padding labels with -100 to ignore them in the loss function
        # labels = pad_to_length(labels, 128, pad_value=-100)

        ## feed to model
        with autocast():
            outputs = adapter_model(
                input_ids = input_ids,
                pixel_values = pixel_values, 
                labels = labels, 
            )
            loss = outputs.loss

        print(f"Step {step}, Loss: {loss.item()}")

        sum_loss_list.append(float(loss.item()))

        optimizer.zero_grad()

        ## Backward pass 
        scaler.scale(loss).backward()

        clip_grad_norm_(adapter_model.parameters(), max_norm=1.0)
        ## update weights
        scaler.step(optimizer)
        scaler.update()

        lr_scheduler.step()

        if step % 10 == 0: 
            generated_output = adapter_model.generate(pixel_values=pixel_values, max_new_tokens=20)
            print("Generated caption:", processor.batch_decode(generated_output, skip_special_tokens=True))

    avg_sum_loss = sum(sum_loss_list) / len(sum_loss_list)
    print(f"Epoch {epoch} - Avg Loss: {avg_sum_loss}")
    loss_list.append(avg_sum_loss)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step 100, Loss: 7.6420300102233885
Step 200, Loss: 6.096641244888306
Step 300, Loss: 5.741399111747742
Step 400, Loss: 5.573953895568848
Step 500, Loss: 5.532182145118713
Step 600, Loss: 5.477182502746582
Step 700, Loss: 5.425440158843994
Step 800, Loss: 5.351293129920959
Step 900, Loss: 5.300527973175049
Step 1000, Loss: 5.256822524070739
Step 1100, Loss: 5.234334244728088
Step 1200, Loss: 5.34716121673584
Step 1300, Loss: 5.271569843292236
Step 1400, Loss: 5.262822613716126
Step 1500, Loss: 5.254798636436463
Step 1600, Loss: 5.2519832038879395
