In [1]:
import os.path

from shared_code.utility.schemas.spark_table_schema import image_table_schema, tokenize_caption_schema
from shared_code.utility.spark.set_environ import *
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, GPT2TokenizerFast, AutoTokenizer
from PIL import Image

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
set_azure_env()

from shared_code.utility.storage.table import TableAdapter

spark_builder = SparkSession \
	.builder \
	.appName('add-missing-captions') \
	.master("local[6]") \
	.config("spark.cores.max", "1") \
	.config("spark.executor.instances", "1") \
	.config("spark.executor.cores", "1") \
	.config("spark.executor.cores", "1") \
	.config("spark.executor.instances", "1") \
	.config("spark.driver.memory", "10g") \
	.config("spark.memory.offHeap.enabled", True) \
	.config("spark.memory.offHeap.size", "10g") \
	.config("spark.executor.cores", "1")

spark = spark_builder.getOrCreate()

In [2]:
table_name = "training"
table_adapter: TableAdapter = TableAdapter()
raw_data = table_adapter.get_all_entities(table_name)
spark_df = spark.createDataFrame(raw_data, schema=image_table_schema)
print("== Data Loaded For Captions ==")

== Data Loaded For Captions ==


In [3]:
def get_vit_caption():
	model = VisionEncoderDecoderModel.from_pretrained("D:\\models\\vit-gpt2-image-captioning")
	feature_extractor = ViTFeatureExtractor.from_pretrained("D:\\models\\vit-gpt2-image-captioning")
	tokenizer = AutoTokenizer.from_pretrained("D:\\models\\vit-gpt2-image-captioning")
	return model, feature_extractor, tokenizer


model, feature_extractor, tokenizer = get_vit_caption()


def caption_image_vit(image_path: str) -> str:
	try:
		max_length = 32

		num_beams = 4

		gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

		images = []

		i_image = Image.open(image_path)
		if i_image.mode != "RGB":
			i_image = i_image.convert(mode="RGB")

		images.append(i_image)

		print(f":: Predicting image: {image_path}")

		pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values

		output_ids = model.generate(pixel_values, **gen_kwargs)

		print(f":: Decoding output for image: {image_path}")
		predictions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

		prediction = [prediction.strip() for prediction in predictions]

		print(f":: Completed prediction for image: {image_path}")
		if len(prediction) > 0:
			return prediction[0]
		else:
			return None

	except Exception as e:
		print(f":: Process Failed For {image_path} with {e}")
		return None

In [4]:
def get_default_caption():
	model = VisionEncoderDecoderModel.from_pretrained("D:\\models\\image-caption-generator")
	feature_extractor = ViTFeatureExtractor.from_pretrained("D:\\models\\image-caption-generator")
	tokenizer = GPT2TokenizerFast.from_pretrained("D:\\models\\image-caption-generator\\tokenizer")
	return model, feature_extractor, tokenizer


blip_model, blip_feature_extractor, blip_tokenizer = get_default_caption()


def caption_image_blip(image_path: str) -> str:
	try:
		img = Image.open(image_path)
		if img.mode != 'RGB':
			img = img.convert(mode="RGB")

		pixel_values = blip_feature_extractor(images=[img], return_tensors="pt").pixel_values

		max_length = 128
		num_beams = 4

		# get model prediction
		output_ids = blip_feature_extractor.generate(pixel_values, num_beams=num_beams, max_length=max_length)

		# decode the generated prediction
		predictions = blip_tokenizer.decode(output_ids[0], skip_special_tokens=True)
		return predictions

	except Exception as e:
		print(f"Error in caption_image: {e}")
		return None

In [5]:
def get_caption_image_name(image_path: str, caption: str, row_key: str, partition_key: str):
	if not os.path.exists(image_path):
		return ""
	try:
		if caption is None or caption == "" or caption == "NaN" or caption == "None" or len(caption) < 5:
			vit_caption = caption_image_vit(image_path)
			record = table_adapter.get_entity("training", partition_key, row_key)
			record["updated_caption"] = vit_caption
			table_adapter.upsert_entity_to_table("training", record)
			return caption
		else:
			return caption
	except:
		return ""


def get_blip_caption_image_name(image_path: str, caption: str, row_key: str, partition_key: str):
	if not os.path.exists(image_path):
		return ""
	try:
		if caption is None or caption == "" or caption == "NaN" or caption == "None" or len(caption) < 5:
			_blip_caption = caption_image_blip(image_path)
			record = table_adapter.get_entity("training", partition_key, row_key)
			record["caption"] = _blip_caption
			table_adapter.upsert_entity_to_table("training", record)
			return _blip_caption
		else:
			return caption
	except:
		return ""

In [15]:
remaining_captions = spark_df.select("image", "small_image", "updated_caption", "caption", "RowKey", "PartitionKey", "Exists").collect()
print(f"Need to processes: {len(remaining_captions)} image captions")

Need to processes: 8600 image captions


In [18]:
for i, elem in enumerate(list(remaining_captions)):
	try:
		if not elem['Exists']:
			continue
		if i % 100 == 0:
			print(f"Remaining: {i} / {len(remaining_captions)}")
		elem = elem.asDict()
		# print(elem)

		if elem is not None:
			pass
		else:
			continue

		if isinstance(elem, type(None)):
			continue

		if not os.path.exists(elem['image']):
			entity = table_adapter.get_entity("training", elem['PartitionKey'], elem['RowKey'])
			entity['Exists'] = False
			table_adapter.upsert_entity_to_table("training", entity)
			continue

		if len(elem['updated_caption']) > 5 and len(elem['caption']) > 5:
			continue

		if os.path.exists(elem['small_image']):
			entity = table_adapter.get_entity("training", elem['PartitionKey'], elem['RowKey'])
			blip_caption = None
			vit_result_caption = None
			try:
				vit_result_caption = caption_image_vit(elem['small_image'])
				blip_caption = caption_image_vit(elem['small_image'])
			except:
				result = None

			entity["caption"] = blip_caption
			entity["updated_caption"] = vit_result_caption
			table_adapter.upsert_entity_to_table("training", entity)
		else:
			print("File not found: ", elem['small_image'])
			entity = table_adapter.get_entity("training", elem['PartitionKey'], elem['RowKey'])
			entity['exists'] = False
			table_adapter.upsert_entity_to_table("training", entity)
	except ConnectionResetError:
		print("Connection Reset Error")
		table_adapter = TableAdapter()
		continue
	except Exception as e:
		# print(e)
		continue

Remaining: 0 / 8600
Remaining: 100 / 8600
Remaining: 200 / 8600
Remaining: 400 / 8600
Remaining: 700 / 8600
Remaining: 800 / 8600
Remaining: 900 / 8600
Remaining: 1000 / 8600
Remaining: 1100 / 8600
Remaining: 1300 / 8600
Remaining: 1500 / 8600
Remaining: 1600 / 8600
Remaining: 1700 / 8600
Remaining: 1900 / 8600
Remaining: 2000 / 8600
Remaining: 2100 / 8600
Remaining: 2200 / 8600
Remaining: 2300 / 8600
Remaining: 2500 / 8600
Remaining: 2600 / 8600
Remaining: 2800 / 8600
Remaining: 3000 / 8600
Remaining: 3100 / 8600
Remaining: 3200 / 8600
Remaining: 3300 / 8600
Remaining: 3400 / 8600
Remaining: 3500 / 8600
Remaining: 3600 / 8600
Remaining: 3700 / 8600
Remaining: 3800 / 8600
Remaining: 3900 / 8600
Remaining: 4000 / 8600
Remaining: 4100 / 8600
Remaining: 4200 / 8600
File not found:  
Remaining: 4300 / 8600
Remaining: 4400 / 8600
Remaining: 4500 / 8600
Remaining: 4600 / 8600
Remaining: 4700 / 8600
Remaining: 4800 / 8600
Remaining: 4900 / 8600
Remaining: 5000 / 8600
Remaining: 5100 / 8600
Re

In [None]:
refreshed = spark.createDataFrame(table_adapter.get_all_entities("training"), schema=image_table_schema)

display(refreshed.limit(10).toPandas())

# refreshed.write.parquet("D:\\data\\processed\\reddit_images_processed_vit_caption.parquet", mode="overwrite")