In [1]:
from transformers.models.blip import BlipProcessor, BlipForConditionalGeneration
import pytorch_lightning as pl
from PIL import Image
from pathlib import Path
from typing import cast
from collections.abc import Sequence
import torch
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import logging
import logging.config
import yaml

this_path = Path(".").parent


class DuplicateFilter(logging.Filter):
    def __init__(self, name: str = "") -> None:
        super().__init__(name)
        self._past_messages = set()

    def filter(self, record: logging.LogRecord) -> bool:
        if record.msg in self._past_messages:
            return False
        self._past_messages.add(record.msg)
        print(
            record,
            record.msg,
            record.levelname,
            record.name,
            file=(this_path / Path(".logs/log.txt")).open("a"),
        )
        print(
            Path(".").absolute(),
            file=(this_path / Path(".logs/log.txt")).open("a"),
        )
        return True


logging.config.dictConfig(
    yaml.safe_load(Path("logging.yaml").open())
)
logger = logging.getLogger(__name__)

In [4]:
FOLDER = Path("./downloaded_images")

In [5]:
class BlipPL(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.processor = cast(
            BlipProcessor,
            BlipProcessor.from_pretrained(
                "Salesforce/blip-image-captioning-base"
            ),
        )
        self.model = BlipForConditionalGeneration.from_pretrained(
            "Salesforce/blip-image-captioning-base"
        )

    def predict_step(
        self,
        batch,
    ) -> str | list[str]:
        """
        Generate multiple captions at once.

        Args:
            images (Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor]): Input image or list of images.

        Returns:
            str | list[str]: Output caption(s)
        """
        # if isinstance(images, (Image.Image, torch.Tensor)):
        #     images = [images]  # type: ignore
        # logger.debug(f"{self.processor.tokenizer.padding_side = }")
        # self.processor.tokenizer.padding_side = "left"
        # inputs = self.processor(images=images, return_tensors="pt")
        logger.debug(f"{batch = }")
        logger.debug(f"{batch['pixel_values'].shape = }")
        out = self.model.generate(**batch)
        return self.processor.batch_decode(out, skip_special_tokens=True)

In [6]:
model = BlipPL()
model.eval()
model.freeze()

In [7]:
images: Sequence[Image.Image] = [Image.open(f) for f in FOLDER.glob("*.jpg")]
image = Image.open(next(iter(FOLDER.glob("*.jpg"))))
type(image).mro()

[PIL.JpegImagePlugin.JpegImageFile,
 PIL.ImageFile.ImageFile,
 PIL.Image.Image,
 object]

In [8]:
class ImageDataset(Dataset):
    def __init__(self, images: Sequence[Image.Image]):
        self.images = images
        self.processor: BlipProcessor = BlipProcessor.from_pretrained(
            "Salesforce/blip-image-captioning-base"
        )  # type: ignore
        self.processor.tokenizer.padding_side = "left"  # type: ignore

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        images = self.images[idx]
        images = self.processor(
            images=images, padding=True, return_tensors="pt"
        )
        images = {
            k: v.squeeze() for k, v in images.items()
        }  # here, squeezing was needed
        return images


class CaptionDataModule(pl.LightningDataModule):
    def __init__(self, images: Sequence[Image.Image]):
        super().__init__()
        # self.processor: BlipProcessor = BlipProcessor.from_pretrained(
        #     "Salesforce/blip-image-captioning-base"
        # )  # type: ignore
        # self.processor.tokenizer.padding_side = "left"  # type: ignore
        self.dataset = ImageDataset(images)

    def test_dataloader(self):
        return DataLoader(
            self.dataset,
            # batch_size=2,
            # shuffle=False,
        )

In [None]:
# Generate captions for images
trainer = pl.Trainer(accelerator="auto")
# not yet using multiple gpus

/mnt/local/ryan/.venv12/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [10]:
captions = trainer.predict(
    model, dataloaders=DataLoader(ImageDataset(images), batch_size=8)
)

/mnt/local/ryan/.venv12/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=223` in the `DataLoader` to improve performance.


Predicting DataLoader 0:   6%|▌         | 36/584 [00:07<02:00,  4.57it/s]

NameError: name 'exit' is not defined

In [None]:
captions