In [None]:
# Transformers installation
! pip install transformers datasets evaluate -q
! pip install jiwer
! pip install accelerate -U

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m401.2/401.2 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jiwer
  Downloading jiwer-3.0.4-py3-none-any.whl (21 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from datasets import load_dataset, Dataset
from transformers import AutoProcessor, AutoModelForCausalLM
from evaluate import load
import torch
from datasets import load_dataset

In [None]:
from datasets import load_dataset, Dataset
from transformers import AutoProcessor, AutoModelForCausalLM
from evaluate import load
import torch
from datasets import load_dataset


"""
Dataset Processing
"""
ds = load_dataset("YaYaB/onepiece-blip-captions")
ds = ds["train"].train_test_split(test_size=0.1, seed=42)
train_ds=ds["train"]
test_ds = ds["test"]
del ds

exclude_idx = []
for index, instance in enumerate(train_ds):
  if "man" in instance["text"] or "shirt" in instance["text"] or "tie" in instance["text"]:
    exclude_idx.append(index)
print("number of excluded items: {}".format(len(exclude_idx)))
train_ds = train_ds.select((i for i in range(len(train_ds))if i not in set(exclude_idx)))

print(train_ds)

In [None]:

checkpoint = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)

def transforms(example_batch):
    images = [x for x in example_batch["image"]]

    captions = [x for x in example_batch["text"]]
    inputs = processor(images=images, text=captions, padding="max_length")
    inputs.update({"labels": inputs["input_ids"]})
    return inputs


train_ds.set_transform(transforms)
test_ds.set_transform(transforms)


In [None]:
from transformers import TrainingArguments, Trainer

model_name = checkpoint.split("/")[1]

training_args = TrainingArguments(
    output_dir=f"{model_name}-blip",
    learning_rate=5e-5,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    remove_unused_columns=False,
    label_names=["labels"],
    load_best_model_at_end=True,
)

wer = load("wer")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"wer_score": wer_score}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

## Inference

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from PIL import Image
import requests

url = "https://ami.animecharactersdatabase.com/uploads/chars/12602-925960129.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image

device = "cuda" if torch.cuda.is_available() else "cpu"

inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values

generated_ids = model.generate(pixel_values=pixel_values, max_length=20) #, temperature=0.7, do_sample=True)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)