In [9]:
from typing import Generic, Optional, Tuple, List, Callable, Iterable

import numpy as np
from torchvision.models import resnet50
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import utils
from utils import debugt, debugs, debug

torch.hub.set_dir('torch_cache')
import fishdetr_batchboy as detr
import contextlib
from generators import TorchStereoDataset
import re
from matplotlib import pyplot as plt
import plotly.express as px

In [2]:
utils.seed_everything(42069)

In [3]:
DATA_DIR = '/mnt/blendervol/leftright_left_data'
TABLE = 'bboxes_std'

In [57]:
utils.reloader(detr)

traingen = TorchStereoDataset(DATA_DIR, TABLE, shuffle=False, imgnrs=range(8,8+16))
loader1 = torch.utils.data.DataLoader(
    traingen,
    batch_size=16,
    collate_fn=lambda x: tuple(zip(*x))
)

loader2 = torch.utils.data.DataLoader(
    traingen,
    batch_size=16,
    collate_fn=detr.collate
)

X1, __ = next(iter(loader1))
X2, __ = next(iter(loader2))

encoder = detr.Encoder()
encoder = encoder.backbone
encoder = encoder.eval()

Encoder successfully loaded with pretrained weights


In [153]:
class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.mouth = nn.Conv2d(3, 3, 1)           
        
        self.throat = nn.Sequential(
            nn.Conv2d(3, 64, 7),
            nn.ReLU(),
        )
        
        self.tummy = nn.Sequential(*[
            nn.Conv2d(64, 64, 5),
            nn.MaxPool2d(2,2),
            nn.ReLU(),
            nn.BatchNorm2d(64, track_running_stats=False)
        ]*4)
        
        self.ass = nn.Sequential(
            nn.Conv2d(64, 16, 1),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.mouth(x-0.42069)
        x = self.throat(x-0.42069)
        x = self.tummy(x-0.42069)
        x = self.ass(x-0.42069)
        return x
    
encoder = TestNet().eval()
# encoder = resnet50().eval()

In [154]:
@torch.no_grad()
def sanity_check_singles(X):
    X: Tuple[Tuple[torch.Tensor, torch.Tensor]]
    debugt(X[0])    
    output = torch.cat([encoder(l) for l, r in X], axis=0)
    debugs(output)
    return output
    
output1 = sanity_check_singles(X1)

[32m(4, sanity_check_singles)[0m X[0]: <class 'tuple'>, len: 2
[32m(6, sanity_check_singles)[0m output: torch.Size([16, 16, 21, 21])


In [155]:
@torch.no_grad()
def sanity_check_batch(X):    
    X: Tuple[torch.Tensor, torch.Tensor]
    debugt(X[0])    
    output = encoder(X[0])
    debugs(output)
    return output
    
output2 = sanity_check_batch(X2)

[32m(4, sanity_check_batch)[0m X[0]: <class 'torch.Tensor'>, len: 16
[32m(6, sanity_check_batch)[0m output: torch.Size([16, 16, 21, 21])


In [156]:
torch.allclose(output1, output2)

False

In [157]:
diffs = abs(output1 - output2)
ratios = output1 / output2

In [160]:
debug(output1[0][0])
debug(output2[0][0])

[32m(1, <module>)[0m output1[0][0]: tensor([[0.6074, 0.6005, 0.5155, 0.5521, 0.5438, 0.4940, 0.5542, 0.6002, 0.5728,
         0.5536, 0.4751, 0.5436, 0.5094, 0.4886, 0.5616, 0.5494, 0.5177, 0.4965,
         0.4882, 0.5905, 0.5594],
        [0.6626, 0.7094, 0.6092, 0.7164, 0.6890, 0.7108, 0.7574, 0.5188, 0.8079,
         0.5812, 0.6239, 0.6074, 0.6567, 0.6519, 0.6348, 0.6572, 0.6680, 0.6160,
         0.6235, 0.6290, 0.6468],
        [0.7182, 0.6998, 0.7156, 0.6991, 0.6183, 0.2626, 0.2836, 0.3091, 1.1766,
         0.7255, 0.6812, 0.6920, 0.5797, 0.6859, 0.6835, 0.6246, 0.6906, 0.6999,
         0.7528, 0.7062, 0.5898],
        [0.5585, 0.5849, 0.5826, 0.6588, 0.3421, 0.0000, 0.0000, 0.0000, 0.7393,
         0.4589, 0.6556, 0.6438, 0.6986, 0.6595, 0.6697, 0.7206, 0.6588, 0.6598,
         0.6342, 0.5828, 0.6038],
        [0.6537, 0.6701, 0.7223, 0.4426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0431,
         0.6704, 0.6755, 0.7621, 0.7437, 0.7421, 0.6089, 0.7513, 0.6691, 0.6603,
         0.6331,

In [91]:
px.histogram(diffs.flatten()).show()