In [None]:
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

import os
import random
import json
import numpy as np
import torch
from pathlib import Path
import imageio.v3 as imageio
from PIL import Image
from minio import Minio
import torchvision.transforms.functional as F

from mmm.data_loading.geojson.GeoAnno import GeoAnno
from mmm.data_loading.geojson.utils import rasterize_annotations
from mmm.data_loading.geojson.WSIGeojsonDataset import WSIGeojsonDataset
from mmm.data_loading.s3 import get_args, get_kwargs, upload_img
from mmm.labelstudio_ext.utils import binary_mask_to_result
from mmm.labelstudio_ext.projects import LSProject, Client
from mmm.labelstudio_ext.LabelstudioCredentials import LabelstudioCredentials

# Importing data with GeoJSON labels

geojson_paths should be a map from geojson annotations to an associated image file

In [None]:
datadir = Path("/data/jnm_data/v1/")
BUCKETNAME, UPLOAD_PREFIX, LS_PROJECTNAME = "dataroot", "jnm_data3", "jnquanti2"
mclient = Minio(*get_args(), **get_kwargs())
ls_client = LabelstudioCredentials(url="http://localhost:9505", token="1234567890").build_client()
trainval_proj = LSProject(LSProject.Config(name=LS_PROJECTNAME), ls_client)
mclient.bucket_exists(BUCKETNAME)
geojson_paths = {p: p.parent/f"{p.stem}.tif" for p in datadir.glob("*.geojson")}

def process_mask(mask: np.ndarray):
    # Crop 68 pixels from the bottom
    mask = mask[:-68]
    # Resize to 224x224
    mask = F.resize(torch.from_numpy(mask).unsqueeze(0), (224, 224))[0].numpy()
    return mask

def process_image(img: Image) -> Image:
    # Crop 68 pixels from the bottom
    img = img.crop((0, 0, img.width, img.height-68))
    # Resize to 224x224
    img_tensor = F.resize(F.to_tensor(img), (224, 224))
    return F.to_pil_image(img_tensor)

In [None]:
ds = WSIGeojsonDataset(WSIGeojsonDataset.Config(), geojson_paths)
CLASS_NAMES = ds.get_classes()

def results_from_geojson(annos: list[GeoAnno], img_path: Path) -> Image:
    img_np_arr = imageio.imread(img_path)
    results = []
    for anno in annos:
        binary_mask = rasterize_annotations(
            img_np_arr.shape[0],
            img_np_arr.shape[1],
            [anno],
            anno_labels=[CLASS_NAMES.index(x.get_class_name()) for x in annos],
        )
        binary_mask = process_mask(binary_mask)
        result = binary_mask_to_result(binary_mask, anno.get_class_name(), brush_name="tag")
        results.append(result)
    return results

res = []

for testcase in ds:
    img: Image = Image.open(testcase["wsi"])
    img = process_image(img)
    d = upload_img(mclient, BUCKETNAME, UPLOAD_PREFIX, img, testcase["wsi"].stem)
    d["annotations"] = [
            {
                "result": results_from_geojson(testcase["annos"], testcase["wsi"])
            }
    ]
    res.append(d)

In [None]:
trainval_proj.get_project().import_tasks(res)

# Importing unlabeled data

In [None]:
# import images without annotations
res = []
for imagepath in datadir.glob("*.tif"):
    if not imagepath in list(geojson_paths.values()):
        img: Image = Image.open(imagepath)
        img = process_image(img)
        d = upload_img(mclient, BUCKETNAME, UPLOAD_PREFIX, img, imagepath.stem)
        res.append(d)
    else:
        print(f"Skipping {imagepath.stem} because it has annotations")
    
trainval_proj.get_project().import_tasks(res)