In [None]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
CHP_ID = "348"
SUBMODEL = "cond-detr-50" # "cond-detr-50"
MODEL_PATH = f"logs/cond-detr-50/finetuned/finetuned/checkpoint-{CHP_ID}"
# IMAGE_SHAPE = 1333 _Shape{IMAGE_SHAPE}
THR = 0.01
iou_threshold = 0.6
FILE_NAME = f"{SUBMODEL}_THR{THR*100:.3f}_IOU{iou_threshold:.3f}_ID{CHP_ID}"

In [None]:
from transformers import (
	AutoImageProcessor,
	AutoModelForObjectDetection,
	ConditionalDetrImageProcessor,
    ConditionalDetrForObjectDetection
)
from PIL import Image
import torch
from torchvision.ops import nms

import pandas as pd
import numpy as np

In [None]:
from zindi_code.dataset import load_and_format
from zindi_code import CLS_MAPPER

image_folder = "zindi_data/images"

test = load_and_format("zindi_data/ValDataset.csv")
test.sample(5)

In [None]:
model_pth = MODEL_PATH

image_processor: ConditionalDetrImageProcessor = AutoImageProcessor.from_pretrained(
    model_pth
)
model: ConditionalDetrForObjectDetection = AutoModelForObjectDetection.from_pretrained(
    model_pth
)

model = model.to(
    "cuda"
) # .train(False)

In [None]:
image_processor

In [None]:
model.config.id2label

In [None]:
from transformers.image_transforms import center_to_corners_format
from torch import nn
from typing import Union, List, Tuple
from transformers.utils.generic import TensorType

n_run = []

In [None]:
@torch.no_grad()
def make_predictions(images: list[Image.Image]):
	inputs = image_processor(images=images, return_tensors="pt").to("cuda")
	# inputs = image_processor.pad(inputs)
	outputs = model(**inputs)
	target_sizes = torch.tensor([image.size[::-1] for image in images])
	return image_processor.post_process_object_detection(
		outputs, threshold=THR, target_sizes=target_sizes
	)

def load_transform(path: str):
	return Image.open(os.path.join(image_folder, path)).convert("RGB")
	image = Image.open(os.path.join(image_folder, path))
	return np.array(image.convert("RGB"))[:, :, ::-1]

def load_images(image_pths: list[str]):
	return [
		load_transform(image_pth)
		for image_pth in image_pths
	]


def predicts(image_pths: list[str]):
	images = load_images(image_pths)
	results = make_predictions(images)
	predictions = []
	for image_pth, result in zip(image_pths, results):
		prediction = []
		if len(result["boxes"]):
			indices = nms(result["boxes"], result["scores"], iou_threshold)
			if not len(indices):
				continue
			for score, label, box in zip(
				result["scores"][indices],
				result["labels"][indices],
				result["boxes"][indices],
			):
				x1, y1, x2, y2 = (round(i, 2) for i in box.tolist())
				prediction.append(
					[
						image_pth,
						x1,
						y1,
						x2 - x1,
						y2 - y1,
						model.config.id2label[label.item()],
						round(score.item(), 3),
					]
				)
		if not len(prediction):
			prediction.append([image_pth] + [0, 0, 0, 0, "NEG", 1.])
		predictions.extend(prediction)
	return pd.DataFrame(
		predictions, columns=["image_id", "x", "y", "w", "h", "category_id", "score"]
	)

In [None]:
image_pths = test["image_id"].unique()[:16]
image_pths

In [None]:
results = predicts(image_pths)

In [None]:
results["category_id"].value_counts()

In [None]:
from tqdm import tqdm

In [None]:
batch_size = 16
test_images = test["image_id"].unique()
results = [
	predicts(test_images[i : i + batch_size])
	for i in tqdm(
		range(0, len(test_images), batch_size), total=len(test_images) // batch_size + 1
	) if i < len(test_images)
]

In [None]:
predictions = pd.concat(results, ignore_index=True)

In [None]:
predictions.sample(10)

In [None]:
predictions["category_id"].value_counts(True)

In [None]:
predictions["score"].describe()

In [None]:
f"zindi_data/validation/prediction_{FILE_NAME}.csv"

In [None]:
predictions = predictions.rename(columns={"x": "xmin", "y": "ymin"})
predictions = predictions.rename(columns={"category_id": "class", "image_id": "Image_ID", "score": "confidence"})

predictions["xmax"] = predictions["xmin"] + predictions["w"]
predictions["ymax"] = predictions["ymin"] + predictions["h"]

predictions.to_csv(f"zindi_data/validation/prediction_{FILE_NAME}.csv", index=False)