In [21]:
# !pip install -q git+https://github.com/huggingface/transformers.git

In [40]:
import numpy as np
import matplotlib.pyplot as plt
import gc
from PIL import Image
import transformers
from transformers import pipeline
from datetime import date
import json
import os
from itertools import groupby
from label_studio_converter import brush


path_to_folder = '../Pictures/data_for_training_filtered'
output_folder = 'SAM_masks'


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    del mask
    gc.collect()

def show_masks_on_image(raw_image, masks):
  plt.imshow(np.array(raw_image))
  ax = plt.gca()
  ax.set_autoscale_on(False)
  for mask in masks:
      show_mask(mask, ax=ax, random_color=True)
  plt.axis("off")
  plt.show()
  del mask
  gc.collect()

def extract_masks(file, generator):
    raw_image = Image.open(file).convert("RGB")
    outputs = generator(raw_image, points_per_batch=64)
    return outputs

In [23]:
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)

In [24]:
fileList = os.listdir(path_to_folder)

In [49]:
jsondata = []
for file in fileList[:1]:
    # print(file)
    filedata = {
        "data": {
            "image": path_to_folder+'/'+file
            },
        "predictions": []
    }

#     import json
# import base64

# data = {}
# with open('some.gif', mode='rb') as file:
#     img = file.read()
# data['img'] = base64.encodebytes(img).decode('utf-8')

# print(json.dumps(data))


    prediction = {
            "model_version": date.today().isoformat()+'_'+transformers.__version__,
            # "score": outputs['scores'][i],
            "result": [
                
            ]
        }
    outputs = extract_masks(path_to_folder+'/'+file, generator)

    for i in range(len(outputs['masks'])):
        uniques, counts = np.unique(outputs['masks'][i], return_counts=True)
        if counts[1] < 150:
            continue
        height, width = outputs['masks'][i].shape[:2]
        # print(outputs['masks'][i].shape)
        result = {
                    "original_width": width,
                    "original_height": height,
                    "image_rotation": 0,
                    "value": {
                        "format": "rle"

                    },
                    "id": str(i),
                    "from_name": "tag",
                    "to_name": "image",
                    "type": "brushlabels"
                }

        rle = brush.mask2rle(outputs['masks'][i]*np.uint8(255))

        result['value']['rle'] = rle
        prediction['result'].append(result)

    filedata['predictions'].append(prediction)
    jsondata.append(filedata)

with open("labelstudio_annos.json", "w") as json_file:
    json.dump(jsondata, json_file)

In [27]:
# show_masks_on_image(raw_image, masks)