In [None]:
from typing import Generic, Optional, Tuple, List, Callable, Iterable, Mapping

import numpy as np
from torchvision.models import resnet50
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import utils
from utils import debugt, debugs, debug

torch.hub.set_dir('torch_cache')
import fishdetr_batchboy as detr
import contextlib
from generators import TorchStereoDataset
import re
from matplotlib import pyplot as plt
import plotly.express as px

import sys
sys.path.append('./detr_custom/')
from models.matcher import HungarianMatcher
from models.detr import SetCriterion
import os

In [None]:
debugt = utils.reloader(debugt)

In [None]:
utils.seed_everything(42069)

try:
    device = utils.pytorch_init_janus_gpu()
    print(f'Using device: {device} ({torch.cuda.get_device_name()})')
    print(utils.get_cuda_status(device))
except AssertionError as e:
    print('GPU could not initialize, got error:', e)
    device = torch.device('cpu')
    print('Device is set to CPU')

In [None]:
DATA_DIR = '/mnt/blendervol/leftright_left_data'
TABLE = 'bboxes_std'

In [None]:
traingen = TorchStereoDataset(DATA_DIR, TABLE, shuffle=False, imgnrs=range(8,8+32))
loader1 = torch.utils.data.DataLoader(
    traingen,
    batch_size=32,
    collate_fn=lambda x: tuple(zip(*x))
)

loader2 = torch.utils.data.DataLoader(
    traingen,
    batch_size=32,
    collate_fn=detr.collate
)

X1, y1 = next(iter(loader1))
X2, y2 = next(iter(loader2))

X1 = list(X1)
X2 = list(X2)
y1 = detr.label_handler(y1, device)
y2 = detr.label_handler(y2, device)

In [None]:
trans = T.Compose([T.Resize(800), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

X2[0] = trans(X2[0]).to(device)
X2[1] = trans(X2[1]).to(device)

In [None]:
utils.reloader(detr)
model = detr.FishDETR().to(device)
model.encoder = model.encoder.eval()

In [None]:
@torch.no_grad()
def sanity_check_singles(X: List[Tuple[torch.tensor, torch.tensor]], model, device=None):
    X: Tuple[Tuple[torch.Tensor, torch.Tensor]]
    output = torch.cat([model(l.to(device)) for l, r in X], axis=0)
    return output

@torch.no_grad()
def sanity_check_batch(X: Tuple[torch.Tensor, torch.Tensor], model: nn.Module, device: torch.device=None):    
    X: Tuple[torch.Tensor, torch.Tensor]
    output = model(X[0].to(device))
    return output

# with torch.no_grad():
#     output1 = sanity_check_singles(X1, model.encoder, device)
#     output2 = sanity_check_batch(X2, model.encoder, device)
#     debugs(output1)
#     debugs(output2)
#     print()
#     debug(output1[0])
#     debug(output2[0])
#     print()
#     debug(torch.allclose(output1, output2))
#     diff = output1 - output2
#     debug(diff.shape)
#     debug(abs(diff).max())
#     debug(abs(diff).mean())
#     debug(abs(diff).std())
#     fig = px.histogram(diff.flatten(), nbins=50)
#     fig.show()
#     fig = px.histogram(abs(diff.cpu().numpy().ravel()[::16]), nbins=200)
#     fig.show()

In [None]:
H = [None]*2
with torch.no_grad():
    H[0] = model.encoder(X2[0].to(device))
    H[1] = model.encoder

In [None]:
debugs(X2[0])

In [None]:
@utils.interruptable
def train_head(model, epochs: int=1):        
    weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1}
    losses = ['labels', 'boxes', 'cardinality']
    matcher = HungarianMatcher()
    criterion = SetCriterion(6, matcher, weight_dict, 0.5, losses).to(device)
    optimizer = torch.optim.AdamW(model.decoder.parameters(), lr=1e-4)
    
    model.decoder.train()
    criterion.train()
    
    running_train_loss = 0.0
    for epoch in range(1,epochs+1):
        output = model(X2)
        loss_dict = criterion(output, y2)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        optimizer.zero_grad()
        losses.backward() # Computes gradients
        optimizer.step() # Do a gradient step
        
        running_train_loss += losses.item()
        train_loss = running_train_loss / (epoch)
        if not epoch % 20: print(losses.item())
            
    return output

output = train_head(model, 100000)

In [None]:
# utils.save_model(model, 'batch_overfit.pth')

In [None]:
utils.reloader(detr)

In [None]:
def box_cxcywh_to_xyxy(x: torch.Tensor):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def plot_results(img, classes: Iterable, boxes: Iterable, classmap: Optional[Mapping[int, str]]=None, ax: Optional=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(16,10))
        
    img = np.array(img)
    ax.imshow(img.clip(0,1))
    
    if len(boxes) != 0:
        h, w = img.shape[:2]
        boxes = box_cxcywh_to_xyxy(boxes)
        boxes[:,[0,2]] *= w
        boxes[:,[1,3]] *= h
        
        for cls, (xmin, ymin, xmax, ymax) in zip(classes, boxes):
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                       fill=False, color='cyan', linewidth=3))
            try:
                strcls = classmap[int(cls)]
            except:
                strcls = str(int(cls))
                    
            ax.text(xmin, ymin, strcls, fontsize=11, bbox=dict(facecolor='cyan', alpha=0.9))
    
    if ax is None:
        ax.axis('off')
        plt.show()
        
    return ax

    
def eval_model(model, img: torch.Tensor, classmap: Optional[Mapping[int, str]]=None, ax: Optional=None):
    with torch.no_grad():
        model.eval()
        
        output = model((img[0].to(device), img[1].to(device)))
        
        boxes = output['pred_boxes'][0]
        logits = output['pred_logits'][0]
        
        logits_, boxes_ = detr.postprocess(logits, boxes)
        
        plot_results(img[0][0].cpu().numpy().transpose((1,2,0)), logits_, boxes_, classmap, ax=ax)
        
        
def eval_compare_model(model: nn.Module, gen: Iterable, index: int=0, classmap: Optional[Mapping[int, str]]=None):
    x, y = gen[index]
    fig, axes = plt.subplots(1,2,figsize=(15,7))
    eval_model(model, detr.img_handler([x])[0], classmap, axes[0])
    plot_results(x[0][0].cpu().numpy().transpose((1,2,0)), y['labels'], y['boxes'], classmap, axes[1])
    axes[0].set_title('Predicted')
    axes[1].set_title('Real')

num2name = eval(open(os.path.join(DATA_DIR,"metadata.txt"), 'r').read())
for i in range(32):
    eval_compare_model(model, traingen, index=i, classmap=num2name)
    plt.show()