In [1]:
# !pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import gc
from PIL import Image
import transformers
from transformers import pipeline
from datetime import date
import json
import os
from itertools import groupby
from label_studio_converter import brush
from label_studio_sdk import Client
import torch


LABEL_STUDIO_URL = 'http://127.0.0.1:8080'
API_KEY = ''
PROJECT_ID = 1


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    del mask
    gc.collect()

def show_masks_on_image(raw_image, masks):
  plt.imshow(np.array(raw_image))
  ax = plt.gca()
  ax.set_autoscale_on(False)
  for mask in masks:
      show_mask(mask, ax=ax, random_color=True)
  plt.axis("off")
  plt.show()
  del mask
  gc.collect()

def extract_masks(files:list, generator):
    # raw_image = Image.open(file).convert("RGB")
    outputs = generator([Image.open(file).convert("RGB") for file in files], points_per_batch=32)
    return outputs

In [3]:
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)

In [4]:
# Connect to the Label Studio API and check the connection
ls = Client(url=LABEL_STUDIO_URL, api_key=API_KEY)
ls.check_connection()

#get the project
project = ls.get_project(PROJECT_ID)

#get all tasks for the project
task_ids = project.get_tasks_ids()
print(len(task_ids))

530


In [5]:
# project.get_files_from_tasks([project.get_task(1)])

In [6]:
# masks = extract_masks([project.get_files_from_tasks([project.get_task(task_id)])[0] for task_id in task_ids], generator)

In [None]:
for task_id in task_ids:

#     import json
# import base64

# data = {}
# with open('some.gif', mode='rb') as file:
#     img = file.read()
# data['img'] = base64.encodebytes(img).decode('utf-8')

# print(json.dumps(data))

    results= []
    outputs = extract_masks([project.get_files_from_tasks([project.get_task(task_id)])[0]], generator)[0] #only works if label studio runs on the same device
    torch.cuda.empty_cache()
    for i in range(len(outputs['masks'])):
        uniques, counts = np.unique(outputs['masks'][i], return_counts=True)
        if counts[1] < 150:
            continue
        height, width = outputs['masks'][i].shape[:2]
        # print(outputs['masks'][i].shape)
        result = {
                    "original_width": width,
                    "original_height": height,
                    "image_rotation": 0,
                    "value": {
                        "format": "rle"

                    },
                    "id": str(i),
                    "from_name": "tag",
                    "to_name": "image",
                    "type": "brushlabels"
                }

        rle = brush.mask2rle(outputs['masks'][i]*np.uint8(255))

        result['value']['rle'] = rle
        results.append(result)
    
    project.create_prediction(task_id, result=results)
