In [1]:
import numpy as np
import pandas as pd
from PIL import Image
import glob
import json
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
from ultralytics import YOLO
from torch.utils.data import Dataset
from torchvision import models
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class clsmodel(torch.nn.Module):
    def __init__(self):
        super(clsmodel,self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.linear1 = torch.nn.Linear(resnet.fc.in_features,1024)
        self.linear2 = torch.nn.Linear(1024,1)
        self.sigmoid = torch.nn.Sigmoid()
        self.relu = torch.nn.ReLU()
        #self.dropout = torch.nn.Dropout(0.25)
        self.model = torch.nn.Sequential(*list(resnet.children())[:-1])
        #deactivate_requires_grad(self.backbone)
    def forward(self,x):
        features = self.model(x).flatten(start_dim=1)
        x = self.relu(self.linear1(features))
        x = self.linear2(x)
        pred = self.sigmoid(x)
        return pred

In [3]:
device = 'cuda:4'
model = YOLO('/workspace/jay/DDP/Ocelot/yolo_binary/runs/detect/train/weights/best.pt')
classifier = clsmodel().to(device)
classifier.load_state_dict(torch.load('/workspace/jay/DDP/Ocelot/classifier/ckpts_v1/17_0.8209.pt'))
classifier = classifier.eval()
files = sorted(glob.glob('/workspace/jay/DDP/Ocelot/ocelot2023/images/train/cell/*.jpg'))
#temp = glob.glob('/workspace/jay/DDP/Ocelot/yolo_binary/datasets/cell_detect_33-1/test/images/*.jpg')



In [4]:
pred_json = {
    "type": "Multiple points",
    "num_images": len(files),
    "points": [],
    "version": {
        "major": 1,
        "minor": 0,
    }
}

In [5]:
for j,file in enumerate(tqdm(files)):
    idx = int(file.split('/')[-1][:-4]) -1 
    img = np.array(Image.open(file))
    out = model.predict(file,conf=0.2,iou=0.5)
    out = out[0].cpu().numpy()
    boxes = out.boxes.data
    
    for i in range(len(boxes)):
        x,y = min(1023,int((boxes[i][0]+boxes[i][2])/2)), min(1023,int((boxes[i][1]+boxes[i][3])/2))
        left,right = max(0,y-64), min(1024,y+64)
        top,bottom = max(0,x-64), min(1024,x+64)
        patch = img[left:right,top:bottom]
        patch = patch/255
        patch = patch - 0.5
        patch = torch.Tensor(np.moveaxis(patch, -1, 0))
        patch = patch[None,:]
        patch = patch.to(device)
        with torch.no_grad():
            prob = classifier(patch)
        prob = prob.cpu().numpy()[0][0]
        if prob<=0.5:
            clas = 0
            prob = 1 - prob
        else:
            clas = 1
        point = {
                "name": f"image_{idx}",
                "point": [int(x), int(y), int(clas)+1],
                "probability": prob.astype(float),  # dummy value, since it is a GT, not a prediction
                }
        pred_json["points"].append(point)
                                

  0%|                                                                                                                                                                              | 0/400 [00:00<?, ?it/s]Ultralytics YOLOv8.0.20 🚀 Python-3.8.13 torch-1.13.1+cu117 CUDA:0 (NVIDIA A100-SXM4-80GB, 81251MiB)
Model summary (fused): 218 layers, 25840918 parameters, 0 gradients, 78.7 GFLOPs
 25%|█████████████████████████████████████████                                                                                                                           | 100/400 [02:05<05:38,  1.13s/it]Exception in thread Thread-35:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/urllib3/connectionpool.py", line 703, in urlopen
    httplib_response = self._make_request(
  File "/opt/conda/lib/python3.8/site-packages/urllib3/connectionpool.py", line 386, in _make_request
    self._validate_conn(conn)
  File "/opt/conda/lib/python3.8/site-packages/urllib3/connectionpool.py", l

In [6]:
# # with open("/workspace/jay/DDP/Ocelot/jsons/pred1.json", "w") as g:
with open("/workspace/jay/DDP/Ocelot/ocelot23algo/evaluation/yolo_pred_binary.json", "w") as g:
    json.dump(pred_json, g)
    print("JSON file saved")

JSON file saved


213