In [None]:
import os
import numpy as np 
import pandas as pd 
from datetime import datetime
import time
import random
from tqdm.auto import tqdm


#Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
import torchvision.transforms as T

#sklearn
from sklearn.model_selection import StratifiedKFold
from skimage import io

################# DETR FUCNTIONS FOR LOSS######################## 
import sys
sys.path.append('./detr_custom/')

from models.matcher import HungarianMatcher
from models.detr import SetCriterion
#################################################################

import matplotlib.pyplot as plt

#Glob
from glob import glob

from typing import Iterable, Sequence, List, Tuple, Dict, Optional, Union, Any
from types import ModuleType
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from generators import BlenderStandardDataset, TorchStandardDataset

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def pytorch_init_janus_gpu():
    device_id = 1
    torch.cuda.set_device(device_id)
    
    # Sanity checks
    assert torch.cuda.current_device() == 1, 'Using wrong GPU'
    assert torch.cuda.device_count() == 2, 'Cannot find both GPUs'
    assert torch.cuda.get_device_name(0) == 'GeForce RTX 2080 Ti', 'Wrong GPU name'
    assert torch.cuda.is_available() == True, 'GPU not available'
    return torch.device('cuda', device_id)

def reloader(module_or_member: Union[ModuleType, Any]):    
    if isinstance(module_or_member, ModuleType):
        importlib.reload(module_or_member)
        return module
    else:
        module = importlib.import_module(module_or_member.__module__)
        importlib.reload(module)
        return module.__dict__[module_or_member.__name__]

In [None]:
seed = 42069
seed_everything(seed)

try:
    device = pytorch_init_janus_gpu()
    print(f'Using device: {device} ({torch.cuda.get_device_name()})')
except AssertionError as e:
    print('GPU could not initialize, got error:', e)
    device = torch.device('cpu')
    print('Device is set to CPU')

In [None]:
TORCH_CACHE_DIR = 'torch_cache'
torch.hub.set_dir(TORCH_CACHE_DIR)
model = torch.hub.load(os.path.join(TORCH_CACHE_DIR, 'facebookresearch_detr_master'), model='detr_resnet50', pretrained=True, source='local').to(device)

In [None]:
#.rotate creates black bars
img = Image.open('test_image2.png')
w, h = img.size

In [None]:
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [None]:
input_img = transform(img)

In [None]:
with torch.no_grad():
    output = model(input_img.unsqueeze(0).to(device))

In [None]:
boxes = output['pred_boxes'][0]
logits = output['pred_logits'][0]

In [None]:
TORCH_CACHE_DIR = 'torch_cache'
DATASET_DIR = '/mnt/blendervol/objdet_std_data'
SQL_TABLE = 'bboxes_std'
WEIGHTS_DIR = 'fish_statedicts'
torchdataset = TorchStandardDataset(DATASET_DIR, SQL_TABLE, 1, shuffle=False, imgnrs=range(0,32))

In [None]:
def sanity_dataset(gen: TorchStandardDataset):
    img, targets = gen[1]
    img = img.numpy().transpose((1,2,0))
    boxes = targets['boxes'].numpy()
    
    h, w = img.shape[:2]
    
    ax = plt.gca()
    ax.imshow(img)
    
    for box in boxes:
        cxy = box[:2]  
        ax.add_patch(plt.Circle(cxy*(h, w), 5, edgecolor='k'))
        
        bw, bh = box[2], box[3]
        tlxy = (cxy - (bw/2, bh/2))*(h, w)
        ax.add_patch(plt.Rectangle(tlxy, bw*w, bh*h, fill=False, lw=2, color='red', alpha=0.2))

sanity_dataset(torchdataset)

In [None]:
def box_xywh_to_cxcywh(bboxes: torch.Tensor):
    '''
    top left point with width height to box center and width and height
    
           w                  w
      x ------->          -------->
      |                  |
    h |           ==>  h |    x
      |                  |
      v                  v
    
    '''
    newtensor = bboxes.detach().clone()
    # (cx, cy, w, h)
    newtensor[:,0] = newtensor[:,0] + newtensor[:,2] * 0.5 # x + w / 2
    newtensor[:,1] = newtensor[:,1] + newtensor[:,3] * 0.5 # y + h / 2
    return newtensor
    
__, y__ = torchdataset[0]
# print(y__)
box_xywh_to_cxcywh(y__['boxes'])

In [None]:
def postprocess2(logits: torch.Tensor, boxes: torch.Tensor):
    keepmask = logits.softmax(-1)[:,:-1].max(-1)[0] > 0.5
    if not any(keepmask):
        return torch.Tensor(), torch.Tensor()
    return logits[keepmask].argmax(-1), boxes[keepmask]

logits_, boxes_ = postprocess2(logits, boxes)

def plot_results2(img: Image.Image, classes: Iterable, boxes: Iterable):
    npimg = np.array(img)
    h, w = np.array(img).shape[:2]
    
    fig, ax = plt.subplots(1,1,figsize=(10,7))
    
    ax.imshow(npimg)

    for class_, bbox in zip(classes, boxes):
        bx, by, bw, bh = bbox
        ax.add_patch(plt.Rectangle((w*bx-bw*w/2,h*by-bh*h/2), bw*w, bh*h, fill=False, color='cyan', linewidth=2))
    
plot_results2(img, logits_, boxes_)

In [None]:
def postprocess(logits: torch.Tensor, boxes: torch.Tensor):
    keepmask = logits.softmax(-1)[:,:-1].max(-1)[0] > 0.7
    return logits[keepmask].argmax(-1), boxes[keepmask]

logits_, boxes_ = postprocess(logits, boxes)

In [None]:
# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

def box_cxcywh_to_xyxy(x: torch.Tensor):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def plot_results(img: Image.Image, classes: Iterable, boxes: Iterable):
    boxes = box_cxcywh_to_xyxy(boxes)
    fig, ax = plt.subplots(figsize=(16,10))
    plt.imshow(img)
    
    w, h = img.size

    boxes[:,[0,2]] *= w
    boxes[:,[1,3]] *= h

    for cls, (xmin, ymin, xmax, ymax) in zip(classes, boxes):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color='cyan', linewidth=3))
        try:
            ax.text(xmin, ymin, CLASSES[cls], fontsize=11, bbox=dict(facecolor='cyan', alpha=0.9))
        except:
            pass
            
    plt.axis('off')
    plt.show()
    
plot_results(img, logits_, boxes_)