In [1]:
import torch
import torchinfo
search = torch.rand(8, 3, 256, 256)
template = torch.rand(8, 3, 128, 128)

In [2]:
from tracking.basic_model.et_tracker import ET_Tracker
from tracking.basic_model.cmt_et_tracker import CMT_ET_Tracker
from tracking.basic_model.wavelet_et_tracker import WAVE_ET_Tracker

In [3]:
ettracker = ET_Tracker(linear_reg=True)
cmttracker = CMT_ET_Tracker(linear_reg=True)
wavetracker = WAVE_ET_Tracker(linear_reg=True)

In [4]:
torchinfo.summary(ettracker, input_size=((1, 3, 128, 128), (1, 3, 256, 256)))

Layer (type:depth-idx)                                  Output Shape              Param #
ET_Tracker                                              [1, 4, 16, 16]            --
├─ChildNet_FCN: 1-1                                     [1, 96, 8, 8]             --
│    └─Conv2d: 2-1                                      [1, 16, 64, 64]           432
│    └─BatchNorm2d: 2-2                                 [1, 16, 64, 64]           32
│    └─Swish: 2-3                                       [1, 16, 64, 64]           --
│    └─Sequential: 2-4                                  [1, 96, 8, 8]             --
│    │    └─Sequential: 3-1                             [1, 16, 64, 64]           744
│    │    └─Sequential: 3-2                             [1, 24, 32, 32]           34,424
│    │    └─Sequential: 3-3                             [1, 40, 16, 16]           159,168
│    │    └─Sequential: 3-4                             [1, 80, 8, 8]             387,800
│    │    └─Sequential: 3-5                 

In [5]:
torchinfo.summary(cmttracker, input_size=((1, 3, 128, 128), (1, 3, 256, 256)))

Layer (type:depth-idx)                                  Output Shape              Param #
CMT_ET_Tracker                                          [1, 4, 16, 16]            --
├─ChildNet_FCN: 1-1                                     [1, 96, 8, 8]             --
│    └─Conv2d: 2-1                                      [1, 16, 64, 64]           432
│    └─BatchNorm2d: 2-2                                 [1, 16, 64, 64]           32
│    └─Swish: 2-3                                       [1, 16, 64, 64]           --
│    └─Sequential: 2-4                                  [1, 96, 8, 8]             --
│    │    └─Sequential: 3-1                             [1, 16, 64, 64]           744
│    │    └─Sequential: 3-2                             [1, 24, 32, 32]           34,424
│    │    └─Sequential: 3-3                             [1, 40, 16, 16]           159,168
│    │    └─Sequential: 3-4                             [1, 80, 8, 8]             387,800
│    │    └─Sequential: 3-5                 

In [6]:
torchinfo.summary(wavetracker, input_size=((1, 3, 128, 128), (1, 3, 256, 256)))

Layer (type:depth-idx)                                  Output Shape              Param #
WAVE_ET_Tracker                                         [1, 4, 16, 16]            --
├─ChildNet_FCN: 1-1                                     [1, 96, 8, 8]             --
│    └─Conv2d: 2-1                                      [1, 16, 64, 64]           432
│    └─BatchNorm2d: 2-2                                 [1, 16, 64, 64]           32
│    └─Swish: 2-3                                       [1, 16, 64, 64]           --
│    └─Sequential: 2-4                                  [1, 96, 8, 8]             --
│    │    └─Sequential: 3-1                             [1, 16, 64, 64]           744
│    │    └─Sequential: 3-2                             [1, 24, 32, 32]           34,424
│    │    └─Sequential: 3-3                             [1, 40, 16, 16]           159,168
│    │    └─Sequential: 3-4                             [1, 80, 8, 8]             387,800
│    │    └─Sequential: 3-5                 

In [7]:
outputs={}

ettracker.to("cpu")

for param in ettracker.parameters():
    param.data = param.data.to("cpu")

outputs["et_output"] = ettracker(template, search)

cmttracker.to("cpu")
for param in cmttracker.parameters():
    param.data = param.data.to("cpu")
    
outputs["cmt_output"] = cmttracker(template, search)

wavetracker.to("cpu")
for param in wavetracker.parameters():
    param.data = param.data.to("cpu")
    
outputs["wave_output"] = wavetracker(template, search)

In [8]:
outputs.keys()

dict_keys(['et_output', 'cmt_output', 'wave_output'])

In [9]:
for key, val in outputs.items():
    print(f"For Model {key}:")
    for item in val.values():
        print(item.shape)

For Model et_output:
torch.Size([8, 1, 16, 16])
torch.Size([8, 4, 16, 16])
For Model cmt_output:
torch.Size([8, 1, 16, 16])
torch.Size([8, 4, 16, 16])
For Model wave_output:
torch.Size([8, 1, 16, 16])
torch.Size([8, 4, 16, 16])
