In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
CHP_ID = "1740"
SUBMODEL = "cond-detr-50" # "cond-detr-50"
MODEL_PATH = f"logs/{SUBMODEL}/checkpoint-{CHP_ID}"
# IMAGE_SHAPE = 1333 _Shape{IMAGE_SHAPE}
THR = 0.01
iou_threshold = 0.8
FILE_NAME = f"{SUBMODEL}_THR{THR*100:.3f}_IOU{iou_threshold:.3f}_ID{CHP_ID}_P"

In [3]:
from transformers import (
	AutoImageProcessor,
	AutoModelForObjectDetection,
	ConditionalDetrImageProcessor,
    ConditionalDetrForObjectDetection
)
from PIL import Image
import torch
from torchvision.ops import nms

import pandas as pd
import numpy as np

In [4]:
from zindi_code.dataset import load_and_format
from zindi_code import CLS_MAPPER

image_folder = "zindi_data/images"

test = load_and_format("zindi_data/ValDataset.csv")
test.sample(5)

Unnamed: 0,image_id,bbox,category_id,id
1069,id_yzmpc6cgm1.jpg,"[2429, 1281, 86, 100]",0,1069
2024,id_087dra2apu.jpg,"[2619, 2100, 86, 102]",0,2024
558,id_tt0yabytpy.jpg,"[2304, 1379, 238, 238]",1,558
1361,id_hmd5mr6xz8.jpg,"[1136, 2647, 95, 95]",0,1361
213,id_u0ixesq8l0.jpg,"[1527, 930, 238, 185]",1,213


In [5]:
model_pth = MODEL_PATH

image_processor: ConditionalDetrImageProcessor = AutoImageProcessor.from_pretrained(
    model_pth
)
model: ConditionalDetrForObjectDetection = AutoModelForObjectDetection.from_pretrained(
    model_pth
)

model = model.to(
    "cuda"
).train(False)

In [6]:
image_processor

ConditionalDetrImageProcessor {
  "do_convert_annotations": true,
  "do_normalize": true,
  "do_pad": true,
  "do_rescale": true,
  "do_resize": true,
  "format": "coco_detection",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ConditionalDetrImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "pad_size": null,
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "longest_edge": 1333,
    "shortest_edge": 800
  }
}

In [7]:
model.config.id2label

{0: 'Trophozoite', 1: 'WBC'}

In [8]:
from transformers.image_transforms import center_to_corners_format
from torch import nn
from typing import Union, List, Tuple
from transformers.utils.generic import TensorType

n_run = []

In [9]:
@torch.no_grad()
def make_predictions(images: list[Image.Image]):
	inputs = image_processor(images=images, return_tensors="pt").to("cuda")
	outputs = model(**inputs)
	target_sizes = torch.tensor([image.size[::-1] for image in images])
	return image_processor.post_process_object_detection(
		outputs, threshold=THR, target_sizes=target_sizes
	)

def load_transform(path: str):
	return Image.open(os.path.join(image_folder, path)).convert("RGB")
	image = Image.open(os.path.join(image_folder, path))
	return np.array(image.convert("RGB"))[:, :, ::-1]

def load_images(image_pths: list[str]):
	return [
		load_transform(image_pth)
		for image_pth in image_pths
	]


def predicts(image_pths: list[str]):
	images = load_images(image_pths)
	results = make_predictions(images)
	predictions = []
	for image_pth, result in zip(image_pths, results):
		prediction = []
		if len(result["boxes"]):
			indices = nms(result["boxes"], result["scores"], iou_threshold)
			if not len(indices):
				continue
			for score, label, box in zip(
				result["scores"][indices],
				result["labels"][indices],
				result["boxes"][indices],
			):
				x1, y1, x2, y2 = (round(i, 2) for i in box.tolist())
				prediction.append(
					[
						image_pth,
						x1,
						y1,
						x2 - x1,
						y2 - y1,
						model.config.id2label[label.item()],
						round(score.item(), 3),
					]
				)
		if not len(prediction):
			prediction.append([image_pth] + [0, 0, 0, 0, "NEG", 1.])
		predictions.extend(prediction)
	return pd.DataFrame(
		predictions, columns=["image_id", "x", "y", "w", "h", "category_id", "score"]
	)

In [10]:
image_pths = test["image_id"].unique()[:16]
image_pths

array(['id_w8xnbd5rvm.jpg', 'id_ytq3slqkjm.jpg', 'id_e20xnaq5qn.jpg',
       'id_7fc9zyfy0e.jpg', 'id_6g52lmvz2y.jpg', 'id_z0i61ad0tq.jpg',
       'id_55a6sf8hbe.jpg', 'id_dg0icorzno.jpg', 'id_zdg96srigj.jpg',
       'id_ezd6x40fd0.jpg', 'id_ch6r0g46fr.jpg', 'id_4cotsn0obm.jpg',
       'id_0fdars2kkw.jpg', 'id_4wkzpeu6or.jpg', 'id_idjqlz4ppb.jpg',
       'id_by6e6shi2z.jpg'], dtype=object)

In [11]:
results = predicts(image_pths)

In [12]:
results["category_id"].value_counts()

category_id
Trophozoite    1101
WBC             119
Name: count, dtype: int64

In [13]:
from tqdm import tqdm

In [14]:
batch_size = 16
test_images = test["image_id"].unique()
results = [
	predicts(test_images[i : i + batch_size])
	for i in tqdm(
		range(0, len(test_images), batch_size), total=len(test_images) // batch_size + 1
	) if i < len(test_images)
]

100%|██████████| 18/18 [01:09<00:00,  3.84s/it]


In [15]:
predictions = pd.concat(results, ignore_index=True)

In [16]:
predictions.sample(10)

Unnamed: 0,image_id,x,y,w,h,category_id,score
7849,id_wv8ze1pfbz.jpg,1378.58,686.81,38.96,32.22,Trophozoite,0.126
18402,id_0iiow811re.jpg,2239.5,91.82,168.47,117.1,Trophozoite,0.052
15636,id_4f9wdugdot.jpg,3340.59,2009.79,73.9,73.21,Trophozoite,0.106
17282,id_auoihg3sqz.jpg,778.68,457.54,30.68,26.14,Trophozoite,0.047
19472,id_s6axfrvztf.jpg,2633.45,1271.18,70.29,78.29,Trophozoite,0.077
17226,id_auoihg3sqz.jpg,755.64,589.71,38.67,36.04,Trophozoite,0.251
18781,id_7dyurqpi1w.jpg,1249.49,2721.75,93.08,106.09,Trophozoite,0.065
12718,id_hmd5mr6xz8.jpg,3131.46,1150.58,79.8,88.07,Trophozoite,0.076
19272,id_66de5o5tvg.jpg,2211.52,2805.37,74.79,84.99,Trophozoite,0.19
9511,id_0o3fxfendl.jpg,980.44,170.2,34.09,31.98,Trophozoite,0.125


In [17]:
predictions["category_id"].value_counts(True)

category_id
Trophozoite    0.882925
WBC            0.117075
Name: proportion, dtype: float64

In [19]:
predictions["score"].describe()

count    20773.000000
mean         0.150609
std          0.131796
min          0.032000
25%          0.066000
50%          0.103000
75%          0.176000
max          0.736000
Name: score, dtype: float64

In [20]:
predictions.to_csv(f"zindi_data/validation/prediction_{FILE_NAME}.csv", index=False)

In [21]:
f"zindi_data/validation/prediction_{FILE_NAME}.csv"

'zindi_data/validation/prediction_cond-detr-50_THR1.000_IOU0.800_ID1740_P.csv'