In [None]:
# CHANGE THESE
CONFIG_DIR = Path()
HUGGINGFACE_MODEL = 'valhalla/distilbart-mnli-12-1'#"valhalla/distilbart-mnli-12-1" "facebook/bart-large-mnli"

In [None]:
import sys
import os
import zipfile
from pathlib import Path
import pandas as pd
import scipy.special as sp
import numpy as np
from PIL import Image
import torch
import open_clip
import json
from pathlib import Path
import argparse
import pickle
import traceback
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import yaml 
from multiprocessing import Pool
from sagemaker.huggingface import HuggingFaceModel

script_path = Path(os.path.dirname(os.path.abspath(sys.argv[0])))
base_path = script_path.parent.absolute()
sys.path.append(base_path / 'cp')
sys.path.append(base_path / 'utils')
from utils.pets_classes import PETS_CLASSES, PETS_GENERIC_CLASSES
from utils.fitz17k_classes import FITZ17K_CLASSES, FITZ17K_GENERIC_CLASSES
from utils.medmnist_classes import MEDMNIST_CLASSES, MEDMNIST_GENERIC_CLASSES
from utils.imagenet_classes import IMAGENET_CLASSES, IMAGENET_GENERIC_CLASSES
from utils.caltech256_classes import CALTECH256_CLASSES, CALTECH256_GENERIC_CLASSES

config = {}
with open(CONFIG_DIR, "r") as yaml_file:
    config = yaml.safe_load(yaml_file)

for k, v in config.items():
    if (k[-4:] == '_dir'):
        config[k] = Path(v)

CONTEXT_DIRECTORY = Path("/home/sagemaker-user") / config['reverse_image_store_dir'].name
IMAGE_PLAUSIBILITIES = Path("/home/sagemaker-user") / config['image_plausibility_store_dir'].name
CALIB_IMAGE_DIRECTORY = Path("/home/sagemaker-user") / config['scraping_store_dir'].name
DATASET = config['dataset']

if DATASET == 'MedMNIST':
    LABELS = MEDMNIST_CLASSES
    PSEUDO_LABELS = MEDMNIST_GENERIC_CLASSES
elif DATASET == 'FitzPatrick17k':
    LABELS = FITZ17K_CLASSES
    PSEUDO_LABELS = FITZ17K_GENERIC_CLASSES
elif DATASET == 'OxfordPets':
    LABELS = PETS_CLASSES
    PSEUDO_LABELS = PETS_GENERIC_CLASSES
elif DATASET == 'ImageNet':
    LABELS = IMAGENET_CLASSES
    PSEUDO_LABELS = IMAGENET_GENERIC_CLASSES
elif DATASET == "Caltech256":
    LABELS = CALTECH256_CLASSES
    PSEUDO_LABELS = CALTECH256_GENERIC_CLASSES
else:
    LABELS = None

In [None]:
import sagemaker
import boto3

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

print(f"sagemaker role arn: {role}")

In [None]:
os.system(f"aws s3 cp s3://sagemaker-datasets-hwei0/{config['reverse_image_store_dir'].name}.zip /home/sagemaker-user --recursive")
os.system(f"aws s3 cp s3://sagemaker-datasets-hwei0/{config['scraping_store_dir'].name}.zip /home/sagemaker-user --recursive")

In [None]:
with zipfile.ZipFile(Path("/home/sagemaker-user") / f"{config['reverse_image_store_dir'].name}.zip", 'r') as zip_ref:
    zip_ref.extractall(Path("/home/sagemaker-user"))
with zipfile.ZipFile(Path("/home/sagemaker-user") / f"{config['scraping_store_dir'].name}.zip", 'r') as zip_ref:
    zip_ref.extractall(Path("/home/sagemaker-user"))

In [None]:
# Hub model configuration <https://huggingface.co/models>
hub = {
  'HF_MODEL_ID':HUGGINGFACE_MODEL, # model_id from hf.co/models
  'HF_TASK':'zero-shot-classification'                           # NLP task you want to use for predictions
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   env=hub,  # path to your trained sagemaker model
   role=role, # iam role with permissions to create an Endpoint
   transformers_version="4.26", # transformers version used
   pytorch_version="1.13", # pytorch version used
   py_version="py39", # python version of the DLC
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
   initial_instance_count=1,
   instance_type="ml.m5.xlarge"
)

In [None]:

def scores_converter(scores, labels):
    dict_scores = {}
    n = len(scores['labels'])
    for i in range(0, n):
        label = scores['labels'][i]
        dict_scores[label] = scores['scores'][i]
    score_vals = []
    for label in labels:
        score_vals.append(dict_scores[label])
    return score_vals

def score_generation(label):
    print("Beginning Score Generation: {label}".format(label=label))
    os.makedirs(IMAGE_PLAUSIBILITIES / label, exist_ok=True)
    n = 0
    for file in os.listdir(CONTEXT_DIRECTORY / label):
        # Load captions 
        if file.endswith("_debug.pkl") or file.endswith("events.log"): continue
        try:
            with open(CALIB_IMAGE_DIRECTORY / label / (file.split('.')[0]+'.caption'), 'r') as read:
                title = "\n".join([line.rstrip() for line in read])
        except:
            print('ERROR PKL LOAD')
            print(traceback.format_exc())
            continue
        captions = pickle.load(open(CONTEXT_DIRECTORY / label / file, 'rb'))
        if len(captions) <= 1: 
            print('ERROR # CAPTIONS:' + str(captions))
            continue
        # Main Score 
        main_score = predictor.predict({
            "inputs": [title],
            "parameters": {'candidate_labels': list(LABELS.values()), 'multi_label': True, 'use_cache':True}
        })
        main_score = scores_converter(main_score, list(LABELS.values()))
        main_score = torch.tensor(main_score)
        # Second Score
        second_score = []
        second_search = captions[0:min(10, len(captions))]
        label_set = list(set(PSEUDO_LABELS.values()))
        for caption in second_search:
            score_dict = predictor.predict({
                "inputs":[caption], 
                "parameters": {'candidate_labels': label_set, 'multi_label':True, 'use_cache':True}
            })
            score = scores_converter(score_dict, list(PSEUDO_LABELS.values()))
            second_score.append(score)
        second_score = [torch.tensor(score) for score in second_score]
        second_score = torch.stack(second_score)
        torch.save(main_score, IMAGE_PLAUSIBILITIES / label / (file.split(".")[0] + '_main'))
        torch.save(second_score, IMAGE_PLAUSIBILITIES / label / (file.split(".")[0] + '_second'))
        print('a')
        n += 1
        if n >= 10: break

    sys.stdout.flush()


In [None]:


# Encode Labels
#label_embed = model.encode([label for label in LABELS.values()])
labels = [label.split(',')[0] for label in LABELS.values()]
print(labels)
#pseudo_embed = model.encode([label for label in PSEUDO_LABELS.values()])
# Loop through caption folders
dir_labels = []
for label in os.listdir(CONTEXT_DIRECTORY):
    try: 
        int(label)
    except:
        continue
    if label.endswith("events.log"): 
        continue
    dir_labels.append(label)

with Pool(processes=config['num_selenium_threads']) as executor:
    executor.map(score_generation, dir_labels)

In [None]:
os.system(f"aws s3 cp {IMAGE_PLAUSIBILITIES.resolve()} s3://sagemaker-datasets-hwei0 --recursive")