In [None]:
import os
import warnings
import glob
import json
import torch
import torch.nn.functional as F
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torchvision.transforms as T
from PIL import Image
from transformers import pipeline
from datasets import Dataset
from tqdm import tqdm

In [None]:
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_available = torch.cuda.is_available()

model_detection = fasterrcnn_resnet50_fpn(pretrained=True)
model_detection.to(device)
model_detection.eval()

In [None]:
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
    'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag',
    'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
    'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
    'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

In [None]:
transform = T.Compose([T.ToTensor()])

image_paths = glob.glob(os.path.join("dataset", "*.jpg"))[30000:40000]

In [None]:
text_generator = pipeline(
    "text2text-generation",
    model="google/flan-t5-large",
    device=0 if gpu_available else -1
)

In [None]:
results = {}

with tqdm(total=len(image_paths), desc="Processing Images", unit="img") as pbar:
    for image_path in image_paths:
        img = Image.open(image_path).convert("RGB")
        img_tensor = transform(img).to(device)
        with torch.no_grad():
            prediction = model_detection([img_tensor])
        scores = prediction[0]['scores'].cpu().numpy()
        labels = prediction[0]['labels'].cpu().numpy()
        det_threshold = 0.6
        filtered_indices = scores >= det_threshold
        filtered_labels = labels[filtered_indices]
        object_names = [COCO_INSTANCE_CATEGORY_NAMES[label] for label in filtered_labels if 0 <= label < len(COCO_INSTANCE_CATEGORY_NAMES)]
        if object_names:
            prompt_text = "Imagine an Image and Generate a detailed, descriptive sentence using these words in a natural context " + ", ".join(object_names)
            caption = text_generator(prompt_text, max_length=40)[0]['generated_text']
        else:
            caption = "No objects detected, so no sentence was created."
        results[os.path.basename(image_path)] = caption
        pbar.update(1)

In [None]:
with open("train_pseudo_caption.json", "w") as f:
    json.dump(results, f, indent=4)

print("Pseudo Caption Completed")