In [None]:
import os
import pickle
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from IPython.display import FileLink

In [None]:
pip install --upgrade torchvision

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from PIL import Image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
img_path = '../input/photos-for-object-detection/photos'
existing_file = '../input/picklebackups/img_objects.pickle'
out_file = '../working/img_objects.pickle'

In [None]:
os.listdir()

In [None]:
detection_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
    pretrained=True, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None)

In [None]:
detection_model.to(device).eval()
print(1)

In [None]:
def get_prediction(model, image, threshold):
    
    preds = model(image)[0]
    
    keep_boxes = torchvision.ops.nms(preds['boxes'], preds['scores'], 0.5)
    
    classes = list(preds['labels'].cpu().numpy())
    classes = [classes[idx] for idx in keep_boxes]
    boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(preds['boxes'].cpu().detach().numpy())]
    boxes = [boxes[idx] for idx in keep_boxes]
    scores = list(preds['scores'].cpu().detach().numpy())
    scores = [scores[idx] for idx in keep_boxes]
    
    valid_boxes = [scores.index(x) for x in scores if x>threshold]
    if not valid_boxes: return [()]
    p_thresh = valid_boxes[-1]
    pred_boxes = boxes[:p_thresh+1]
    pred_classes = classes[:p_thresh+1]
    pred_scores = scores[:p_thresh+1]
    
    return list(zip(pred_boxes, pred_classes, pred_scores))

In [None]:
class ImgDataset(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        self.all_imgs = os.listdir(main_dir)

    def __len__(self):
        return len(self.all_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.all_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image, img_loc.split('/')[-1].split('.')[0], image.size

In [None]:
if not os.path.isfile(existing_file):
    found_objects = {}
else:
    with open(existing_file, 'rb') as img_dict:
        found_objects = pickle.load(img_dict)

In [None]:
trsfm = transforms.Compose([transforms.ToTensor()])

In [None]:
detect_dataset = ImgDataset(img_path, transform=trsfm)
detect_loader = DataLoader(detect_dataset, batch_size=1, shuffle=False, 
                               num_workers=0, drop_last=True)

In [None]:
count = len(found_objects)
for img, imgname, imgsize in tqdm(detect_loader):
    if imgname not in found_objects:
        count += 1
        img = img.to(device)
        found_objects[imgname] = get_prediction(detection_model, img, 0.5)
        if not count % 10000:
            with open(out_file, 'wb') as img_dict:
                pickle.dump(found_objects, img_dict, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open(out_file, 'wb') as img_dict:
    pickle.dump(found_objects, img_dict, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
FileLink(r'img_objects.pickle')