In [1]:
from collections import OrderedDict

import torch
from torch import nn
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt

from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection import FasterRCNN
import torchvision.models.detection.image_list as image_list
from torchvision.models.detection.rpn import AnchorGenerator
from jde_rcnn import Jde_RCNN

from utils.datasets import LoadImagesAndLabels, collate_fn

In [2]:
backbone = resnet_fpn_backbone('resnet50', True)
backbone.out_channels = 256

In [3]:
model = Jde_RCNN(backbone, num_ID=100)


In [4]:
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
targets = [OrderedDict(boxes=torch.FloatTensor([[3.9566e+02, 2.3470e+02, 4.0000e+02, 2.4286e+02]]), labels=torch.LongTensor([1])), OrderedDict(boxes=torch.FloatTensor([[3.9721e+02, 1.1319e+01, 4.0000e+02, 1.9564e+01]]), labels=torch.LongTensor([2]))]
targets[0]


OrderedDict([('boxes', tensor([[395.6600, 234.7000, 400.0000, 242.8600]])),
             ('labels', tensor([1]))])

In [5]:
## training

losses = model(x, targets)
losses

{'loss_box_reg': tensor(4.9268e-06, grad_fn=<DivBackward0>),
 'loss_classifier': tensor(0.7046, grad_fn=<NllLossBackward>),
 'loss_objectness': tensor(0.6933, grad_fn=<BinaryCrossEntropyWithLogitsBackward>),
 'loss_reid': tensor(4.5977, grad_fn=<NllLossBackward>),
 'loss_rpn_box_reg': tensor(0.2620, grad_fn=<DivBackward0>)}

In [4]:
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)
predictions

[{'boxes': tensor([[3.9566e+02, 2.3470e+02, 4.0000e+02, 2.4286e+02],
          [3.9721e+02, 1.1319e+01, 4.0000e+02, 1.9564e+01],
          [3.9252e+02, 2.1509e+02, 4.0000e+02, 2.2330e+02],
          [3.9716e+02, 2.3614e+02, 4.0000e+02, 2.4439e+02],
          [3.9823e+02, 2.3759e+02, 4.0000e+02, 2.4579e+02],
          [3.9384e+02, 2.1175e+02, 4.0000e+02, 2.2809e+02],
          [3.9845e+02, 2.5485e+02, 4.0000e+02, 2.7114e+02],
          [3.9380e+02, 2.4303e+02, 4.0000e+02, 2.5939e+02],
          [3.9815e+02, 1.2736e+01, 4.0000e+02, 2.0915e+01],
          [3.9258e+02, 2.1209e+02, 4.0000e+02, 2.2030e+02],
          [3.9106e+02, 6.7673e+00, 4.0000e+02, 1.4984e+01],
          [3.9563e+02, 2.5266e+02, 4.0000e+02, 2.6086e+02],
          [3.9577e+02, 9.3759e+01, 4.0000e+02, 1.0189e+02],
          [3.9563e+02, 2.5564e+02, 4.0000e+02, 2.6384e+02],
          [3.9825e+02, 9.3600e+01, 4.0000e+02, 1.0183e+02],
          [3.9820e+02, 2.8253e+02, 4.0000e+02, 2.9072e+02],
          [3.9399e+02, 2.3162e+

In [5]:
x = torch.rand(1, 3, 1024, 768)
output = backbone(x)
type(output)

collections.OrderedDict

In [6]:
for j,k in output.items():
    print(j,k.shape)

0 torch.Size([1, 256, 256, 192])
1 torch.Size([1, 256, 128, 96])
2 torch.Size([1, 256, 64, 48])
3 torch.Size([1, 256, 32, 24])
pool torch.Size([1, 256, 16, 12])


In [7]:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((1/3),)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)

In [8]:
image_sizes = [(640,640)]
image_list_ = image_list.ImageList(x,image_sizes)
anchor = anchor_generator(image_list_, list(output.values()))

In [9]:
for i in anchor:
    print(i.shape)

torch.Size([65472, 4])


In [2]:
path = './data/citypersons/trainImages.txt'
data = LoadImagesAndLabels(path, (2048,1024))

In [3]:
dataloader = torch.utils.data.DataLoader(data, batch_size=16, shuffle=True,
                                             num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)

In [24]:
for i, (imgs, targets, _, _, targets_len) in enumerate(dataloader):
    if i%100==0:
        print(i)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/hunter/miniconda3/envs/mot/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/hunter/miniconda3/envs/mot/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/hunter/Document/torch/myproject/utils/datasets.py", line 329, in collate_fn
    imgs = torch.stack(imgs, 0)
TypeError: expected Tensor as element 0 in argument 0, but got numpy.ndarray
