In [1]:
import torch
import torch.nn as nn

In [2]:
from backbone import Backbone
from neck import Neck

In [3]:
back_model = Backbone(hid_dim=96, layers=(2, 2, 2, 2), heads=(3, 6, 12, 24))
# back_out, feature_maps = back_model(torch.randn(1, 3, 1024, 1024))

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
back_out.shape, [f.shape for f in feature_maps]

In [4]:
neck_model = Neck(hid_dim=96, layers=(2, 2, 2, 2), heads=(3, 6, 12, 24), channels=768)

In [None]:
neck_out = neck_model(back_out.permute(0, 3, 1, 2), feature_maps)

In [None]:
neck_out.shape

In [5]:
from model import Model
from head import Head

In [6]:
head_model = Head(in_channels=96, num_classes=1)

In [7]:
mod = Model(back_model, neck_model, head_model)

In [8]:
mod_out = mod(torch.randn(1, 3, 1024, 1024))

Input shape: torch.Size([1, 32, 32, 768]), Feature map shape: torch.Size([1, 32, 32, 768])
features:  torch.Size([1, 256, 256, 96])
torch.Size([1, 96, 256, 256])
1


In [11]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(back_model)

+-------------------------------------------------------------+------------+
|                           Modules                           | Parameters |
+-------------------------------------------------------------+------------+
|       model.stage1.patch_partition.patch_merge.weight       |    4608    |
|        model.stage1.patch_partition.patch_merge.bias        |     96     |
|    model.stage1.layers.0.0.attention_block.fn.norm.weight   |     96     |
|     model.stage1.layers.0.0.attention_block.fn.norm.bias    |     96     |
|    model.stage1.layers.0.0.attention_block.fn.fn.pos_emb    |     9      |
|      model.stage1.layers.0.0.attention_block.fn.fn.tau      |     1      |
| model.stage1.layers.0.0.attention_block.fn.fn.to_qkv.weight |   27648    |
| model.stage1.layers.0.0.attention_block.fn.fn.to_out.weight |    9216    |
|  model.stage1.layers.0.0.attention_block.fn.fn.to_out.bias  |     96     |
|       model.stage1.layers.0.0.mlp_block.fn.norm.weight      |     96     |

20383856

In [12]:
count_parameters(neck_model)

+-------------------------------------------------------------+------------+
|                           Modules                           | Parameters |
+-------------------------------------------------------------+------------+
|    model.stage1.layers.0.0.attention_block.fn.norm.weight   |    768     |
|     model.stage1.layers.0.0.attention_block.fn.norm.bias    |    768     |
|    model.stage1.layers.0.0.attention_block.fn.fn.pos_emb    |     9      |
|      model.stage1.layers.0.0.attention_block.fn.fn.tau      |     1      |
| model.stage1.layers.0.0.attention_block.fn.fn.to_qkv.weight |   221184   |
| model.stage1.layers.0.0.attention_block.fn.fn.to_out.weight |   73728    |
|  model.stage1.layers.0.0.attention_block.fn.fn.to_out.bias  |    768     |
|       model.stage1.layers.0.0.mlp_block.fn.norm.weight      |    768     |
|        model.stage1.layers.0.0.mlp_block.fn.norm.bias       |    768     |
|   model.stage1.layers.0.0.mlp_block.fn.fn.network.0.weight  |   294912   |

7106768