In [5]:
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
model.to(device)

For stability purposes, it is recommended to have accelerate installed when using this model in torch.float16, please install it with `pip install accelerate`
Loading checkpoint shards: 100%|██████████| 2/2 [00:16<00:00,  8.08s/it]


Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0): Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
        )
        (1): Blip2EncoderLayer(
          (self_attn): 

In [7]:

import os
from collections import defaultdict
all_images = []

image_dir = "death-note-images"
all_image_paths = [f"{image_dir}/{fname}" for fname in os.listdir(image_dir) if fname.endswith(".jpg")]


In [8]:
from tqdm import tqdm
def generate_captions(image_file_paths, batch_size=8):
    images = [Image.open(image_file_path) for image_file_path in image_file_paths]
    
    captions = []
    for i, s in enumerate(tqdm(range(0, len(images), batch_size))):
        batch_images = images[s : s + batch_size]
        inputs = processor(images=batch_images, return_tensors="pt").to(device, torch.float16)
        generated_ids = model.generate(**inputs)
        generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
        captions.extend([t.strip() for t in generated_texts])
    return captions


In [9]:
all_captions = generate_captions(all_image_paths)

100%|██████████| 426/426 [05:40<00:00,  1.25it/s]


In [21]:
from datasets import load_dataset, Image, Dataset, DatasetDict, load_from_disk

In [11]:
death_note_dataset = DatasetDict()

In [12]:
death_note_dataset["train"] = Dataset.from_dict({"image" : all_image_paths, "text" : all_captions}).cast_column("image", Image())

In [19]:
death_note_dataset.save_to_disk("stable-diffusion-dataset/text-to-death-note-blip-2")

                                                                                             

In [22]:
new_dataset = load_from_disk("/usr1/datasets/sumita/artnml/stable-diffusion-dataset/text-to-death-note-blip-2")

In [23]:
new_dataset["train"][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x720>,
 'text': 'a woman with long hair and a black phone'}