In [11]:
import torch
import numpy as np 
import pandas as pd
import os

import utils
from utils import debug, debugs, debugt
from matplotlib import pyplot as plt
from generators import Torch3DDataset
import generators
from torch.utils.data import DataLoader
import sqlite3
from hungarianmatcher import HungarianMatcher
from setcriterion import SetCriterion
from typing import Sequence
import itertools

import fishdetr3d as detr
utils.reloader(generators)
utils.reloader(Torch3DDataset)
utils.reloader(detr)
None

In [2]:
TABLE = "bboxes_full"
DIR = "/mnt/blendervol/3d_data"
BATCHSIZE = 16
num2str = eval(open(os.path.join(DIR, "metadata.txt")).read())

TORCH_CACHE_DIR = 'torch_cache'
torch.hub.set_dir(TORCH_CACHE_DIR)

In [3]:
db_con = sqlite3.connect(f'file:{os.path.join(DIR,"bboxes.db")}?mode=ro', uri=True)
print("Getting number of images in database")
n_data = pd.read_sql_query(f'SELECT COUNT(DISTINCT(imgnr)) FROM {TABLE}', db_con).values[0][0]

TRAIN_RANGE = (0, int(9/10*n_data))
VAL_RANGE = (int(9/10*n_data), n_data)

Getting number of images in database


In [4]:
traingen = Torch3DDataset(DIR, TABLE, 1, shuffle=False, imgnrs=range(*TRAIN_RANGE))
valgen = Torch3DDataset(DIR, TABLE, 1, shuffle=False, imgnrs=range(*VAL_RANGE))
testgen = Torch3DDataset(DIR+"_test", TABLE, 1, shuffle=False)

In [5]:
debug(traingen)
debug(valgen)
debug(testgen)

[32m(1, <module>)[0m [0mtraingen:[0m Torch3DDataset([33mdata=[m/mnt/blendervol/3d_data, [33mindex_range=[m[0, 29890])
[32m(2, <module>)[0m [0mvalgen:[0m Torch3DDataset([33mdata=[m/mnt/blendervol/3d_data, [33mindex_range=[m[29891, 33212])
[32m(3, <module>)[0m [0mtestgen:[0m Torch3DDataset([33mdata=[m/mnt/blendervol/3d_data_test, [33mindex_range=[m[0, 63])


In [6]:
model = detr.FishDETR()
model.load_state_dict(torch.load('fish_statedicts_3d/weights_2021-03-22/detr_statedicts_epoch18_train0.0505_val0.0506_2021-03-22T10:22:42.pth', map_location='cpu')['model_state_dict'])

Encoder successfully loaded with pretrained weights


<All keys matched successfully>

In [7]:
BATCH_SIZE = 64
testloader = DataLoader(
    dataset = testgen,
    batch_size = BATCH_SIZE,
    collate_fn = detr.collate,
#     pin_memory = True,
    shuffle = True
)

weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1, 'loss_smooth':1}
losses = ['labels', 'boxes_smooth_l1']
matcher = HungarianMatcher(use_giou=False, smooth_l1=False)
criterion = SetCriterion(6, matcher, weight_dict, eos_coef = 0.5, losses=losses)

In [8]:
X, y = next(iter(testloader))

In [9]:
X_, y_ = detr.preprocess(X, y, None)
with torch.no_grad():
    out, loss = model.eval_on_batch(X_, y_, criterion)

In [17]:
df = detr.postprocess_to_df(range(0,64), out, 0.7)
df.to_csv('nogit_test_output.csv')

In [19]:
pd.read_sql("SELECT * FROM bboxes_full",testgen.con).to_csv("nogit_test_labels.csv", index=False)

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

In [None]:
gen = Torch3DDataset(data_dir=DIR, table=TABLE, batch_size=BATCHSIZE, n_classes=6, shuffle=False)

grid = (4,4)
fig, axes = plt.subplots(*grid, figsize=(grid[1]*5,grid[0]*5))

for i, ax in enumerate(np.ravel(axes)):
    X, y = testgen[i]
    left_img = X[0]
    boxes = y['boxes']
    boxes[:,[0,1]] += boxes[:,[2,3]]*0.5 
    utils.plot_bboxes(left_img[0].permute((1,2,0)), classes=y['labels'], boxes=boxes, classmap=num2str, ax=ax)
    