In [1]:
import os
import gc
import sys
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import utils, transforms, datasets
from torchvision.models.detection.rpn import AnchorGenerator

from transforms import ToTensor, Resize, Compose, testTensor
from data.kitti_dataset import KITTI
from backbone.backbone_vgg import vgg16, vgg11, vgg13, vgg19
from backbone.backbone_resnet import resnet18, resnet34
from models.faster_rcnn_mod import FasterRCNN
from models.mask_rcnn_mod import MaskRCNN
from models.keypoint_rcnn_mod import KeypointRCNN
from models.retinanet_mod import RetinaNet

from ID.intrinsic_dimension import estimate, block_analysis
from scipy.spatial.distance import pdist,squareform

from tqdm import tqdm
from PIL import Image

In [2]:
backbone = vgg19(pretrained=False).features
backbone.out_channels = 512
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5,1.0,2.0),)) 
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2)

In [3]:
root = os.getcwd()
model_path = os.path.join(root, "trained_model", "kitti", "retinanet", "vgg19")

In [4]:
#model = FasterRCNN(backbone, num_classes=10,rpn_anchor_generator= anchor_generator, box_roi_pool=roi_pooler )
model = RetinaNet(backbone, num_classes=10, anchor_generator=anchor_generator)
model.load_state_dict(torch.load(os.path.join(model_path, 'model.pt'), map_location=torch.device('cpu')))

<All keys matched successfully>

In [None]:
class WrappedModel(nn.Module):
	def __init__(self):
		super(WrappedModel, self).__init__()
		self.module = FasterRCNN(backbone, num_classes=10,rpn_anchor_generator= anchor_generator, box_roi_pool=roi_pooler ) # that I actually define.
	def forward(self, x):
		return self.module(x)

model = WrappedModel()
model.load_state_dict(torch.load(os.path.join(model_path, 'checkpoint_60.pt'), map_location=torch.device('cpu'))['model_state_dict'])

In [5]:
model.eval()

RetinaNet(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)

In [None]:
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [6]:
class test_data(object):
    def __init__(self, path, transforms):
        self.path = path
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(path)))
       
    
    def __getitem__(self, idx):
        print(idx)
        img_path = os.path.join(self.path, self.imgs[idx])
        
        image = Image.open(img_path).convert('RGB')
        if self.transforms is not None:
            image = self.transforms(image)
        return image
    
    def __len__(self):
        return len(self.imgs)

In [7]:
image_directory = 'D:/Dataset/KITTI/testing/horizontal_shift'

data = test_data(image_directory, transforms= testTensor())
indices = torch.randperm(len(data))
dataset = torch.utils.data.Subset(data, indices[:24])

def collate_fn(batch):
    return tuple(zip(*batch))
    

In [8]:
testdata = DataLoader(dataset, batch_size=1, shuffle=True)

In [27]:
image = next(iter(testdata))

tensor(2744)


In [28]:
out0 = image
print(out0.shape)
out1 = model.backbone[4](F.relu(model.backbone[2](F.relu(model.backbone[0](out0)))))
print(out1.shape)
out2 = model.backbone[9](F.relu(model.backbone[7](F.relu(model.backbone[5](out1)))))
print(out2.shape)
out3 = model.backbone[18](F.relu(model.backbone[16](F.relu(model.backbone[14](F.relu(model.backbone[12](F.relu(model.backbone[10](out2)))))))))        
print(out3.shape)
out4 = model.backbone[27](F.relu(model.backbone[25](F.relu(model.backbone[23](F.relu(model.backbone[21](F.relu(model.backbone[19](out3)))))))))    
print(out4.shape)
out5 = model.backbone[36](F.relu(model.backbone[34](F.relu(model.backbone[32](F.relu(model.backbone[30](F.relu(model.backbone[28](out4)))))))))
print(out5.shape)
out6 = F.relu(model.head.classification_head.conv[6](F.relu(model.head.classification_head.conv[4](F.relu(model.head.classification_head.conv[2](F.relu(model.head.classification_head.conv[0](out5))))))))
print(out6.shape)
out7 = model.head.classification_head.cls_logits(out6)
print(out7.shape)
out8 = F.relu(model.head.regression_head.conv[6](F.relu(model.head.regression_head.conv[4](F.relu(model.head.regression_head.conv[2](F.relu(model.head.regression_head.conv[0](out5))))))))
print(out8.shape)
out9 = model.head.regression_head.bbox_reg(out8)
print(out9.shape)

torch.Size([1, 3, 1200, 1200])
torch.Size([1, 64, 600, 600])
torch.Size([1, 128, 300, 300])
torch.Size([1, 256, 150, 150])
torch.Size([1, 512, 75, 75])
torch.Size([1, 512, 37, 37])
torch.Size([1, 512, 37, 37])
torch.Size([1, 150, 37, 37])
torch.Size([1, 512, 37, 37])
torch.Size([1, 60, 37, 37])


In [22]:
def computeID(r, nres, fraction):
    ID = []
    n = int(np.round(r.shape[0]*fraction))
    print(n)
    dist = squareform(pdist(r, 'euclidean'))
    print(dist)
    for i in range(nres):
        dist_s = dist
        perm = np.random.permutation(dist.shape[0])[0:n]
        print(perm)
        dist_s = dist_s[perm,:]
        dist_s = dist_s[:,perm]
        print(dist_s)
        ID.append(estimate(dist_s)[2])
    mean = np.mean(ID)
    error = np.std(ID)
    return mean,error

In [19]:
Out0 = out0.view(image.shape[0], -1).cpu().data
Out1 = out1.view(image.shape[0], -1).cpu().data
Out2 = out2.view(image.shape[0], -1).cpu().data
Out3 = out3.view(image.shape[0], -1).cpu().data
Out4 = out4.view(image.shape[0], -1).cpu().data
Out5 = out5.view(image.shape[0], -1).cpu().data
Out6 = out6.view(image.shape[0], -1).cpu().data
Out7 = out7.view(image.shape[0], -1).cpu().data
Out8 = out8.view(image.shape[0], -1).cpu().data
Out9 = out9.view(image.shape[0], -1).cpu().data
#Out10 = out10.view(image.shape[0], -1).cpu().data
#Out11 = out11.view(image.shape[0], -1).cpu().data

In [29]:
Out0 = torch.cat((Out0, out0.view(image.shape[0], -1).cpu().data),0)
Out1 = torch.cat((Out1, out1.view(image.shape[0], -1).cpu().data),0) 
Out2 = torch.cat((Out2, out2.view(image.shape[0], -1).cpu().data),0) 
Out3 = torch.cat((Out3, out3.view(image.shape[0], -1).cpu().data),0)                 
Out4 = torch.cat((Out4, out4.view(image.shape[0], -1).cpu().data),0) 
Out5 = torch.cat((Out5, out5.view(image.shape[0], -1).cpu().data),0) 
Out6 = torch.cat((Out6, out6.view(image.shape[0], -1).cpu().data),0) 
Out7 = torch.cat((Out7, out7.view(image.shape[0], -1).cpu().data),0)                 
Out8 = torch.cat((Out8, out8.view(image.shape[0], -1).cpu().data),0)
Out9 = torch.cat((Out9, out9.view(image.shape[0], -1).cpu().data),0) 
#Out10 = torch.cat((Out10, out10.view(image.shape[0], -1).cpu().data),0)                 
#Out11 = torch.cat((Out11, out11.view(image.shape[0], -1).cpu().data),0)

In [30]:
out = [Out0, Out1, Out2, Out3, Out4, Out5, Out6, Out7, Out8, Out9]

In [31]:
ID_all = []
for j in tqdm(range(0,12)):
    r = out[j]
    print(r)
    ID_all.append(computeID(r, 20, 0.9))
ID_all = np.array(ID_all)

  8%|▊         | 1/12 [00:00<00:01,  6.89it/s]

tensor([[0.5294, 0.5294, 0.5294,  ..., 0.2392, 0.2392, 0.2392],
        [0.4824, 0.4824, 0.4824,  ..., 0.1176, 0.1176, 0.1216],
        [0.2000, 0.1882, 0.2078,  ..., 0.0275, 0.0275, 0.0235],
        [1.0000, 1.0000, 1.0000,  ..., 0.4039, 0.4000, 0.4000]])
4
[[   0.          808.96829798  963.16618985  994.69562326]
 [ 808.96829798    0.          678.49197041  800.74922345]
 [ 963.16618985  678.49197041    0.         1028.47368525]
 [ 994.69562326  800.74922345 1028.47368525    0.        ]]
[1 0 2 3]
[[   0.          808.96829798  678.49197041  800.74922345]
 [ 808.96829798    0.          963.16618985  994.69562326]
 [ 678.49197041  963.16618985    0.         1028.47368525]
 [ 800.74922345  994.69562326 1028.47368525    0.        ]]
[0 2 3 1]
[[   0.          963.16618985  994.69562326  808.96829798]
 [ 963.16618985    0.         1028.47368525  678.49197041]
 [ 994.69562326 1028.47368525    0.          800.74922345]
 [ 808.96829798  678.49197041  800.74922345    0.        ]]
[2 0 3 1]


 17%|█▋        | 2/12 [00:00<00:04,  2.20it/s]

[[   0.         1174.02714423 1313.02625636 1239.15149002]
 [1174.02714423    0.         1223.90287462 1171.84099985]
 [1313.02625636 1223.90287462    0.         1337.12814542]
 [1239.15149002 1171.84099985 1337.12814542    0.        ]]
[0 1 2 3]
[[   0.         1174.02714423 1313.02625636 1239.15149002]
 [1174.02714423    0.         1223.90287462 1171.84099985]
 [1313.02625636 1223.90287462    0.         1337.12814542]
 [1239.15149002 1171.84099985 1337.12814542    0.        ]]
[3 0 1 2]
[[   0.         1239.15149002 1171.84099985 1337.12814542]
 [1239.15149002    0.         1174.02714423 1313.02625636]
 [1171.84099985 1174.02714423    0.         1223.90287462]
 [1337.12814542 1313.02625636 1223.90287462    0.        ]]
[0 2 1 3]
[[   0.         1313.02625636 1174.02714423 1239.15149002]
 [1313.02625636    0.         1223.90287462 1337.12814542]
 [1174.02714423 1223.90287462    0.         1171.84099985]
 [1239.15149002 1337.12814542 1171.84099985    0.        ]]
[3 1 2 0]
[[   0.     

 25%|██▌       | 3/12 [00:01<00:03,  2.40it/s]

[[   0.         1994.51085716 2331.43820959 2179.0271528 ]
 [1994.51085716    0.         2078.31185749 1920.51176346]
 [2331.43820959 2078.31185749    0.         2322.30941914]
 [2179.0271528  1920.51176346 2322.30941914    0.        ]]
[0 3 2 1]
[[   0.         2179.0271528  2331.43820959 1994.51085716]
 [2179.0271528     0.         2322.30941914 1920.51176346]
 [2331.43820959 2322.30941914    0.         2078.31185749]
 [1994.51085716 1920.51176346 2078.31185749    0.        ]]
[2 3 1 0]
[[   0.         2322.30941914 2078.31185749 2331.43820959]
 [2322.30941914    0.         1920.51176346 2179.0271528 ]
 [2078.31185749 1920.51176346    0.         1994.51085716]
 [2331.43820959 2179.0271528  1994.51085716    0.        ]]
[2 3 0 1]
[[   0.         2322.30941914 2331.43820959 2078.31185749]
 [2322.30941914    0.         2179.0271528  1920.51176346]
 [2331.43820959 2179.0271528     0.         1994.51085716]
 [2078.31185749 1920.51176346 1994.51085716    0.        ]]
[3 0 2 1]
[[   0.     

 33%|███▎      | 4/12 [00:01<00:02,  2.93it/s]

[[   0.         2144.00475697 2263.19791304 2415.15780201]
 [2144.00475697    0.         2082.38581208 2263.57479872]
 [2263.19791304 2082.38581208    0.         2391.84754139]
 [2415.15780201 2263.57479872 2391.84754139    0.        ]]
[1 0 3 2]
[[   0.         2144.00475697 2263.57479872 2082.38581208]
 [2144.00475697    0.         2415.15780201 2263.19791304]
 [2263.57479872 2415.15780201    0.         2391.84754139]
 [2082.38581208 2263.19791304 2391.84754139    0.        ]]
[2 0 3 1]
[[   0.         2263.19791304 2391.84754139 2082.38581208]
 [2263.19791304    0.         2415.15780201 2144.00475697]
 [2391.84754139 2415.15780201    0.         2263.57479872]
 [2082.38581208 2144.00475697 2263.57479872    0.        ]]
[2 0 3 1]
[[   0.         2263.19791304 2391.84754139 2082.38581208]
 [2263.19791304    0.         2415.15780201 2144.00475697]
 [2391.84754139 2415.15780201    0.         2263.57479872]
 [2082.38581208 2144.00475697 2263.57479872    0.        ]]
[3 0 1 2]
[[   0.     

 58%|█████▊    | 7/12 [00:01<00:00,  5.82it/s]

[2 0 1 3]
[[  0.         568.39083245 516.83166537 639.37841428]
 [568.39083245   0.         540.50274131 653.83160474]
 [516.83166537 540.50274131   0.         615.94805384]
 [639.37841428 653.83160474 615.94805384   0.        ]]
[2 0 1 3]
[[  0.         568.39083245 516.83166537 639.37841428]
 [568.39083245   0.         540.50274131 653.83160474]
 [516.83166537 540.50274131   0.         615.94805384]
 [639.37841428 653.83160474 615.94805384   0.        ]]
[2 3 0 1]
[[  0.         639.37841428 568.39083245 516.83166537]
 [639.37841428   0.         653.83160474 615.94805384]
 [568.39083245 653.83160474   0.         540.50274131]
 [516.83166537 615.94805384 540.50274131   0.        ]]
[2 1 0 3]
[[  0.         516.83166537 568.39083245 639.37841428]
 [516.83166537   0.         540.50274131 615.94805384]
 [568.39083245 540.50274131   0.         653.83160474]
 [639.37841428 615.94805384 653.83160474   0.        ]]
[0 1 2 3]
[[  0.         540.50274131 568.39083245 653.83160474]
 [540.50274

 83%|████████▎ | 10/12 [00:01<00:00,  5.32it/s]


[[  0.         351.47706036 338.36703256 230.51252389]
 [351.47706036   0.         438.08008662 352.66418619]
 [338.36703256 438.08008662   0.         285.57708411]
 [230.51252389 352.66418619 285.57708411   0.        ]]
[1 0 3 2]
[[  0.         230.51252389 351.47706036 338.36703256]
 [230.51252389   0.         352.66418619 285.57708411]
 [351.47706036 352.66418619   0.         438.08008662]
 [338.36703256 285.57708411 438.08008662   0.        ]]
[3 2 1 0]
[[  0.         438.08008662 351.47706036 352.66418619]
 [438.08008662   0.         338.36703256 285.57708411]
 [351.47706036 338.36703256   0.         230.51252389]
 [352.66418619 285.57708411 230.51252389   0.        ]]
[0 3 1 2]
[[  0.         352.66418619 230.51252389 285.57708411]
 [352.66418619   0.         351.47706036 438.08008662]
 [230.51252389 351.47706036   0.         338.36703256]
 [285.57708411 438.08008662 338.36703256   0.        ]]
[0 1 3 2]
[[  0.         230.51252389 352.66418619 285.57708411]
 [230.51252389   0. 




IndexError: list index out of range

In [32]:
ID_all

[(2.91245597650893, 0.0),
 (13.007608147211874, 1.7763568394002505e-15),
 (7.799266975461698, 8.881784197001252e-16),
 (12.148286037675783, 1.7763568394002505e-15),
 (12.291145559515655, 0.0),
 (27.08703253851392, 0.0),
 (4.901854063944784, 8.881784197001252e-16),
 (4.118577803823757, 8.881784197001252e-16),
 (61.986979348515376, 7.105427357601002e-15),
 (31.124509191323842, 7.105427357601002e-15)]

1