In [None]:
import torchvision
import os
import numpy as np
import torch
from PIL import Image
import pandas as pd
import cv2
import pickle
from tqdm.notebook import tqdm
import torchvision.transforms as T
import matplotlib.pyplot as plt

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
VALID_IMGS_PATH = "valid_imgs.pkl"
MODEL_PATH = "trained_models/checkpoint-29.pt"
MIN_CONFIDENCE = 0.65

In [None]:
class CellDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths):
        self.img_paths = img_paths
        self.transforms = T.Compose([T.ToTensor()])
    def __getitem__(self, idx):
        x = Image.open(self.img_paths[idx]).convert("RGB")
        return self.transforms(x)
    def __len__(self):
        return len(self.img_paths)

In [None]:
valid_imgs = pickle.load(open(VALID_IMGS_PATH, "rb"))

model = torch.load(MODEL_PATH)
model.eval()
model.to(DEVICE)

ds = CellDataset(valid_imgs)
dl = torch.utils.data.DataLoader(ds, 1)

raw_preds = []
with torch.no_grad():
    for batch in tqdm(dl, total=len(dl)):
        raw_preds += model(batch.to(DEVICE))

In [None]:
rnd_ids = np.random.randint(0, len(valid_imgs), size=15)
for img_path, preds in zip(np.array(valid_imgs)[rnd_ids], np.array(raw_preds)[rnd_ids]):
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    
    scores = preds["scores"].cpu().numpy()
    above_threshold_ids = scores > MIN_CONFIDENCE
    
    scores = scores[above_threshold_ids]
    boxes = preds["boxes"].cpu().numpy()[above_threshold_ids]
    for score, box in zip(scores, boxes):
        x1, y1, x2, y2 = box.astype(int)
        cv2.rectangle(img, (x1, y1), (x2, y2), tuple([int(np.random.rand() * 255) for _ in range(3)]), 5)
    
    plt.figure(figsize=(12,12))
    plt.imshow(img)
    plt.axis("off")
    plt.show()