In [1]:
import os

os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ['CUDA_VISIBLE_DEVICES'] = '9'

In [2]:
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from PIL import Image
from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
import torchvision.transforms as transforms

In [3]:

import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# checkpoint = "HuggingFaceM4/tiny-random-idefics"
checkpoint = "HuggingFaceM4/idefics-9b"
#checkpoint = "/data/LMS/ide/idefics-9b-chart/checkpoint-40/adapter_model.bin"
# Here we skip some special modules that can't be quantized properly
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_skip_modules=["lm_head", "embed_tokens"],
)

processor = AutoProcessor.from_pretrained(checkpoint)
# Simply take-off the quantization_config arg if you want to load the original model
model = IdeficsForVisionText2Text.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')



# # 使用DataParallel包装模型
# model = nn.DataParallel(model)
# # 将模型移动到指定设备上
# model = model.to(device)
# # 其他代码...

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [4]:
def check_inference(model, processor, prompts, max_new_tokens=50):
    tokenizer = processor.tokenizer
    bad_words = ["<image>", "<fake_token_around_image>"]
    if len(bad_words) > 0:
        bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids

    eos_token = "</s>"
    eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

    inputs = processor(prompts, return_tensors="pt").to(device)
    generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(generated_text)

In [None]:
import pandas as pd
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split

def convert_to_rgb(image):
    # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
    # for transparent images. The call to `alpha_composite` handles this case
    if image.mode == "RGB":
        return image

    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    alpha_composite = alpha_composite.convert("RGB")
    return alpha_composite

def ds_transforms(example_batch):
    image_size = processor.image_processor.image_size
    image_mean = processor.image_processor.image_mean
    image_std = processor.image_processor.image_std

    image_transform = transforms.Compose([
        convert_to_rgb,
        transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=image_mean, std=image_std),
    ])

    prompts = []
    for i in range(len(example_batch['caption'])):
        # We split the captions to avoid having very long examples, which would require more GPU ram during training
        caption = example_batch['caption'][i]
        img = Image.open(example_batch['image_url'][i])
        prompts.append(
            [
                img,
                f"Question: Describe this chart in detail. Answer: {caption}</s>",
            ],
        )

    inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)

    inputs["labels"] = inputs["input_ids"]

    return inputs


# load and prepare dataset
data1 = pd.read_csv("/data/LMS/ide/merged.csv")
# dataset = Dataset.from_pandas(data1)

# 拆分为训练集和测试集
train_ds, eval_ds = train_test_split(data1, test_size=0.2, shuffle=True)

print(type(train_ds))
print(type(eval_ds))

# 将训练集和测试集转换为 Dataset
train_ds = Dataset.from_dict(train_ds)
eval_ds = Dataset.from_dict(eval_ds)

print(type(train_ds))
print(type(eval_ds))

train_ds.set_transform(ds_transforms)
eval_ds.set_transform(ds_transforms)

print(type(train_ds))
print(type(eval_ds))

# print(type(train_ds))
# print(type(eval_ds))
# print(type(data))

# LoRA
After specifying the low-rank adapters (LoRA) config, we load the PeftModel using the get_peft_model utility function

In [None]:

model_name = checkpoint.split("/")[1]
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, config)

In [None]:
model.print_trainable_parameters()
model = model.to(device)

In [None]:
weight = torch.load("/data/LMS/ide/idefics-9b-result/checkpoint-40/adapter_model.bin")
model.load_state_dict(weight)

In [None]:
training_args = TrainingArguments(
    output_dir=f"{model_name}-chartnewdata",
    learning_rate=2e-4,
    fp16=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    dataloader_pin_memory=False,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=40,
    eval_steps=20,
    logging_steps=20,
    max_steps=40,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    report_to=None,
    optim="paged_adamw_8bit",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
)

trainer.train()

In [5]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load peft config for pre-trained checkpoint etc.
peft_model_id = "/data/LMS/ide/idefics-9b-result/checkpoint-40"
config = PeftConfig.from_pretrained(peft_model_id)

# load base LLM model and tokenizer

model = IdeficsForVisionText2Text.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map={"":0})
#model.eval()

print("Peft model loaded")


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

Peft model loaded


# Push your new model to the hub!


In [None]:
# # Insert your "write" token. You should find it in the settings of your HF profile
# !huggingface-cli login

In [None]:
# model.push_to_hub(f"{model_name}-pokemon", private=False)

In [10]:
from PIL import Image
image = Image.open("/data/LMS/4.png")
prompts = [
    # "Instruction: provide an answer to the question. Use the image to answer.\n",
    image,
    "Question: Describe this chart in detail. Answer:",
]
check_inference(model, processor, prompts, max_new_tokens=500)

Question: Describe this chart in detail. Answer: Number of neonatal deaths per 1000 births


In [12]:
from PIL import Image
image3 = Image.open("/data/LMS/chart2textidata/Chart-to-text/pew_dataset/dataset/multiColumn/imgs/3.png")
image2 = Image.open("/data/LMS/chart2textidata/Chart-to-text/pew_dataset/dataset/multiColumn/imgs/74.png")
image1 = Image.open("/data/LMS/chart2textidata/Chart-to-text/pew_dataset/dataset/multiColumn/imgs/1.png")
image4 = Image.open("/data/LMS/chart2textidata/test_img_all/multi_col-1000.png")

prompts = [
        "User: Describe this chart in detail. This is the underlying table information of the chart: votes cast in most recent national election as a  | x of voting age population  | xof registered voters  | iceland 12017  | japan 12017  | turkey 120181  | sweden 120181  | australia 120191  | belgium 120191  | south korea 12017  | israel 120201  | netherlands 12017  | denmark 12019  | hungary 120181  | norway 12017  | finland 12019  | germany 12017  | france 12017  | mexico 120181  | poland 120201  | slovakia 120201  | italy 12018  | austria 12019  | greece 120191  | new zealand 120201  | canada 12019  | united kingdom 12019  | portugal 12019  | spain 12019  | lithuania 12019  | czech republic 12017  | colombia 120181  | ireland 120201  | estonia 12019  | united states 12016  | slovenia 120181  | latvia 12018  | chile 12017  | luxembourg 120181  | switzerland 120191  | 10  | 20  | 30  | 40  | 50  | 60  | 70  | 80  | 90  | 100 ",
        image3,
        "<end_of_utterance>",

        "\nAssistant: The 55.7% VAP turnout in 2016 puts the U.S. behind most of its peers in the Organization for Economic Cooperation and Development, most of whose members are highly developed democratic states. Looking at the most recent nationwide election in each OECD nation, the U.S. places 30th out of 35 nations for which data is available. The highest turnout rates among OECD nations were in Turkey (89% of voting-age population), Sweden (82.1%), Australia (80.8%), Belgium (77.9%) and South Korea (77.9%). Switzerland consistently has the lowest turnout in the OECD: In 2019 federal elections, barely 36% of the Swiss voting-age population voted. As a consequence, turnout comparisons based only on registered voters may not be very meaningful. For instance, U.S. turnout in 2016 was 86.8% of registered voters, fifth-highest among OECD countries and second-highest among those without compulsory voting. But registered voters in the U.S. are much more of a self-selected group, already more likely to vote because they took the trouble to register themselves.<end_of_utterance>",

        "User: Describe this chart in detail. This is the underlying table information of the chart: x ofchildren living in poverty by race and ethnicity  | 60  | recessions  | 40  | 20  | 1980  | 1985  | 1990  | 1995  | 2000  | 2005  | 2010  | 2015  | asian  | white  | black  | hispanic ",
        image1,
        "<end_of_utterance>",

        "\nAssistant: Before the coronavirus outbreak sent the U.S. economy into a recession, the share of American children living in poverty was on a downward trajectory, reaching record lows across racial and ethnic groups, according to a new Pew Research Center analysis of U.S. Census Bureau data. In 2019, the year with the most recently available data, 14% of children under age 18, or 10.5 million children, were living in poverty, down from 22%, or 16.3 million, in 2010. All major racial and ethnic groups saw declines since 2010, but the greatest decreases were in the shares of Black and Hispanic children living in poverty. About two-in-ten Hispanic children (21%) were living in poverty in 2019, down from 35% in 2010. In 2019, 26% of Black children were impoverished, dropping from 39% in 2010. Even so, Black and Hispanic children were still about three times as likely as Asian (7%) and White (8%) children to be living in poverty. <end_of_utterance>",

       "User: Describe this chart in detail. This is the underlying table information of the chart:Democrats with high science knowledge have more  | confidence in the scientific method  | % ofU.S. adults in each group who say the scientific method  | Can be used to produce any  | Generally produces  | conclusion the researcher wants  | accurate conclusions  | U. S. adults  | 35  | 63  | Among Republicans with  | science knowledge  | Can be used to produce any  | Generally produces  | conclusion the researcher wants  | accurate conclusions  | High  | 40  | 59  | Medium  | 47  | 52  | Low  | 48  | 51  | Among Democrats with  | science knowledge  | High  | 86  | 14  | Medium  | 31  | 67  | Low  | 46  | 52  | Note: Respondents who did not give an answer are not shown. See Methodology for details  | on index of science knowledge  | Source: Survey conducted Jan. 7-21, 2019  | Trust and Mistrust in Americans Views of Scientific Experts  | PEW RESEARCH CENTER ",
        image4,
        "<end_of_utterance>",

        "\nAssistant: Factual knowledge alone does not explain public confidence in the scientific method to produce sound conclusions. Overall, a 63% majority of Americans say the scientific method generally produces sound conclusions, while 35% think it can be used to produce “any result a researcher wants.” People’s level of knowledge can influence beliefs about these matters, but it does so through the lens of partisanship, a tendency known as motivated reasoning.<end_of_utterance>",


        "\nUser:",
        image2,
        "Describe this chart in detail.This is the underlying table information of the chart:among all american adults the 9 who use the internet by age  | 100  | 75  | 50  | 25  | 2000  | 2002  | 2004  | 2006  | 2008  | 2010  | 2012  | 2014  | 1829  | 3049  | 5064  | 65 or older <end_of_utterance>",

        "\nAssistant:",
],
check_inference(model, processor, prompts, max_new_tokens=500)

User: Describe this chart in detail. This is the underlying table information of the chart: votes cast in most recent national election as a  | x of voting age population  | xof registered voters  | iceland 12017  | japan 12017  | turkey 120181  | sweden 120181  | australia 120191  | belgium 120191  | south korea 12017  | israel 120201  | netherlands 12017  | denmark 12019  | hungary 120181  | norway 12017  | finland 12019  | germany 12017  | france 12017  | mexico 120181  | poland 120201  | slovakia 120201  | italy 12018  | austria 12019  | greece 120191  | new zealand 120201  | canada 12019  | united kingdom 12019  | portugal 12019  | spain 12019  | lithuania 12019  | czech republic 12017  | colombia 120181  | ireland 120201  | estonia 12019  | united states 12016  | slovenia 120181  | latvia 12018  | chile 12017  | luxembourg 120181  | switzerland 120191  | 10  | 20  | 30  | 40  | 50  | 60  | 70  | 80  | 90  | 100 <end_of_utterance>
Assistant: The 55.7% VAP turnout in 2016 puts the 