In [13]:
import torch as T
from torch import nn
from torchinfo import summary
import logging

In [2]:
from src.models import SwinTransformer

In [14]:
logging.basicConfig(level=logging.DEBUG)


In [43]:
model_args = dict(
    img_size=28, 
    patch_size=2, 
    in_chans=1, 
    num_classes=10, 
    global_pool='avg',
    embed_dim=64, 
    depths=[2,], 
    num_heads=[8,], 
    head_dim=None,
    window_size=2, 
    mlp_ratio=4., 
    qkv_bias=True,
    drop_rate=0.1, 
    attn_drop_rate=0.1, 
    drop_path_rate=0.1,
    norm_layer=nn.LayerNorm, 
    ape=False, 
    patch_norm=True
)

In [44]:
model = SwinTransformer(**model_args)

DEBUG:src.models.swin_transformer:Input for window partition shape - torch.Size([1, 14, 14, 1])
DEBUG:src.models.swin_transformer:trying to change shape to 7, 2, 7, 2, 1


In [46]:
model

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 64, kernel_size=(2, 2), stride=(2, 2))
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.1, inplace=False)
  (layers): Sequential(
    (0): BasicLayer(
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=64, out_features=192, bias=True)
            (attn_drop): Dropout(p=0.1, inplace=False)
            (proj): Linear(in_features=64, out_features=64, bias=True)
            (proj_drop): Dropout(p=0.1, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=64, out_features=256, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.1, i

In [45]:
summary(model)

Layer (type:depth-idx)                             Param #
SwinTransformer                                    --
├─PatchEmbed: 1-1                                  --
│    └─Conv2d: 2-1                                 320
│    └─LayerNorm: 2-2                              128
├─Dropout: 1-2                                     --
├─Sequential: 1-3                                  --
│    └─BasicLayer: 2-3                             --
│    │    └─Sequential: 3-1                        100,112
├─LayerNorm: 1-4                                   128
├─Linear: 1-5                                      650
Total params: 101,194
Trainable params: 101,194
Non-trainable params: 0

In [33]:
inp = T.rand((1,1,28, 28))

In [26]:
inp.shape

torch.Size([1, 1, 24, 24])

In [47]:
a = model.forward(inp)

DEBUG:src.models.swin_transformer:Input shape - torch.Size([1, 1, 28, 28])
DEBUG:src.models.swin_transformer:After patch embedding - torch.Size([1, 196, 64])
DEBUG:src.models.swin_transformer:After postitional embed dropout - torch.Size([1, 196, 64])
DEBUG:src.models.swin_transformer:Input for Basic Layer - torch.Size([1, 196, 64])
DEBUG:src.models.swin_transformer:Input for window partition shape - torch.Size([1, 14, 14, 64])
DEBUG:src.models.swin_transformer:trying to change shape to 7, 2, 7, 2, 64
DEBUG:src.models.swin_transformer:Input for window partition shape - torch.Size([1, 14, 14, 64])
DEBUG:src.models.swin_transformer:trying to change shape to 7, 2, 7, 2, 64
DEBUG:src.models.swin_transformer:Output from Block Layer - torch.Size([1, 196, 64])
DEBUG:src.models.swin_transformer:After Swin Layers - torch.Size([1, 196, 64])
DEBUG:src.models.swin_transformer:After Norm - torch.Size([1, 196, 64])
DEBUG:src.models.swin_transformer:After Forward feature - torch.Size([1, 196, 64])
  r

In [28]:
12 * 12

144

In [None]:
def build_model_with_cfg(
        model_cls: Callable,
        variant: str,
        pretrained: bool,
        pretrained_cfg: Optional[Dict] = None,
        model_cfg: Optional[Any] = None,
        feature_cfg: Optional[Dict] = None,
        pretrained_strict: bool = False,
        pretrained_filter_fn: Optional[Callable] = None,
        pretrained_custom_load: bool = False,
        kwargs_filter: Optional[Tuple[str]] = None,
        **kwargs):
    """ Build model with specified default_cfg and optional model_cfg
    This helper fn aids in the construction of a model including:
      * handling default_cfg and associated pretrained weight loading
      * passing through optional model_cfg for models with config based arch spec
      * features_only model adaptation
      * pruning config / model adaptation
    Args:
        model_cls (nn.Module): model class
        variant (str): model variant name
        pretrained (bool): load pretrained weights
        pretrained_cfg (dict): model's pretrained weight/task config
        model_cfg (Optional[Dict]): model's architecture config
        feature_cfg (Optional[Dict]: feature extraction adapter config
        pretrained_strict (bool): load pretrained weights strictly
        pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
        pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
        kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
        **kwargs: model args passed through to model __init__
    """
    pruned = kwargs.pop('pruned', False)
    features = False
    feature_cfg = feature_cfg or {}

    # resolve and update model pretrained config and model kwargs
    pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg)
    update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter)
    pretrained_cfg.setdefault('architecture', variant)

    # Setup for feature extraction wrapper done at end of this fn
    if kwargs.pop('features_only', False):
        features = True
        feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
        if 'out_indices' in kwargs:
            feature_cfg['out_indices'] = kwargs.pop('out_indices')

    # Build the model
    model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg  # alias for backwards compat
    
    if pruned:
        model = adapt_model_from_file(model, variant)

    # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
    num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
    if pretrained:
        if pretrained_custom_load:
            # FIXME improve custom load trigger
            load_custom_pretrained(model, pretrained_cfg=pretrained_cfg)
        else:
            load_pretrained(
                model,
                pretrained_cfg=pretrained_cfg,
                num_classes=num_classes_pretrained,
                in_chans=kwargs.get('in_chans', 3),
                filter_fn=pretrained_filter_fn,
                strict=pretrained_strict)

    # Wrap the model in a feature extraction module if enabled
    if features:
        feature_cls = FeatureListNet
        if 'feature_cls' in feature_cfg:
            feature_cls = feature_cfg.pop('feature_cls')
            if isinstance(feature_cls, str):
                feature_cls = feature_cls.lower()
                if 'hook' in feature_cls:
                    feature_cls = FeatureHookNet
                elif feature_cls == 'fx':
                    feature_cls = FeatureGraphNet
                else:
                    assert False, f'Unknown feature class {feature_cls}'
        model = feature_cls(model, **feature_cfg)
        model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg)  # add back default_cfg
        model.default_cfg = model.pretrained_cfg  # alias for backwards compat
    
    return model