In [None]:
!pip install git+https://github.com/huggingface/transformers -q
!pip install accelerate -q
!pip install -q 'labelbox[data]' -q

import requests
from PIL import Image
from labelbox.schema.ontology import OntologyBuilder
from labelbox import Client, MALPredictionImport
from labelbox.data.annotation_types import (
    Label, ImageData, ClassificationAnnotation, Text
)
import uuid
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import json
import labelbox

## Enter your API KEY below. Link to [how to create](https://docs.labelbox.com/reference/create-api-key) API KEY

In [None]:
MYAPI = "API KEY"
client = Client(MYAPI)  

# Build the ontology used to create project

In [None]:
ontology_builder = labelbox.OntologyBuilder(
    classifications=[labelbox.Classification(
        class_type=labelbox.Classification.Type.TEXT,
        name="BLIP model prediction"  # name of object
    ), labelbox.Classification(
        class_type=labelbox.Classification.Type.TEXT,
        name="Human caption"  # name of object
    )]
)

ontology = client.create_ontology("BLIP", ontology_builder.asdict(), media_type=labelbox.MediaType.Image)

# Create project and attach the ontology

In [None]:
project = client.create_project(name = "BLIP Pre label", media_type=labelbox.MediaType.Image)
project.setup_editor(ontology)
ontology_from_project = labelbox.OntologyBuilder.from_project(project)

# Export the Datarow IDs from datset so that they can be attached to the project 

In [None]:
# Set the export params to include/exclude certain fields. Make sure each of these fields are correctly grabbed 
export_params= {
    "attachments": True,
    "metadata_fields": True,
    "data_row_details": True,
    "project_details": True,
    "performance_details": True
  
}

# You can set the range for last_activity_at
# For context, last_activity_at captures the creation and modification of labels, metadata, status, comments and reviews.
# Note: This is an AND logic between the filters, so usually using one filter is sufficient.

filters= {
            "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
}

dataset = client.get_dataset("Dataset ID")
export_task = dataset.export_v2(params=export_params, filters=filters)
export_task.wait_till_done()
if export_task.errors:
  print(export_task.errors)
export_result = export_task.result
print("results: ", json.dumps(export_result[:3], indent = 4))
data_row_ids = [dr["data_row"]["id"] for dr in export_result[:100]]

# Attach batch to the project

In [None]:
batch = project.create_batch(
  "Adding assets", # name of the batch
  data_row_ids, # list of Data Rows
  1 # priority between 1-5
)

#Initialize and load a pre-trained BLIP-2 model

If a GPU is available, the model will be moved to the GPU to take advantage of its parallel processing capabilities, which can significantly speed up computations. If a GPU is not available, the model will run on the CPU (Central Processing Unit) instead.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
model.to(device)

# Example of the image and the output of the model

In [None]:
url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png' 
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')   
display(image.resize((596, 437)))
inputs = processor(image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

# Collect inferences to be used as prelabels

In [None]:
queued_data_rows = project.export_queued_data_rows()
ground_truth_list = list()

for data_row in queued_data_rows:
  url = data_row["rowData"]
  image = Image.open(requests.get(url, stream=True).raw)
  inputs = processor(image, return_tensors="pt").to(device, torch.float16)
  generated_ids = model.generate(**inputs, max_new_tokens=30)
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
  
  text_annotation = labelbox.data.annotation_types.ClassificationAnnotation(
      name="BLIP model prediction", 
      value=labelbox.data.annotation_types.Text(answer = generated_text)
    )
  
  ground_truth_list.append(Label(
        data= ImageData(uid = data_row["id"]), annotations = [text_annotation]
    ))

#Upload prelabels to project 

In [None]:
  upload_task = labelbox.MALPredictionImport.create_from_objects(client, project.uid, str(uuid.uuid4()), ground_truth_list)
  upload_task.wait_until_done()
  print(upload_task.errors)