In [48]:
import torch
from transformers import BeitModel, BeitImageProcessor
from PIL import Image
import datasets
from modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMConfig
from tokenization_chatglm import ChatGLMTokenizer
from peft import (get_peft_model_state_dict, get_peft_model, LoraConfig)

In [49]:
import importlib
import modeling_chatglm
importlib.reload(modeling_chatglm)

<module 'modeling_chatglm' from '/Share/home/qiyifan/filebase/projects/multi-modal/src/modeling_chatglm.py'>

In [2]:
ID_MASK = 64789
ID_gMASK = 64790
ID_sMASK = 64791
ID_SOP = 64792
ID_EOP = 64793
ID_BOS = 1
ID_EOS = 2
ID_PAD = 0

In [50]:
PATH_MODEL_PRETRAIN='/Share/home/qiyifan/filebase/source/chatglm2-6b'
ImageEncoderPath="/Share/home/qiyifan/filebase/source/beit-base-patch16-224-pt22k-ft22k"
tokenizer = ChatGLMTokenizer.from_pretrained(PATH_MODEL_PRETRAIN)
    # tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # Allow batched inference
model = modeling_chatglm.ChatGLMForConditionalGeneration.from_pretrained(PATH_MODEL_PRETRAIN)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at /Share/home/qiyifan/filebase/source/chatglm2-6b and are newly initialized: ['transformer.image_encoder.align.image_h_to_kv.bias', 'transformer.image_encoder.align.image_h_to_kv.weight', 'transformer.image_encoder.align.image_e_to_h.bias', 'transformer.image_encoder.align.image_e_to_h.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
model.transformer.init_image_model("/Share/home/qiyifan/filebase/source/beit-base-patch16-224-pt22k-ft22k")

Some weights of the model checkpoint at /Share/home/qiyifan/filebase/source/beit-base-patch16-224-pt22k-ft22k were not used when initializing BeitModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BeitModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BeitModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [52]:
def find_target_modules(model):
    # Initialize a Set to Store Unique Layers
    unique_layers = set()
    
    # Iterate Over All Named Modules in the Model
    for name, module in model.named_modules():
        # Check if the Module Type Contains 'Linear4bit'
        if "Linear" in str(type(module)) and "image" not in name:
            # Extract the Type of the Layer
            layer_type = name.split('.')[-1]
            
            # Add the Layer Type to the Set of Unique Layers
            unique_layers.add(name)

    # Return the Set of Unique Layers Converted to a List
    return list(unique_layers)

In [53]:
TARGET_MODULES = find_target_modules(model)

In [54]:
config = LoraConfig(target_modules=TARGET_MODULES,
                    lora_dropout=0.05,
                    lora_alpha=16,
                    task_type="CAUSAL_LM",
                    bias="none",
                    r=8,
                    )

In [55]:
model = get_peft_model(model, config)

In [56]:
model = model.cuda()

In [57]:
model.print_trainable_parameters()

trainable params: 15,376,384 || all params: 6,406,606,784 || trainable%: 0.24000823709676264


In [58]:
for name , parameters in model.named_parameters():
    if "align" in name:
        parameters.requires_grad = True

In [34]:
medqa = datasets.load_dataset('flaviagiammarino/vqa-rad')

Found cached dataset parquet (/Share/home/qiyifan/.cache/huggingface/datasets/flaviagiammarino___parquet/flaviagiammarino--vqa-rad-d04980c9c3579419/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
feature_extractor = BeitImageProcessor.from_pretrained('/Share/home/qiyifan/filebase/source/beit-base-patch16-224-pt22k-ft22k')

In [36]:
train_data = medqa["train"]

In [80]:
lengths = train_data.map(lambda x : {'seq_x_len': len(x['question']),
                                     "seq_y_len":len(x['answer'])})

Map:   0%|          | 0/1793 [00:00<?, ? examples/s]

In [37]:
def preprocess_function_train(datapoint):
    max_seq_length = 128 + 128 + 1
    prompt_column = 'question'
    response_column = 'answer'
    image_column = 'image'
    model_inputs = {
        "input_ids": None,
        "labels": None,
        "image_tensors": None
    }
    if datapoint[prompt_column] and datapoint[response_column]:
        query, answer = datapoint[prompt_column], datapoint[response_column]
        
        prompt = tokenizer.build_prompt(query,)

        a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
                                    max_length=128)
        b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
                                    max_length=128)

        input_ids = a_ids
        labels = b_ids + [tokenizer.eos_token_id]

        model_inputs["input_ids"] =input_ids
        model_inputs["labels"] =labels
        image_tensor=feature_extractor(datapoint[image_column], return_tensors='pt')["pixel_values"].squeeze()
        model_inputs['image_tensors'] =image_tensor
        
    return model_inputs

In [46]:
image_tensor=feature_extractor(train_data[0]['image'], return_tensors='pt')["pixel_values"].squeeze()

In [47]:
image_tensor.shape

torch.Size([3, 224, 224])

In [39]:
train_data_prepared = train_data.map(preprocess_function_train)

Loading cached processed dataset at /Share/home/qiyifan/.cache/huggingface/datasets/flaviagiammarino___parquet/flaviagiammarino--vqa-rad-d04980c9c3579419/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-626963f56ef08d75.arrow


In [45]:
torch.tensor(train_data_prepared[0]['image_tensors']).shape

torch.Size([3, 224, 224])

In [40]:
def data_collator(batch):
    len_max_batch = [len(batch[i].get("input_ids")) + len(batch[i].get("labels")) + 1
                     for i in range(len(batch))]
    len_max_batch = min(200, max(len_max_batch))
    batch_input_ids = []
    batch_labels = []
    batch_image_tensors = []
    for ba in batch:
        x, y, image_tensor = ba.get("input_ids"), ba.get("labels") , ba.get('image_tensors')
        len_padding = len_max_batch - len(x) - len(y)
        if tokenizer.padding_side and tokenizer.padding_side == "left":
            labels = [-100] * len_padding + [-100] * len(x) + y
            input_ids = [ID_PAD] * (len_padding) + x + y
        else:
            labels = [-100] * len(x) + y + [-100] * len_padding
            input_ids = x + y + [ID_PAD] * (len_padding)
        tensor_input_ids = torch.tensor(input_ids, dtype=torch.long)
        tensor_labels = torch.tensor(labels, dtype=torch.long)
        image_tensor = torch.tensor(image_tensor)
        batch_input_ids.append(tensor_input_ids)
        batch_labels.append(tensor_labels)
        batch_image_tensors.append(image_tensor)
    batch_input_ids = torch.stack(batch_input_ids)
    batch_labels = torch.stack(batch_labels)
    batch_image_tensors = torch.stack(batch_image_tensors)
    input_dict = {
                "input_ids": batch_input_ids,
                "labels": batch_labels,
                "image_tensors":batch_image_tensors,
                }
    return input_dict

In [41]:
model_inputs = data_collator(train_data_prepared.select([0,1,2,3]))

In [42]:
model_inputs = {k: v.cuda() for k, v in model_inputs.items()}

In [None]:
output = model(**model_inputs)

In [61]:
output.loss

tensor(7.7383, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)

In [None]:
for name , parameters in model.named_parameters():
    if "align" in name:
        parameters.requires_grad = True

In [63]:
state_dict = model.state_dict()

In [72]:
state_dict = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
    if 'lora' not in k and v.requires_grad:
        filtered_state_dict[k] = state_dict[k]

In [73]:
filtered_state_dict.keys()

dict_keys(['base_model.model.transformer.image_encoder.align.image_e_to_h.weight', 'base_model.model.transformer.image_encoder.align.image_e_to_h.bias', 'base_model.model.transformer.image_encoder.align.image_h_to_kv.weight', 'base_model.model.transformer.image_encoder.align.image_h_to_kv.bias'])