In [1]:
!pip install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


In [2]:
import torch

num_classes = 15
IMAGE_SIZE = (224, 224)
# IMAGE_SIZE[0], IMAGE_SIZE[1]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from thop import profile
from thop import clever_format
import torch

def display_params_flops(model):
    #params
    num_params = sum(p.numel() for p in model.parameters())
    num_params_millions = num_params / 1e6
    print(f"Number of parameters in millions: {num_params_millions:.2f} M")

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_params_millions = num_params / 1e6
    print(f"Number of trainable parameters in millions: {num_params_millions:.2f} M")


    #FLOPS
    input_size = (1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1])  


    # Move the model to GPU if available
    if torch.cuda.is_available():
        model = model.cuda()

    # Use thop.profile to count FLOPs
    input_tensor = torch.randn(*input_size)
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()
    flops, params = profile(model, inputs=(input_tensor,))

    # Convert FLOPs to gigaFLOPs and format the results
    flops, params = clever_format([flops, params], "%.2f")
    print(f"FLOPs: {flops}, Params: {params}")
    

### CNN Models

### 1. VGG-19

In [4]:
import torch
import torch.nn as nn
import torchvision.models as models


class VGG19Model(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(VGG19Model, self).__init__()

        # Add a convolutional layer at the top
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)  # Assuming input is grayscale (1 channel)

        self.vgg19 = models.vgg19(pretrained=True)
        
        if not fine_tune:
            # Freeze all layers except classifier layers
            for param in self.vgg19.parameters():
                param.requires_grad = False

            # Unfreeze the classifier layers
            for param in self.vgg19.classifier.parameters():
                param.requires_grad = True
            

        # Get the number of input features for the final fully connected layer
        in_features = self.vgg19.classifier[6].in_features

        # Replace the final fully connected layer with a new one for the specified number of classes
        self.vgg19.classifier[6] = nn.Linear(in_features, num_classes)
        
        

    def forward(self, x):
        x = self.conv(x)
        x = self.vgg19(x)
        return x
    
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG19Model(num_classes, fine_tune=False) # change fine_tune as required
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 148MB/s]



Model output's shape: torch.Size([1, 15])
tensor([[ 0.1141, -0.1989,  0.0149,  0.3392,  0.0469, -0.3226,  0.3937,  0.6723,
         -0.3124, -0.3450,  0.0357, -0.4151, -0.0462,  0.2496,  0.0574]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 139.63 M
Number of trainable parameters in millions: 119.61 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
FLOPs: 19.63G, Params: 139.63M


### 2. EfficientNet

In [5]:
!pip install efficientnet_pytorch # for installing efficientnet model

import torch
import torch.nn as nn
import torchvision.models as models
from efficientnet_pytorch import EfficientNet

class EfficientNetModel(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(EfficientNetModel, self).__init__()
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b0', num_classes=num_classes, in_channels=1)
        
        if not fine_tune:
            for param in self.efficientnet.parameters():
                param.requires_grad = False

            for param in self.efficientnet._fc.parameters():
                param.requires_grad = True

        
        
    def forward(self, x):
        x = self.efficientnet(x)
        return x


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientNetModel(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l- done
Building wheels for collected packages: efficientnet_pytorch
  Building wheel for efficientnet_pytorch (setup.py) ... [?25l- done
[?25h  Created wheel for efficientnet_pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16428 sha256=5aa126a5656fd926a8c8974a568d3b2086ddf3843a5f110abab1d0e0006f16c4
  Stored in directory: /root/.cache/pip/wheels/03/3f/e9/911b1bc46869644912bda90a56bcf7b960f20b5187feea3baf
Successfully built efficientnet_pytorch
Installing collected packages: efficientnet_pytorch
Successfully installed efficientnet_pytorch-0.7.1


Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth
100%|██████████| 20.4M/20.4M [00:00<00:00, 155MB/s]


Loaded pretrained weights for efficientnet-b0

Model output's shape: torch.Size([1, 15])
tensor([[-0.0184, -0.1523,  0.0733,  0.0891, -0.0507, -0.0303, -0.0032, -0.0084,
          0.0310, -0.0577, -0.0176,  0.0120, -0.0025, -0.0212, -0.0970]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 4.03 M
Number of trainable parameters in millions: 0.02 M
[INFO] Register zero_ops() for <class 'torch.nn.modules.padding.ZeroPad2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 27.03M, Params: 61.23K


### 3. ResNet

In [6]:
import torch
import torch.nn as nn
import torchvision.models as models

class ResNetModel(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(ResNetModel, self).__init__()        
         # Add a convolutional layer at the top
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)  # Assuming input is grayscale (1 channel)

        self.resnet = models.resnet152(pretrained=True)
        
        if not fine_tune:
            for param in self.resnet.parameters():
                param.requires_grad = False
            
            for param in self.resnet.fc.parameters():
                param.requires_grad = True
        

        in_features = self.resnet.fc.in_features

        # Replace the final fully connected layer with a new one for the specified number of classes
        self.resnet.fc = nn.Linear(in_features, num_classes)
        
        
    def forward(self, x):
        x = self.conv(x)
        x = self.resnet(x)
        return x
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetModel(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:01<00:00, 136MB/s]



Model output's shape: torch.Size([1, 15])
tensor([[ 0.3527,  0.2662, -0.2867,  0.5065,  0.1777, -0.1360,  0.1384,  0.3014,
          0.1933,  0.0919, -0.0363,  0.1002,  0.3396,  0.4503, -0.2354]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 58.17 M
Number of trainable parameters in millions: 0.03 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 11.60G, Params: 58.17M


### 4. CoatNet ( 512 x 512 image size not compatible )

In [7]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model


class CoatNetModel(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(CoatNetModel, self).__init__()
        self.coatnet = create_model(
            'timm/coatnet_3_rw_224.sw_in12k', # accepts only 224x224 images
            pretrained=True,
            num_classes=num_classes,
            in_chans=1
        )
        
        if not fine_tune:
            for param in self.coatnet.parameters():
                param.requires_grad = False
            
            for param in self.coatnet.head.parameters():
                param.requires_grad = True
        

    def forward(self, x):
        x = self.coatnet(x)
        return x
    
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CoatNetModel(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits
display_params_flops(model)

model.safetensors:   0%|          | 0.00/727M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[ 1.0493,  0.4044,  0.6934, -2.3854,  2.4097,  0.5271, -2.4023, -2.9863,
         -0.0977,  1.1776, -2.9211, -1.2702, -1.6801,  2.0718, -0.2862]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 163.66 M
Number of trainable parameters in millions: 0.02 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_avgpool() for <class 'torch.nn.modules.pooling.AvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 32.49G, Params: 163.28M


### 5. ConvNeXt

In [8]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model

class ConvNeXtModel(nn.Module):
    def __init__(self, num_classes, model_name="convnext_tiny", pretrained=True, fine_tune=False):
        super(ConvNeXtModel, self).__init__()
        self.convnext_model = create_model(
            model_name, 
            pretrained=pretrained,
            num_classes=num_classes,
            in_chans=1
        )
        
        if not fine_tune:
            for param in self.convnext_model.parameters():
                param.requires_grad = False
            
            for param in self.convnext_model.head.parameters():
                param.requires_grad = True
        
        
    def forward(self, x):
        x = self.convnext_model(x)
        return x
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNeXtModel(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[-0.9079,  0.9408, -0.0791,  0.0431, -0.3826, -0.4962,  0.5288,  0.7123,
         -1.1503,  0.4666, -0.7218,  0.7928,  0.0454,  0.4059,  1.0176]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 27.83 M
Number of trainable parameters in millions: 0.01 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
FLOPs: 4.45G, Params: 27.81M


### 6. DenseNet

In [9]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model

class DenseNetModel(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(DenseNetModel, self).__init__()
        self.densenet = create_model(
            'densenet121.tv_in1k', 
            pretrained=pretrained,
            num_classes=num_classes,
            in_chans=1
        )

        if not fine_tune:
            for param in self.densenet.parameters():
                param.requires_grad = False
            
            for param in self.densenet.global_pool.parameters():
                param.requires_grad = True
                
            for param in self.densenet.head_drop.parameters():
                param.requires_grad = True
                
            for param in self.densenet.classifier.parameters():
                param.requires_grad = True
        
        
    def forward(self, x):
        x = self.densenet(x)
        return x
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNetModel(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

model.safetensors:   0%|          | 0.00/32.3M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[ 0.2303,  0.3872, -0.1187, -0.0159,  0.2146,  0.4215, -0.4641, -0.2141,
          0.3579, -0.0363, -0.2009,  0.7574, -0.1098, -0.0292, -0.1578]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 6.96 M
Number of trainable parameters in millions: 0.02 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_avgpool() for <class 'torch.nn.modules.pooling.AvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 2.75G, Params: 6.

### Transformers

### 1. Swin Transformer (512 x 512 not compatible) only 224 or 384

In [10]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model


class SwinTransformerModel(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(SwinTransformerModel, self).__init__()
        self.swin = create_model(
            'swin_large_patch4_window7_224.ms_in22k', 
            pretrained=True,
            num_classes=num_classes,
            in_chans=1
        )
        
        if not fine_tune:
            for param in self.swin.parameters():
                param.requires_grad = False
            
            for param in self.swin.head.parameters():
                param.requires_grad = True
        

    def forward(self, x):
        x = self.swin(x)
        return x
    
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinTransformerModel(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

model.safetensors:   0%|          | 0.00/916M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[-0.2104,  0.0284, -0.3546, -0.2533,  0.2120,  0.7453, -0.1186, -0.7598,
          0.3066,  0.2589, -0.0577,  0.7392, -0.3076, -0.0241, -0.2682]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 195.01 M
Number of trainable parameters in millions: 0.02 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
FLOPs: 34.06G, Params: 194.92M


### 2. MViT (512 x 512 image not supported) 

In [11]:
import timm
import torch.nn as nn

class MViT(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(MViT, self).__init__()
        self.mvit = timm.create_model('mvitv2_base.fb_in1k', 
           pretrained=pretrained, 
           num_classes=num_classes,
            in_chans=1)
        
        
        if not fine_tune:
            for param in self.mvit.parameters():
                param.requires_grad = False
            
            for param in self.mvit.head.parameters():
                param.requires_grad = True
        

    def forward(self, x):
        return self.mvit(x)
    

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MViT(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits    
display_params_flops(model)

model.safetensors:   0%|          | 0.00/206M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[ 0.0954,  0.0023, -0.1967, -0.0303, -0.0759, -0.0975,  0.2859, -0.1323,
          0.3687,  0.3437, -0.0465,  0.2573, -0.2109, -0.0827,  0.0485]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 50.71 M
Number of trainable parameters in millions: 0.01 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
FLOPs: 8.83G, Params: 50.48M


### 3. DaViT

In [12]:
import torch
import torch.nn as nn
import torchvision.models as models

from torch import nn
from timm import create_model


class DaViT(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(DaViT, self).__init__()
        self.davit = create_model(
            'davit_base.msft_in1k', 
            pretrained=True,
            num_classes=num_classes,
            in_chans=1
        )
        

        if not fine_tune:
            for param in self.davit.parameters():
                param.requires_grad = False
            
            for param in self.davit.head.parameters():
                param.requires_grad = True
        
        

    def forward(self, x):
        x = self.davit(x)
        return x
    
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DaViT(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

pytorch_model.bin:   0%|          | 0.00/352M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[-0.2269,  0.1547, -0.0736,  0.1810,  0.1348, -0.2158, -0.2421,  0.1566,
          0.0277,  0.1135,  0.0463, -0.1975, -0.2434,  0.2884,  0.2682]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 86.93 M
Number of trainable parameters in millions: 0.02 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
FLOPs: 15.18G, Params: 86.88M


### 4. PVT

In [13]:
import timm
import torch.nn as nn


class PVT(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(PVT, self).__init__()
        self.pvt = timm.create_model('pvt_v2_b5', 
           pretrained=pretrained, 
           num_classes=num_classes,
           in_chans=1)
        
#         model.pvt.head_drop
#         model.pvt.head

        if not fine_tune:
            for param in self.pvt.parameters():
                param.requires_grad = False
            
            for param in self.pvt.head_drop.parameters():
                param.requires_grad = True
            
            for param in self.pvt.head.parameters():
                param.requires_grad = True
        

    def forward(self, x):
        return self.pvt(x)

    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PVT(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[-0.8003,  0.2739, -0.2441,  0.9140,  0.2193,  0.1956, -0.4293, -0.3385,
         -0.1421, -0.2328, -0.8976,  0.5361, -0.1937,  0.5554, -0.2802]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 81.44 M
Number of trainable parameters in millions: 0.01 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
FLOPs: 11.33G, Params: 81.38M


### 5. GC ViT (Not pretrained) 512 compatible version is very time consuming [1 hidden cell]

In [14]:
#!/usr/bin/env python3

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# written by Ali Hatamizadeh and Pavlo Molchanov from NVResearch


import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
from timm.models._registry import register_model
from timm.models._builder import build_model_with_cfg


def _cfg(url='', **kwargs):
    return {'url': url,
            'num_classes': num_classes, # set num_classes here
            'input_size': (1, IMAGE_SIZE[0], IMAGE_SIZE[1]),  # adjust input image size here
            'pool_size': None,
            'crop_pct': 0.875, 
            'interpolation': 'bicubic', 
            'fixed_input_size': True,
            'mean': (0.485, 0.456, 0.406), 
            'std': (0.229, 0.224, 0.225),
            **kwargs
            }


default_cfgs = {
    'gc_vit_xxtiny': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_xxtiny.pth.tar',
                          crop_pct=1.0, 
                          input_size=(3, 224, 224), 
                          crop_mode= 'center'),
    'gc_vit_xtiny': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_xtiny.pth.tar',
                         crop_pct=0.875, 
                         input_size=(3, 224, 224), 
                         crop_mode= 'center'),
    'gc_vit_tiny': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_tiny.pth.tar',
                        crop_pct=1.0, 
                        input_size=(3, 224, 224), 
                        crop_mode= 'center'),
    'gc_vit_tiny2': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_tiny2.pth.tar',
                         crop_pct=0.875, 
                         input_size=(3, 224, 224), 
                         crop_mode= 'center'),
    'gc_vit_small': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_small.pth.tar',
                         crop_pct=0.875, 
                         input_size=(3, 224, 224), 
                         crop_mode= 'center'),
    'gc_vit_small2': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_small2.pth.tar',
                          crop_pct=0.875, 
                          input_size=(3, 224, 224), 
                          crop_mode= 'center'),
    'gc_vit_base': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_base.pth.tar',
                        crop_pct=1.0, 
                        input_size=(3, 224, 224), 
                        crop_mode= 'center'),
    'gc_vit_large': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_1k_large.pth.tar',
                         crop_pct=1.0, 
                         input_size=(3, 224, 224), 
                         crop_mode= 'center'),
    'gc_vit_large_224_21k': _cfg(url='https://drive.google.com/uc?export=download&id=1maGDr6mJkLyRTUkspMzCgSlhDzNRFGEf', 
                                 crop_pct=1.0, 
                                 input_size=(3, 224, 224), 
                                 crop_mode= 'center'),
    'gc_vit_large_384_21k': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_21k_large_384.pth.tar', 
                                 crop_pct=1.0, 
                                 input_size=(3, 384, 384), 
                                 crop_mode='squash'),
    'gc_vit_large_512_21k': _cfg(url='https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_21k_large_512.pth.tar', 
                                 crop_pct=1.0, 
                                 input_size=(3, 512, 512), 
                                 crop_mode='squash'),                             
}


def _to_channel_last(x):
    """
    Args:
        x: (B, C, H, W)

    Returns:
        x: (B, H, W, C)
    """
    return x.permute(0, 2, 3, 1)


def _to_channel_first(x):
    """
    Args:
        x: (B, H, W, C)

    Returns:
        x: (B, C, H, W)
    """
    return x.permute(0, 3, 1, 2)


def window_partition(x, window_size, h_w, w_w):
    """
    Args:
        x: (B, H, W, C)
        window_size: window size

    Returns:
        local window features (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, h_w, window_size, w_w, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W, h_w, w_w, B):
    """
    Args:
        windows: local window features (num_windows*B, window_size, window_size, C)
        window_size: Window size
        H: Height of image
        W: Width of image

    Returns:
        x: (B, H, W, C)
    """
    # B = int(windows.shape[0] // (H * W // window_size // window_size))
    x = windows.view(B, h_w, w_w, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class Mlp(nn.Module):
    """
    Multi-Layer Perceptron (MLP) block
    """

    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.GELU,
                 drop=0.):
        """
        Args:
            in_features: input features dimension.
            hidden_features: hidden features dimension.
            out_features: output features dimension.
            act_layer: activation function.
            drop: dropout rate.
        """

        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class SE(nn.Module):
    """
    Squeeze and excitation block
    """

    def __init__(self,
                 inp,
                 oup,
                 expansion=0.25):
        """
        Args:
            inp: input features dimension.
            oup: output features dimension.
            expansion: expansion ratio.
        """

        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class ReduceSize(nn.Module):
    """
    Down-sampling block based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self,
                 dim,
                 norm_layer=nn.LayerNorm,
                 keep_dim=False):
        """
        Args:
            dim: feature size dimension.
            norm_layer: normalization layer.
            keep_dim: bool argument for maintaining the resolution.
        """

        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1,
                      groups=dim, bias=False),
            nn.GELU(),
            SE(dim, dim),
            nn.Conv2d(dim, dim, 1, 1, 0, bias=False),
        )
        if keep_dim:
            dim_out = dim
        else:
            dim_out = 2 * dim
        self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False)
        self.norm2 = norm_layer(dim_out)
        self.norm1 = norm_layer(dim)

    def forward(self, x):
        x = x.contiguous()
        x = self.norm1(x)
        x = _to_channel_first(x)
        x = x + self.conv(x)
        x = self.reduction(x)
        x = _to_channel_last(x)
        x = self.norm2(x)
        return x


class PatchEmbed(nn.Module):
    """
    Patch embedding block based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self, in_chans=3, dim=96):
        """
        Args:
            in_chans: number of input channels.
            dim: feature size dimension.
        """

        super().__init__()
        self.proj = nn.Conv2d(in_chans, dim, 3, 2, 1)
        self.conv_down = ReduceSize(dim=dim, keep_dim=True)

    def forward(self, x):
        x = self.proj(x)
        x = _to_channel_last(x)
        x = self.conv_down(x)
        return x


class FeatExtract(nn.Module):
    """
    Feature extraction block based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self, dim, keep_dim=False):
        """
        Args:
            dim: feature size dimension.
            keep_dim: bool argument for maintaining the resolution.
        """

        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1,
                      groups=dim, bias=False),
            nn.GELU(),
            SE(dim, dim),
            nn.Conv2d(dim, dim, 1, 1, 0, bias=False),
        )
        if not keep_dim:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.keep_dim = keep_dim

    def forward(self, x):
        x = x.contiguous()
        x = x + self.conv(x)
        if not self.keep_dim:
            x = self.pool(x)
        return x


class WindowAttention(nn.Module):
    """
    Local window attention based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    """

    def __init__(self,
                 dim,
                 num_heads,
                 window_size,
                 qkv_bias=True,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 ):
        """
        Args:
            dim: feature size dimension.
            num_heads: number of attention head.
            window_size: window size.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            attn_drop: attention dropout rate.
            proj_drop: output dropout rate.
        """

        super().__init__()
        window_size = (window_size, window_size)
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = torch.div(dim, num_heads, rounding_mode='floor')
        self.scale = qk_scale or head_dim ** -0.5
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, q_global):
        B_, N, C = x.shape
        head_dim = torch.div(C, self.num_heads, rounding_mode='floor')
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class WindowAttentionGlobal(nn.Module):
    """
    Global window attention based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self,
                 dim,
                 num_heads,
                 window_size,
                 qkv_bias=True,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 ):
        """
        Args:
            dim: feature size dimension.
            num_heads: number of attention head.
            window_size: window size.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            attn_drop: attention dropout rate.
            proj_drop: output dropout rate.
        """

        super().__init__()
        window_size = (window_size, window_size)
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = torch.div(dim, num_heads, rounding_mode='floor')
        self.scale = qk_scale or head_dim ** -0.5
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, q_global):
        B_, N, C = x.shape
        B = q_global.shape[0]
        head_dim = torch.div(C, self.num_heads, rounding_mode='floor')
        B_dim = torch.div(B_, B, rounding_mode='floor')
        kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, head_dim).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]
        q_global = q_global.repeat(1, B_dim, 1, 1, 1)
        q = q_global.reshape(B_, self.num_heads, N, head_dim)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class GCViTBlock(nn.Module):
    """
    GCViT block based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self,
                 dim,
                 input_resolution,
                 num_heads,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 attention=WindowAttentionGlobal,
                 norm_layer=nn.LayerNorm,
                 layer_scale=None,
                 ):
        """
        Args:
            dim: feature size dimension.
            input_resolution: input image resolution.
            num_heads: number of attention head.
            window_size: window size.
            mlp_ratio: MLP ratio.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: drop path rate.
            act_layer: activation function.
            attention: attention block type.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
        """

        super().__init__()
        self.window_size = window_size
        self.norm1 = norm_layer(dim)

        self.attn = attention(dim,
                              num_heads=num_heads,
                              window_size=window_size,
                              qkv_bias=qkv_bias,
                              qk_scale=qk_scale,
                              attn_drop=attn_drop,
                              proj_drop=drop,
                              )

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.layer_scale = False
        if layer_scale is not None and type(layer_scale) in [int, float]:
            self.layer_scale = True
            self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
        else:
            self.gamma1 = 1.0
            self.gamma2 = 1.0

        inp_w = torch.div(input_resolution, window_size, rounding_mode='floor')
        self.num_windows = int(inp_w * inp_w)

    def forward(self, x, q_global):
        B, H, W, C = x.shape
        shortcut = x
        x = self.norm1(x)
        h_w = torch.div(H, self.window_size, rounding_mode='floor')
        w_w = torch.div(W, self.window_size, rounding_mode='floor')
        x_windows = window_partition(x, self.window_size, h_w, w_w)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        attn_windows = self.attn(x_windows, q_global)
        x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w, B)
        x = shortcut + self.drop_path(self.gamma1 * x)
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x


class GlobalQueryGen(nn.Module):
    """
    Global query generator based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self,
                 dim,
                 input_resolution,
                 image_resolution,
                 window_size,
                 num_heads):
        """
        Args:
            dim: feature size dimension.
            input_resolution: input image resolution.
            window_size: window size.
            num_heads: number of heads.

        For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at
        down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details.
        """

        super().__init__()
        if input_resolution == image_resolution//4:
            self.to_q_global = nn.Sequential(
                FeatExtract(dim, keep_dim=False),
                FeatExtract(dim, keep_dim=False),
                FeatExtract(dim, keep_dim=False),
            )

        elif input_resolution == image_resolution//8:
            self.to_q_global = nn.Sequential(
                FeatExtract(dim, keep_dim=False),
                FeatExtract(dim, keep_dim=False),
            )

        elif input_resolution == image_resolution//16:

            if window_size == input_resolution:
                self.to_q_global = nn.Sequential(
                    FeatExtract(dim, keep_dim=True)
                )

            else:
                self.to_q_global = nn.Sequential(
                    FeatExtract(dim, keep_dim=True)
                )

        elif input_resolution == image_resolution//32:
            self.to_q_global = nn.Sequential(
                FeatExtract(dim, keep_dim=True)
            )

        self.resolution = input_resolution
        self.num_heads = num_heads
        self.N = window_size * window_size
        self.dim_head = torch.div(dim, self.num_heads, rounding_mode='floor')

    def forward(self, x):
        x = _to_channel_last(self.to_q_global(x))
        B = x.shape[0]
        x = x.reshape(B, 1, self.N, self.num_heads, self.dim_head).permute(0, 1, 3, 2, 4)
        return x


class GCViTLayer(nn.Module):
    """
    GCViT layer based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self,
                 dim,
                 depth,
                 input_resolution,
                 image_resolution,
                 num_heads,
                 window_size,
                 downsample=True,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 layer_scale=None):
        """
        Args:
            dim: feature size dimension.
            depth: number of layers in each stage.
            input_resolution: input image resolution.
            window_size: window size in each stage.
            downsample: bool argument for down-sampling.
            mlp_ratio: MLP ratio.
            num_heads: number of heads in each stage.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: drop path rate.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
        """

        super().__init__()
        self.blocks = nn.ModuleList([
            GCViTBlock(dim=dim,
                       num_heads=num_heads,
                       window_size=window_size,
                       mlp_ratio=mlp_ratio,
                       qkv_bias=qkv_bias,
                       qk_scale=qk_scale,
                       attention=WindowAttention if (i % 2 == 0) else WindowAttentionGlobal,
                       drop=drop,
                       attn_drop=attn_drop,
                       drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                       norm_layer=norm_layer,
                       layer_scale=layer_scale,
                       input_resolution=input_resolution)
            for i in range(depth)])
        self.downsample = None if not downsample else ReduceSize(dim=dim, norm_layer=norm_layer)
        self.q_global_gen = GlobalQueryGen(dim, input_resolution, image_resolution, window_size, num_heads)

    def forward(self, x):
        q_global = self.q_global_gen(_to_channel_first(x))
        for blk in self.blocks:
            x = blk(x, q_global)
        if self.downsample is None:
            return x
        return self.downsample(x)


class GCViT(nn.Module):
    """
    GCViT based on: "Hatamizadeh et al.,
    Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
    """

    def __init__(self,
                 dim,
                 depths,
                 window_size,
                 mlp_ratio,
                 num_heads,
                 resolution=224,
                 drop_path_rate=0.2,
                 in_chans=3,
                 num_classes=1000,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 norm_layer=nn.LayerNorm,
                 layer_scale=None,
                 **kwargs):
        """
        Args:
            dim: feature size dimension.
            depths: number of layers in each stage.
            window_size: window size in each stage.
            mlp_ratio: MLP ratio.
            num_heads: number of heads in each stage.
            resolution: input image resolution.
            drop_path_rate: drop path rate.
            in_chans: number of input channels.
            num_classes: number of classes.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
        """
        super().__init__()

        num_features = int(dim * 2 ** (len(depths) - 1))
        self.num_classes = num_classes
        self.patch_embed = PatchEmbed(in_chans=in_chans, dim=dim)
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        self.levels = nn.ModuleList()
        for i in range(len(depths)):
            level = GCViTLayer(dim=int(dim * 2 ** i),
                               depth=depths[i],
                               num_heads=num_heads[i],
                               window_size=window_size[i],
                               mlp_ratio=mlp_ratio,
                               qkv_bias=qkv_bias,
                               qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
                               norm_layer=norm_layer,
                               downsample=(i < len(depths) - 1),
                               layer_scale=layer_scale,
                               input_resolution=int(2 ** (-2 - i) * resolution),
                               image_resolution=resolution)
            self.levels.append(level)
        self.norm = norm_layer(num_features)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'rpb'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        for level in self.levels:
            x = level(x)

        x = self.norm(x)
        x = _to_channel_first(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x



def _create_gc_vit(variant, pretrained=False, **kwargs):

    return build_model_with_cfg(
        GCViT,
        variant,
        pretrained,
        **kwargs,
    )


@register_model
def gc_vit_xxtiny(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
    model_kwargs = dict(depths=[2, 2, 6, 2],
                        num_heads=[2, 4, 8, 16],
                        window_size=[7, 7, 14, 7],
                        dim=64,
                        mlp_ratio=3,
                        drop_path_rate=drop_path_rate,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_xxtiny', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_xtiny(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
    model_kwargs = dict(depths=[3, 4, 6, 5],
                        num_heads=[2, 4, 8, 16],
                        window_size=[7, 7, 14, 7],
                        dim=64,
                        mlp_ratio=3,
                        drop_path_rate=drop_path_rate,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_xtiny', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_tiny(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
    model_kwargs = dict(depths=[3, 4, 19, 5],
                        num_heads=[2, 4, 8, 16],
                        window_size=[7, 7, 14, 7],
                        dim=64,
                        mlp_ratio=3,
                        drop_path_rate=drop_path_rate,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_tiny', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_tiny2(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.25)
    model_kwargs = dict(depths=[3, 4, 29, 5],
                        num_heads=[2, 4, 8, 16],
                        window_size=[7, 7, 14, 7],
                        dim=64,
                        mlp_ratio=3,
                        drop_path_rate=drop_path_rate,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_tiny2', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_small(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
    model_kwargs = dict(depths=[3, 4, 19, 5],
                        num_heads=[3, 6, 12, 24],
                        window_size=[7, 7, 14, 7],
                        dim=96,
                        mlp_ratio=2,
                        drop_path_rate=drop_path_rate,
                        layer_scale=1e-5,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_small', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_small2(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.35)
    model_kwargs = dict(depths=[3, 4, 23, 5],
                        num_heads=[3, 6, 12, 24],
                        window_size=[7, 7, 14, 7],
                        dim=96,
                        mlp_ratio=3,
                        drop_path_rate=drop_path_rate,
                        layer_scale=1e-5,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_small2', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_base(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.5)
    model_kwargs = dict(depths=[3, 4, 19, 5],
                        num_heads=[4, 8, 16, 32],
                        window_size=[7, 7, 14, 7],
                        dim=128,
                        mlp_ratio=2,
                        drop_path_rate=drop_path_rate,
                        layer_scale=1e-5,
                        **kwargs
                        )
    
    
    model = _create_gc_vit('gc_vit_base', pretrained=pretrained, **model_kwargs)
#     in_features = model.head.in_features 
#     out_features = 5 # num_classes , set here
    
#     model.head = nn.Linear(in_features, out_features, bias=True)

    return model


@register_model
def gc_vit_large(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.5)
    model_kwargs = dict(depths=[3, 4, 19, 5],
                        num_heads=[6, 12, 24, 48],
                        window_size=[7, 7, 14, 7],
                        dim=192,
                        mlp_ratio=2,
                        drop_path_rate=drop_path_rate,
                        layer_scale=1e-5,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_large', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_large_224_21k(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.5)
    model_kwargs = dict(depths=[3, 4, 19, 5],
                        num_heads=[6, 12, 24, 48],
                        window_size=[7, 7, 14, 7],
                        dim=192,
                        mlp_ratio=2,
                        drop_path_rate=drop_path_rate,
                        layer_scale=1e-5,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_large_224_21k', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_large_384_21k(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.1)
    model_kwargs = dict(depths=[3, 4, 19, 5],
                        num_heads=[6, 12, 24, 48],
                        window_size=[12, 12, 24, 12],
                        dim=192,
                        mlp_ratio=2,
                        drop_path_rate=drop_path_rate,
                        layer_scale=1e-5,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_large_384_21k', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def gc_vit_large_512_21k(pretrained=False, **kwargs) -> GCViT:
    drop_path_rate = kwargs.pop("drop_path_rate", 0.1)
    model_kwargs = dict(depths=[3, 4, 19, 5],
                        num_heads=[6, 12, 24, 48],
                        window_size=[16, 16, 32, 16],
                        dim=192,
                        mlp_ratio=2,
                        drop_path_rate=drop_path_rate,
                        layer_scale=1e-5,
                        **kwargs
                        )
    model = _create_gc_vit('gc_vit_large_512_21k', pretrained=pretrained, **model_kwargs)
    return model

In [15]:
import torch

class GCViT_Model(nn.Module): 
    def __init__(self, num_classes):
        super(GCViT_Model, self).__init__()
        # Add a convolutional layer at the top
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)  # Assuming input is grayscale (1 channel)

#         self.gcvit = gc_vit_large_512_21k(pretrained=False) # for image_size 512 x 512
        self.gcvit = gc_vit_base(pretrained=False) # for image_size 224 x 224


    def forward(self, x):
        x = self.conv(x)
        return self.gcvit(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
model = GCViT_Model(num_classes)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]



Model output's shape: torch.Size([1, 15])
tensor([[-0.1446, -0.0298, -0.1130, -0.2961, -0.3105,  0.1483,  0.3173, -0.1134,
         -0.0269, -0.2434, -0.1346,  0.1015,  0.0903, -0.4193, -0.3536]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 89.31 M
Number of trainable parameters in millions: 89.31 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
FL

### 6. EfficientViT

In [16]:
import timm
import torch.nn as nn

# efficientvit_b0.r224_in1k # is very fast 
# 'efficientvit_l3.r384_in1k # time consuming 

class EfficientViT(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(EfficientViT, self).__init__()
        self.efficientvit = timm.create_model('efficientvit_b0.r224_in1k', 
           pretrained=pretrained,
           num_classes=num_classes,
           in_chans=1)
        
#         model.efficientvit.head
        if not fine_tune:
            for param in self.efficientvit.parameters():
                param.requires_grad = False
            
            for param in self.efficientvit.head.parameters():
                param.requires_grad = True
            


    def forward(self, x):
        return self.efficientvit(x)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientViT(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)



output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

model.safetensors:   0%|          | 0.00/13.7M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[ 0.5227,  0.4816, -0.1879,  0.4157,  0.4021,  0.2044,  0.2763, -0.0605,
         -0.4593, -0.0542, -0.1839, -0.3059, -0.0949, -0.5918, -0.9585]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 2.15 M
Number of trainable parameters in millions: 1.47 M
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'

### 7. MaxViT

In [17]:
import timm
import torch.nn as nn

# maxvit_tiny_tf_224.in1k # for image size 224 x 224
# maxvit_tiny_tf_512.in1k # for image size 512 x 512



class MaxVit(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(MaxVit, self).__init__()
        self.maxvit = timm.create_model('maxvit_tiny_tf_224.in1k', 
           pretrained=pretrained,
           num_classes=num_classes,
           in_chans=1)
        
        
        if not fine_tune:
            for param in self.maxvit.parameters():
                param.requires_grad = False
            
            for param in self.maxvit.head.parameters():
                param.requires_grad = True
            


    def forward(self, x):
        return self.maxvit(x)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MaxVit(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

model.safetensors:   0%|          | 0.00/124M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[-0.1084,  0.0446,  0.2134, -0.2865,  0.1619, -0.2049,  0.5878, -0.1520,
          0.0856,  0.0323,  0.0307, -0.1480, -0.2198,  0.0066,  0.0194]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 30.41 M
Number of trainable parameters in millions: 0.27 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
FLOPs: 5.33G, Params: 30.28M


### Proposed #1 DaViT_Unetr

In [18]:
!pip install monai

Collecting monai
  Downloading monai-1.3.0-202310121228-py3-none-any.whl.metadata (10 kB)
Downloading monai-1.3.0-202310121228-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.3.0


In [19]:
import timm
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock


class DaViT_UnetR_Model(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(DaViT_UnetR_Model, self).__init__()
        
#         self.davit = timm.create_model('davit_small.msft_in1k', pretrained=pretrained, num_classes=num_classes)
        
        self.davit = timm.create_model('davit_small.msft_in1k', pretrained=pretrained, features_only=True, in_chans=1)
        
        if not fine_tune:
            for param in self.davit.parameters():
                param.requires_grad = False
        
        
        spatial_dims = 2 
        in_channels = 1 # R,G,B
        feature_size = 96
        norm_name = "instance"
        hidden_size = 96
        res_block = True
        conv_block = False

        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*2,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*4,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size * 8,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        
        self.conv = nn.Sequential(
            nn.Conv2d(feature_size, 78, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(78, 50, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # Classifier layer with convolution
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2450, 1024),  # (DYNAMIC)Adjust the input size based on the output size of the convolutional layer
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    
    def forward(self, x_in):
        
        hidden_states_out = self.davit(x_in) # returns 4 lists
#         print("Length of hidden states from DaViT:", len(hidden_states_out))
#         for i in hidden_states_out:
#             print(i.shape)
#         print()


        enc1 = self.encoder1(x_in)
#         print("output from encoder1:", enc1.shape)
        
        x2 = hidden_states_out[0]
        enc2 = self.encoder2(x2)
#         print("output from encoder2:", enc2.shape)
        
        x3 = hidden_states_out[1]
        enc3 = self.encoder3(x3)
#         print("output from encoder3:", enc3.shape)
        
        
        x4 = hidden_states_out[2]
        enc4 = self.encoder4(x4)
#         print("output from encoder4:", enc4.shape)
        
#         print("All encoders OK\n")
        
        dec4 = hidden_states_out[3]
#         print("Input to decoder5:", dec4.shape, enc4.shape)
        dec3 = self.decoder5(dec4, enc4)
#         print("output from decoder5:", dec3.shape)
        
#         print("Input to decoder4:", dec3.shape, enc3.shape)
        dec2 = self.decoder4(dec3, enc3)
#         print("output from decoder4:", dec2.shape)
        
#         print("Input to decoder3:", dec2.shape, enc2.shape)
        dec1 = self.decoder3(dec2, enc2)
#         print("output from decoder3:", dec1.shape)
        
#         print("Input to decoder2:", dec1.shape, enc1.shape)
        out = self.decoder2(dec1, enc1) 
#         print("output from decoder2:", out.shape)
        

        
        conv_out = self.conv(out)
#         print(f"conv_out_shape:{conv_out.shape}")

        return self.classifier(conv_out)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DaViT_UnetR_Model(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)

2024-04-20 16:41:57.781980: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-20 16:41:57.782145: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-20 16:41:57.874898: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/199M [00:00<?, ?B/s]


Model output's shape: torch.Size([1, 15])
tensor([[ 0.1955,  0.0616,  0.0698, -0.0919,  0.2435,  0.1863,  0.0584, -0.0458,
         -0.0207, -0.0806, -0.0196,  0.0023,  0.0011, -0.0114,  0.1238]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 78.91 M
Number of trainable parameters in millions: 29.95 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INF

### Proposed #2 GWA_DaViT Model (10 cells) # works only with 224 x 224

In [20]:
""" PyTorch Feature Extraction Helpers

A collection of classes, functions, modules to help extract features from models
and provide a common interface for describing them.

The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py

Hacked together by / Copyright 2020 Ross Wightman
"""
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Sequence, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from timm.layers import Format


__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']


class FeatureInfo:

    def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
        prev_reduction = 1
        for i, fi in enumerate(feature_info):
            # sanity check the mandatory fields, there may be additional fields depending on the model
            assert 'num_chs' in fi and fi['num_chs'] > 0
            assert 'reduction' in fi and fi['reduction'] >= prev_reduction
            prev_reduction = fi['reduction']
            assert 'module' in fi
            fi.setdefault('index', i)
        self.out_indices = out_indices
        self.info = feature_info

    def from_other(self, out_indices: Tuple[int]):
        return FeatureInfo(deepcopy(self.info), out_indices)

    def get(self, key, idx=None):
        """ Get value by key at specified index (indices)
        if idx == None, returns value for key at each output index
        if idx is an integer, return value for that feature module index (ignoring output indices)
        if idx is a list/tupple, return value for each module index (ignoring output indices)
        """
        if idx is None:
            return [self.info[i][key] for i in self.out_indices]
        if isinstance(idx, (tuple, list)):
            return [self.info[i][key] for i in idx]
        else:
            return self.info[idx][key]

    def get_dicts(self, keys=None, idx=None):
        """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
        """
        if idx is None:
            if keys is None:
                return [self.info[i] for i in self.out_indices]
            else:
                return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
        if isinstance(idx, (tuple, list)):
            return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
        else:
            return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}

    def channels(self, idx=None):
        """ feature channels accessor
        """
        return self.get('num_chs', idx)

    def reduction(self, idx=None):
        """ feature reduction (output stride) accessor
        """
        return self.get('reduction', idx)

    def module_name(self, idx=None):
        """ feature module name accessor
        """
        return self.get('module', idx)

    def __getitem__(self, item):
        return self.info[item]

    def __len__(self):
        return len(self.info)


class FeatureHooks:
    """ Feature Hook Helper

    This module helps with the setup and extraction of hooks for extracting features from
    internal nodes in a model by node name.

    FIXME This works well in eager Python but needs redesign for torchscript.
    """

    def __init__(
            self,
            hooks: Sequence[str],
            named_modules: dict,
            out_map: Sequence[Union[int, str]] = None,
            default_hook_type: str = 'forward',
    ):
        # setup feature hooks
        self._feature_outputs = defaultdict(OrderedDict)
        modules = {k: v for k, v in named_modules}
        for i, h in enumerate(hooks):
            hook_name = h['module']
            m = modules[hook_name]
            hook_id = out_map[i] if out_map else hook_name
            hook_fn = partial(self._collect_output_hook, hook_id)
            hook_type = h.get('hook_type', default_hook_type)
            if hook_type == 'forward_pre':
                m.register_forward_pre_hook(hook_fn)
            elif hook_type == 'forward':
                m.register_forward_hook(hook_fn)
            else:
                assert False, "Unsupported hook type"

    def _collect_output_hook(self, hook_id, *args):
        x = args[-1]  # tensor we want is last argument, output for fwd, input for fwd_pre
        if isinstance(x, tuple):
            x = x[0]  # unwrap input tuple
        self._feature_outputs[x.device][hook_id] = x

    def get_output(self, device) -> Dict[str, torch.tensor]:
        output = self._feature_outputs[device]
        self._feature_outputs[device] = OrderedDict()  # clear after reading
        return output


def _module_list(module, flatten_sequential=False):
    # a yield/iter would be better for this but wouldn't be compatible with torchscript
    ml = []
    for name, module in module.named_children():
        if flatten_sequential and isinstance(module, nn.Sequential):
            # first level of Sequential containers is flattened into containing model
            for child_name, child_module in module.named_children():
                combined = [name, child_name]
                ml.append(('_'.join(combined), '.'.join(combined), child_module))
        else:
            ml.append((name, name, module))
    return ml


def _get_feature_info(net, out_indices):
    feature_info = getattr(net, 'feature_info')
    if isinstance(feature_info, FeatureInfo):
        return feature_info.from_other(out_indices)
    elif isinstance(feature_info, (list, tuple)):
        return FeatureInfo(net.feature_info, out_indices)
    else:
        assert False, "Provided feature_info is not valid"


def _get_return_layers(feature_info, out_map):
    module_names = feature_info.module_name()
    return_layers = {}
    for i, name in enumerate(module_names):
        return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
    return return_layers


class FeatureDictNet(nn.ModuleDict):
    """ Feature extractor with OrderedDict return

    Wrap a model and extract features as specified by the out indices, the network is
    partially re-built from contained modules.

    There is a strong assumption that the modules have been registered into the model in the same
    order as they are used. There should be no reuse of the same nn.Module more than once, including
    trivial modules like `self.relu = nn.ReLU`.

    Only submodules that are directly assigned to the model class (`model.feature1`) or at most
    one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
    All Sequential containers that are directly assigned to the original model will have their
    modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
            out_map: Sequence[Union[int, str]] = None,
            output_fmt: str = 'NCHW',
            feature_concat: bool = False,
            flatten_sequential: bool = False,
    ):
        """
        Args:
            model: Model from which to extract features.
            out_indices: Output indices of the model features to extract.
            out_map: Return id mapping for each output index, otherwise str(index) is used.
            feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
                first element e.g. `x[0]`
            flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
        """
        super(FeatureDictNet, self).__init__()
        self.feature_info = _get_feature_info(model, out_indices)
        self.output_fmt = Format(output_fmt)
        self.concat = feature_concat
        self.grad_checkpointing = False
        self.return_layers = {}

        return_layers = _get_return_layers(self.feature_info, out_map)
        modules = _module_list(model, flatten_sequential=flatten_sequential)
        remaining = set(return_layers.keys())
        layers = OrderedDict()
        for new_name, old_name, module in modules:
            layers[new_name] = module
            if old_name in remaining:
                # return id has to be consistently str type for torchscript
                self.return_layers[new_name] = str(return_layers[old_name])
                remaining.remove(old_name)
            if not remaining:
                break
        assert not remaining and len(self.return_layers) == len(return_layers), \
            f'Return layers ({remaining}) are not present in model'
        self.update(layers)

    def set_grad_checkpointing(self, enable: bool = True):
        self.grad_checkpointing = enable

    def _collect(self, x) -> (Dict[str, torch.Tensor]):
        out = OrderedDict()
        for i, (name, module) in enumerate(self.items()):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # Skipping checkpoint of first module because need a gradient at input
                # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
                # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
                first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
                x = module(x) if first_or_last_module else checkpoint(module, x)
            else:
                x = module(x)

            if name in self.return_layers:
                out_id = self.return_layers[name]
                if isinstance(x, (tuple, list)):
                    # If model tap is a tuple or list, concat or select first element
                    # FIXME this may need to be more generic / flexible for some nets
                    out[out_id] = torch.cat(x, 1) if self.concat else x[0]
                else:
                    out[out_id] = x
        return out

    def forward(self, x) -> Dict[str, torch.Tensor]:
        return self._collect(x)


class FeatureListNet(FeatureDictNet):
    """ Feature extractor with list return

    A specialization of FeatureDictNet that always returns features as a list (values() of dict).
    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
            output_fmt: str = 'NCHW',
            feature_concat: bool = False,
            flatten_sequential: bool = False,
    ):
        """
        Args:
            model: Model from which to extract features.
            out_indices: Output indices of the model features to extract.
            feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
                first element e.g. `x[0]`
            flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
        """
        super().__init__(
            model,
            out_indices=out_indices,
            output_fmt=output_fmt,
            feature_concat=feature_concat,
            flatten_sequential=flatten_sequential,
        )

    def forward(self, x) -> (List[torch.Tensor]):
        return list(self._collect(x).values())


class FeatureHookNet(nn.ModuleDict):
    """ FeatureHookNet

    Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.

    If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
    network in any way.

    If `no_rewrite` is False, the model will be re-written as in the
    FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.

    FIXME this does not currently work with Torchscript, see FeatureHooks class
    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
            out_map: Sequence[Union[int, str]] = None,
            return_dict: bool = False,
            output_fmt: str = 'NCHW',
            no_rewrite: bool = False,
            flatten_sequential: bool = False,
            default_hook_type: str = 'forward',
    ):
        """

        Args:
            model: Model from which to extract features.
            out_indices: Output indices of the model features to extract.
            out_map: Return id mapping for each output index, otherwise str(index) is used.
            return_dict: Output features as a dict.
            no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
                flatten_sequential arg must also be False if this is set True.
            flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
            default_hook_type: The default hook type to use if not specified in model.feature_info.
        """
        super().__init__()
        assert not torch.jit.is_scripting()
        self.feature_info = _get_feature_info(model, out_indices)
        self.return_dict = return_dict
        self.output_fmt = Format(output_fmt)
        self.grad_checkpointing = False

        layers = OrderedDict()
        hooks = []
        if no_rewrite:
            assert not flatten_sequential
            if hasattr(model, 'reset_classifier'):  # make sure classifier is removed?
                model.reset_classifier(0)
            layers['body'] = model
            hooks.extend(self.feature_info.get_dicts())
        else:
            modules = _module_list(model, flatten_sequential=flatten_sequential)
            remaining = {
                f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
                for f in self.feature_info.get_dicts()
            }
            for new_name, old_name, module in modules:
                layers[new_name] = module
                for fn, fm in module.named_modules(prefix=old_name):
                    if fn in remaining:
                        hooks.append(dict(module=fn, hook_type=remaining[fn]))
                        del remaining[fn]
                if not remaining:
                    break
            assert not remaining, f'Return layers ({remaining}) are not present in model'
        self.update(layers)
        self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)

    def set_grad_checkpointing(self, enable: bool = True):
        self.grad_checkpointing = enable

    def forward(self, x):
        for i, (name, module) in enumerate(self.items()):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # Skipping checkpoint of first module because need a gradient at input
                # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
                # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
                first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
                x = module(x) if first_or_last_module else checkpoint(module, x)
            else:
                x = module(x)
        out = self.hooks.get_output(x.device)
        return out if self.return_dict else list(out.values())

In [21]:
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, List, Dict, Union, Type

import torch
from torch import nn

# from ._features import _get_feature_info, _get_return_layers

try:
    from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
    has_fx_feature_extraction = True
except ImportError:
    has_fx_feature_extraction = False

# Layers we went to treat as leaf modules
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers.norm_act import (
    BatchNormAct2d,
    SyncBatchNormAct,
    FrozenBatchNormAct2d,
    GroupNormAct,
    GroupNorm1Act,
    LayerNormAct,
    LayerNormAct2d
)

__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
           'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
           'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']


# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
    BilinearAttnTransform,  # reason: flow control t <= 1
    # Reason: get_same_padding has a max which raises a control flow error
    Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
    CondConv2d,  # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]),
    BatchNormAct2d,
    SyncBatchNormAct,
    FrozenBatchNormAct2d,
    GroupNormAct,
    GroupNorm1Act,
    LayerNormAct,
    LayerNormAct2d,
}

try:
    from timm.layers import InplaceAbn
    _leaf_modules.add(InplaceAbn)
except ImportError:
    pass


def register_notrace_module(module: Type[nn.Module]):
    """
    Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
    """
    _leaf_modules.add(module)
    return module


def is_notrace_module(module: Type[nn.Module]):
    return module in _leaf_modules


def get_notrace_modules():
    return list(_leaf_modules)


# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()


def register_notrace_function(func: Callable):
    """
    Decorator for functions which ought not to be traced through
    """
    _autowrap_functions.add(func)
    return func


def is_notrace_function(func: Callable):
    return func in _autowrap_functions


def get_notrace_functions():
    return list(_autowrap_functions)


def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
    assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
    return _create_feature_extractor(
        model, return_nodes,
        tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
    )


class FeatureGraphNet(nn.Module):
    """ A FX Graph based feature extractor that works with the model feature_info metadata
    """
    def __init__(self, model, out_indices, out_map=None):
        super().__init__()
        assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
        self.feature_info = _get_feature_info(model, out_indices)
        if out_map is not None:
            assert len(out_map) == len(out_indices)
        return_nodes = _get_return_layers(self.feature_info, out_map)
        self.graph_module = create_feature_extractor(model, return_nodes)

    def forward(self, x):
        return list(self.graph_module(x).values())


class GraphExtractNet(nn.Module):
    """ A standalone feature extraction wrapper that maps dict -> list or single tensor
    NOTE:
      * one can use feature_extractor directly if dictionary output is desired
      * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
      metadata for builtin feature extraction mode
      * create_feature_extractor can be used directly if dictionary output is acceptable

    Args:
        model: model to extract features from
        return_nodes: node names to return features from (dict or list)
        squeeze_out: if only one output, and output in list format, flatten to single tensor
    """
    def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
        super().__init__()
        self.squeeze_out = squeeze_out
        self.graph_module = create_feature_extractor(model, return_nodes)

    def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
        out = list(self.graph_module(x).values())
        if self.squeeze_out and len(out) == 1:
            return out[0]
        return out

In [22]:
""" Model creation / weight loading / state_dict helpers

Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union

import torch
try:
    import safetensors.torch
    _has_safetensors = True
except ImportError:
    _has_safetensors = False

_logger = logging.getLogger(__name__)

__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']


def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
    # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
    cleaned_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        cleaned_state_dict[name] = v
    return cleaned_state_dict


def load_state_dict(
        checkpoint_path: str,
        use_ema: bool = True,
        device: Union[str, torch.device] = 'cpu',
) -> Dict[str, Any]:
    if checkpoint_path and os.path.isfile(checkpoint_path):
        # Check if safetensors or not and load weights accordingly
        if str(checkpoint_path).endswith(".safetensors"):
            assert _has_safetensors, "`pip install safetensors` to use .safetensors"
            checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
        else:
            checkpoint = torch.load(checkpoint_path, map_location=device)

        state_dict_key = ''
        if isinstance(checkpoint, dict):
            if use_ema and checkpoint.get('state_dict_ema', None) is not None:
                state_dict_key = 'state_dict_ema'
            elif use_ema and checkpoint.get('model_ema', None) is not None:
                state_dict_key = 'model_ema'
            elif 'state_dict' in checkpoint:
                state_dict_key = 'state_dict'
            elif 'model' in checkpoint:
                state_dict_key = 'model'
        state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
        _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
        return state_dict
    else:
        _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
        raise FileNotFoundError()


def load_checkpoint(
        model: torch.nn.Module,
        checkpoint_path: str,
        use_ema: bool = True,
        device: Union[str, torch.device] = 'cpu',
        strict: bool = True,
        remap: bool = False,
        filter_fn: Optional[Callable] = None,
):
    if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
        # numpy checkpoint, try to load via model specific load_pretrained fn
        if hasattr(model, 'load_pretrained'):
            model.load_pretrained(checkpoint_path)
        else:
            raise NotImplementedError('Model cannot load numpy checkpoint')
        return

    state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
    if remap:
        state_dict = remap_state_dict(state_dict, model)
    elif filter_fn:
        state_dict = filter_fn(state_dict, model)
    incompatible_keys = model.load_state_dict(state_dict, strict=strict)
    return incompatible_keys


def remap_state_dict(
        state_dict: Dict[str, Any],
        model: torch.nn.Module,
        allow_reshape: bool = True
):
    """ remap checkpoint by iterating over state dicts in order (ignoring original keys).
    This assumes models (and originating state dict) were created with params registered in same order.
    """
    out_dict = {}
    for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
        assert va.numel() == vb.numel(), f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
        if va.shape != vb.shape:
            if allow_reshape:
                vb = vb.reshape(va.shape)
            else:
                assert False,  f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
        out_dict[ka] = vb
    return out_dict


def resume_checkpoint(
        model: torch.nn.Module,
        checkpoint_path: str,
        optimizer: torch.optim.Optimizer = None,
        loss_scaler: Any = None,
        log_info: bool = True,
):
    resume_epoch = None
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            if log_info:
                _logger.info('Restoring model state from checkpoint...')
            state_dict = clean_state_dict(checkpoint['state_dict'])
            model.load_state_dict(state_dict)

            if optimizer is not None and 'optimizer' in checkpoint:
                if log_info:
                    _logger.info('Restoring optimizer state from checkpoint...')
                optimizer.load_state_dict(checkpoint['optimizer'])

            if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
                if log_info:
                    _logger.info('Restoring AMP loss scaler state from checkpoint...')
                loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])

            if 'epoch' in checkpoint:
                resume_epoch = checkpoint['epoch']
                if 'version' in checkpoint and checkpoint['version'] > 1:
                    resume_epoch += 1  # start at the next epoch, old checkpoints incremented before save

                if log_info:
                    _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
        else:
            model.load_state_dict(checkpoint)
            if log_info:
                _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
        return resume_epoch
    else:
        _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
        raise FileNotFoundError()

In [23]:
import hashlib
import json
import logging
import os
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union

import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse

try:
    from torch.hub import get_dir
except ImportError:
    from torch.hub import _get_torch_home as get_dir

try:
    import safetensors.torch
    _has_safetensors = True
except ImportError:
    _has_safetensors = False

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from timm import __version__
# from timm.models._pretrained import filter_pretrained_cfg

try:
    from huggingface_hub import (
        create_repo, get_hf_file_metadata,
        hf_hub_download, hf_hub_url,
        repo_type_and_id_from_hf_id, upload_folder)
    from huggingface_hub.utils import EntryNotFoundError
    hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
    _has_hf_hub = True
except ImportError:
    hf_hub_download = None
    _has_hf_hub = False

_logger = logging.getLogger(__name__)

__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
           'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']

# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "pytorch_model.bin"  # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors"  # safetensors version
HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin"  # default pytorch pkl
HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors"  # safetensors version


def get_cache_dir(child_dir=''):
    """
    Returns the location of the directory where models are cached (and creates it if necessary).
    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    hub_dir = get_dir()
    child_dir = () if not child_dir else (child_dir,)
    model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
    os.makedirs(model_dir, exist_ok=True)
    return model_dir


def download_cached_file(url, check_hash=True, progress=False):
    if isinstance(url, (list, tuple)):
        url, filename = url
    else:
        parts = urlparse(url)
        filename = os.path.basename(parts.path)
    cached_file = os.path.join(get_cache_dir(), filename)
    if not os.path.exists(cached_file):
        _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return cached_file


def check_cached_file(url, check_hash=True):
    if isinstance(url, (list, tuple)):
        url, filename = url
    else:
        parts = urlparse(url)
        filename = os.path.basename(parts.path)
    cached_file = os.path.join(get_cache_dir(), filename)
    if os.path.exists(cached_file):
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
            if hash_prefix:
                with open(cached_file, 'rb') as f:
                    hd = hashlib.sha256(f.read()).hexdigest()
                    if hd[:len(hash_prefix)] != hash_prefix:
                        return False
        return True
    return False


def has_hf_hub(necessary=False):
    if not _has_hf_hub and necessary:
        # if no HF Hub module installed, and it is necessary to continue, raise error
        raise RuntimeError(
            'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
    return _has_hf_hub


def hf_split(hf_id: str):
    # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
    rev_split = hf_id.split('@')
    assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
    hf_model_id = rev_split[0]
    hf_revision = rev_split[-1] if len(rev_split) > 1 else None
    return hf_model_id, hf_revision


def load_cfg_from_json(json_file: Union[str, os.PathLike]):
    with open(json_file, "r", encoding="utf-8") as reader:
        text = reader.read()
    return json.loads(text)


def download_from_hf(model_id: str, filename: str):
    hf_model_id, hf_revision = hf_split(model_id)
    return hf_hub_download(hf_model_id, filename, revision=hf_revision)


def load_model_config_from_hf(model_id: str):
    assert has_hf_hub(True)
    cached_file = download_from_hf(model_id, 'config.json')

    hf_config = load_cfg_from_json(cached_file)
    if 'pretrained_cfg' not in hf_config:
        # old form, pull pretrain_cfg out of the base dict
        pretrained_cfg = hf_config
        hf_config = {}
        hf_config['architecture'] = pretrained_cfg.pop('architecture')
        hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
        if 'labels' in pretrained_cfg:  # deprecated name for 'label_names'
            pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
        hf_config['pretrained_cfg'] = pretrained_cfg

    # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
    pretrained_cfg = hf_config['pretrained_cfg']
    pretrained_cfg['hf_hub_id'] = model_id  # insert hf_hub id for pretrained weight load during model creation
    pretrained_cfg['source'] = 'hf-hub'

    # model should be created with base config num_classes if its exist
    if 'num_classes' in hf_config:
        pretrained_cfg['num_classes'] = hf_config['num_classes']

    # label meta-data in base config overrides saved pretrained_cfg on load
    if 'label_names' in hf_config:
        pretrained_cfg['label_names'] = hf_config.pop('label_names')
    if 'label_descriptions' in hf_config:
        pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')

    model_args = hf_config.get('model_args', {})
    model_name = hf_config['architecture']
    return pretrained_cfg, model_name, model_args


def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
    assert has_hf_hub(True)
    hf_model_id, hf_revision = hf_split(model_id)

    # Look for .safetensors alternatives and load from it if it exists
    if _has_safetensors:
        for safe_filename in _get_safe_alternatives(filename):
            try:
                cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
                _logger.info(
                    f"[{model_id}] Safe alternative available for '{filename}' "
                    f"(as '{safe_filename}'). Loading weights using safetensors.")
                return safetensors.torch.load_file(cached_safe_file, device="cpu")
            except EntryNotFoundError:
                pass

    # Otherwise, load using pytorch.load
    cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
    _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
    return torch.load(cached_file, map_location='cpu')


def save_config_for_hf(
        model,
        config_path: str,
        model_config: Optional[dict] = None,
        model_args: Optional[dict] = None
):
    model_config = model_config or {}
    hf_config = {}
    pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
    # set some values at root config level
    hf_config['architecture'] = pretrained_cfg.pop('architecture')
    hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)

    # NOTE these attr saved for informational purposes, do not impact model build
    hf_config['num_features'] = model_config.pop('num_features', model.num_features)
    global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
    if isinstance(global_pool_type, str) and global_pool_type:
        hf_config['global_pool'] = global_pool_type

    # Save class label info
    if 'labels' in model_config:
        _logger.warning(
            "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
            " Renaming provided 'labels' field to 'label_names'.")
        model_config.setdefault('label_names', model_config.pop('labels'))

    label_names = model_config.pop('label_names', None)
    if label_names:
        assert isinstance(label_names, (dict, list, tuple))
        # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
        # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
        hf_config['label_names'] = label_names

    label_descriptions = model_config.pop('label_descriptions', None)
    if label_descriptions:
        assert isinstance(label_descriptions, dict)
        # maps label names -> descriptions
        hf_config['label_descriptions'] = label_descriptions

    if model_args:
        hf_config['model_args'] = model_args

    hf_config['pretrained_cfg'] = pretrained_cfg
    hf_config.update(model_config)

    with config_path.open('w') as f:
        json.dump(hf_config, f, indent=2)


def save_for_hf(
        model,
        save_directory: str,
        model_config: Optional[dict] = None,
        model_args: Optional[dict] = None,
        safe_serialization: Union[bool, Literal["both"]] = False,
):
    assert has_hf_hub(True)
    save_directory = Path(save_directory)
    save_directory.mkdir(exist_ok=True, parents=True)

    # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
    tensors = model.state_dict()
    if safe_serialization is True or safe_serialization == "both":
        assert _has_safetensors, "`pip install safetensors` to use .safetensors"
        safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
    if safe_serialization is False or safe_serialization == "both":
        torch.save(tensors, save_directory / HF_WEIGHTS_NAME)

    config_path = save_directory / 'config.json'
    save_config_for_hf(
        model,
        config_path,
        model_config=model_config,
        model_args=model_args,
    )


def push_to_hf_hub(
        model: torch.nn.Module,
        repo_id: str,
        commit_message: str = 'Add model',
        token: Optional[str] = None,
        revision: Optional[str] = None,
        private: bool = False,
        create_pr: bool = False,
        model_config: Optional[dict] = None,
        model_card: Optional[dict] = None,
        model_args: Optional[dict] = None,
        safe_serialization: Union[bool, Literal["both"]] = False,
):
    """
    Arguments:
        (...)
        safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
            Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
            Can be set to `"both"` in order to push both safe and unsafe weights.
    """
    # Create repo if it doesn't exist yet
    repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)

    # Infer complete repo_id from repo_url
    # Can be different from the input `repo_id` if repo_owner was implicit
    _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
    repo_id = f"{repo_owner}/{repo_name}"

    # Check if README file already exist in repo
    try:
        get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
        has_readme = True
    except EntryNotFoundError:
        has_readme = False

    # Dump model and push to Hub
    with TemporaryDirectory() as tmpdir:
        # Save model weights and config.
        save_for_hf(
            model,
            tmpdir,
            model_config=model_config,
            model_args=model_args,
            safe_serialization=safe_serialization,
        )

        # Add readme if it does not exist
        if not has_readme:
            model_card = model_card or {}
            model_name = repo_id.split('/')[-1]
            readme_path = Path(tmpdir) / "README.md"
            readme_text = generate_readme(model_card, model_name)
            readme_path.write_text(readme_text)

        # Upload model and return
        return upload_folder(
            repo_id=repo_id,
            folder_path=tmpdir,
            revision=revision,
            create_pr=create_pr,
            commit_message=commit_message,
        )


def generate_readme(model_card: dict, model_name: str):
    readme_text = "---\n"
    readme_text += "tags:\n- image-classification\n- timm\n"
    readme_text += "library_name: timm\n"
    readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
    if 'details' in model_card and 'Dataset' in model_card['details']:
        readme_text += 'datasets:\n'
        if isinstance(model_card['details']['Dataset'], (tuple, list)):
            for d in model_card['details']['Dataset']:
                readme_text += f"- {d.lower()}\n"
        else:
            readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
        if 'Pretrain Dataset' in model_card['details']:
            if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
                for d in model_card['details']['Pretrain Dataset']:
                    readme_text += f"- {d.lower()}\n"
            else:
                readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
    readme_text += "---\n"
    readme_text += f"# Model card for {model_name}\n"
    if 'description' in model_card:
        readme_text += f"\n{model_card['description']}\n"
    if 'details' in model_card:
        readme_text += f"\n## Model Details\n"
        for k, v in model_card['details'].items():
            if isinstance(v, (list, tuple)):
                readme_text += f"- **{k}:**\n"
                for vi in v:
                    readme_text += f"  - {vi}\n"
            elif isinstance(v, dict):
                readme_text += f"- **{k}:**\n"
                for ki, vi in v.items():
                    readme_text += f"  - {ki}: {vi}\n"
            else:
                readme_text += f"- **{k}:** {v}\n"
    if 'usage' in model_card:
        readme_text += f"\n## Model Usage\n"
        readme_text += model_card['usage']
        readme_text += '\n'

    if 'comparison' in model_card:
        readme_text += f"\n## Model Comparison\n"
        readme_text += model_card['comparison']
        readme_text += '\n'

    if 'citation' in model_card:
        readme_text += f"\n## Citation\n"
        if not isinstance(model_card['citation'], (list, tuple)):
            citations = [model_card['citation']]
        else:
            citations = model_card['citation']
        for c in citations:
            readme_text += f"```bibtex\n{c}\n```\n"
    return readme_text


def _get_safe_alternatives(filename: str) -> Iterable[str]:
    """Returns potential safetensors alternatives for a given filename.

    Use case:
        When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
        Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
    """
    if filename == HF_WEIGHTS_NAME:
        yield HF_SAFE_WEIGHTS_NAME
    if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
        yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
    if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
        yield filename[:-4] + ".safetensors"

In [24]:
import collections.abc
import math
import re
from collections import defaultdict
from itertools import chain
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union

import torch
from torch import nn as nn
from torch.utils.checkpoint import checkpoint

__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
           'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']


def model_parameters(model: nn.Module, exclude_head: bool = False):
    if exclude_head:
        # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
        return [p for p in model.parameters()][:-2]
    else:
        return model.parameters()


def named_apply(
        fn: Callable,
        module: nn.Module, name='',
        depth_first: bool = True,
        include_root: bool = False,
) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = '.'.join((name, child_name)) if name else child_name
        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


def named_modules(
        module: nn.Module,
        name: str = '',
        depth_first: bool = True,
        include_root: bool = False,
):
    if not depth_first and include_root:
        yield name, module
    for child_name, child_module in module.named_children():
        child_name = '.'.join((name, child_name)) if name else child_name
        yield from named_modules(
            module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        yield name, module


def named_modules_with_params(
        module: nn.Module,
        name: str = '',
        depth_first: bool = True,
        include_root: bool = False,
):
    if module._parameters and not depth_first and include_root:
        yield name, module
    for child_name, child_module in module.named_children():
        child_name = '.'.join((name, child_name)) if name else child_name
        yield from named_modules_with_params(
            module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if module._parameters and depth_first and include_root:
        yield name, module


MATCH_PREV_GROUP = (99999,)


def group_with_matcher(
        named_objects: Iterator[Tuple[str, Any]],
        group_matcher: Union[Dict, Callable],
        return_values: bool = False,
        reverse: bool = False
):
    if isinstance(group_matcher, dict):
        # dictionary matcher contains a dict of raw-string regex expr that must be compiled
        compiled = []
        for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
            if mspec is None:
                continue
            # map all matching specifications into 3-tuple (compiled re, prefix, suffix)
            if isinstance(mspec, (tuple, list)):
                # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
                for sspec in mspec:
                    compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
            else:
                compiled += [(re.compile(mspec), (group_ordinal,), None)]
        group_matcher = compiled

    def _get_grouping(name):
        if isinstance(group_matcher, (list, tuple)):
            for match_fn, prefix, suffix in group_matcher:
                r = match_fn.match(name)
                if r:
                    parts = (prefix, r.groups(), suffix)
                    # map all tuple elem to int for numeric sort, filter out None entries
                    return tuple(map(float, chain.from_iterable(filter(None, parts))))
            return float('inf'),  # un-matched layers (neck, head) mapped to largest ordinal
        else:
            ord = group_matcher(name)
            if not isinstance(ord, collections.abc.Iterable):
                return ord,
            return tuple(ord)

    # map layers into groups via ordinals (ints or tuples of ints) from matcher
    grouping = defaultdict(list)
    for k, v in named_objects:
        grouping[_get_grouping(k)].append(v if return_values else k)

    # remap to integers
    layer_id_to_param = defaultdict(list)
    lid = -1
    for k in sorted(filter(lambda x: x is not None, grouping.keys())):
        if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
            lid += 1
        layer_id_to_param[lid].extend(grouping[k])

    if reverse:
        assert not return_values, "reverse mapping only sensible for name output"
        # output reverse mapping
        param_to_layer_id = {}
        for lid, lm in layer_id_to_param.items():
            for n in lm:
                param_to_layer_id[n] = lid
        return param_to_layer_id

    return layer_id_to_param


def group_parameters(
        module: nn.Module,
        group_matcher,
        return_values: bool = False,
        reverse: bool = False,
):
    return group_with_matcher(
        module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)


def group_modules(
        module: nn.Module,
        group_matcher,
        return_values: bool = False,
        reverse: bool = False,
):
    return group_with_matcher(
        named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)


def flatten_modules(
        named_modules: Iterator[Tuple[str, nn.Module]],
        depth: int = 1,
        prefix: Union[str, Tuple[str, ...]] = '',
        module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
):
    prefix_is_tuple = isinstance(prefix, tuple)
    if isinstance(module_types, str):
        if module_types == 'container':
            module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
        else:
            module_types = (nn.Sequential,)
    for name, module in named_modules:
        if depth and isinstance(module, module_types):
            yield from flatten_modules(
                module.named_children(),
                depth - 1,
                prefix=(name,) if prefix_is_tuple else name,
                module_types=module_types,
            )
        else:
            if prefix_is_tuple:
                name = prefix + (name,)
                yield name, module
            else:
                if prefix:
                    name = '.'.join([prefix, name])
                yield name, module


def checkpoint_seq(
        functions,
        x,
        every=1,
        flatten=False,
        skip_last=False,
        preserve_rng_state=True
):
    r"""A helper function for checkpointing sequential models.

    Sequential models execute a list of modules/functions in order
    (sequentially). Therefore, we can divide such a sequence into segments
    and checkpoint each segment. All segments except run in :func:`torch.no_grad`
    manner, i.e., not storing the intermediate activations. The inputs of each
    checkpointed segment will be saved for re-running the segment in the backward pass.

    See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.

    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.

    .. warning:
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients.

    Args:
        functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
        x: A Tensor that is input to :attr:`functions`
        every: checkpoint every-n functions (default: 1)
        flatten (bool): flatten nn.Sequential of nn.Sequentials
        skip_last (bool): skip checkpointing the last function in the sequence if True
        preserve_rng_state (bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.

    Returns:
        Output of running :attr:`functions` sequentially on :attr:`*inputs`

    Example:
        >>> model = nn.Sequential(...)
        >>> input_var = checkpoint_seq(model, input_var, every=2)
    """
    def run_function(start, end, functions):
        def forward(_x):
            for j in range(start, end + 1):
                _x = functions[j](_x)
            return _x
        return forward

    if isinstance(functions, torch.nn.Sequential):
        functions = functions.children()
    if flatten:
        functions = chain.from_iterable(functions)
    if not isinstance(functions, (tuple, list)):
        functions = tuple(functions)

    num_checkpointed = len(functions)
    if skip_last:
        num_checkpointed -= 1
    end = -1
    for start in range(0, num_checkpointed, every):
        end = min(start + every - 1, num_checkpointed - 1)
        x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
    if skip_last:
        return run_function(end + 1, len(functions) - 1, functions)(x)
    return x


def adapt_input_conv(in_chans, conv_weight):
    conv_type = conv_weight.dtype
    conv_weight = conv_weight.float()  # Some weights are in torch.half, ensure it's float for sum on CPU
    O, I, J, K = conv_weight.shape
    if in_chans == 1:
        if I > 3:
            assert conv_weight.shape[1] % 3 == 0
            # For models with space2depth stems
            conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
            conv_weight = conv_weight.sum(dim=2, keepdim=False)
        else:
            conv_weight = conv_weight.sum(dim=1, keepdim=True)
    elif in_chans != 3:
        if I != 3:
            raise NotImplementedError('Weight format not supported by conversion.')
        else:
            # NOTE this strategy should be better than random init, but there could be other combinations of
            # the original RGB input layer weights that'd work better for specific cases.
            repeat = int(math.ceil(in_chans / 3))
            conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
            conv_weight *= (3 / float(in_chans))
    conv_weight = conv_weight.to(conv_type)
    return conv_weight

In [25]:
import copy
from collections import deque, defaultdict
from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union


__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']


@dataclass
class PretrainedCfg:
    """
    """
    # weight source locations
    url: Optional[Union[str, Tuple[str, str]]] = None  # remote URL
    file: Optional[str] = None  # local / shared filesystem path
    state_dict: Optional[Dict[str, Any]] = None  # in-memory state dict
    hf_hub_id: Optional[str] = None  # Hugging Face Hub model id ('organization/model')
    hf_hub_filename: Optional[str] = None  # Hugging Face Hub filename (overrides default)

    source: Optional[str] = None  # source of cfg / weight location used (url, file, hf-hub)
    architecture: Optional[str] = None  # architecture variant can be set when not implicit
    tag: Optional[str] = None  # pretrained tag of source
    custom_load: bool = False  # use custom model specific model.load_pretrained() (ie for npz files)

    # input / data config
    input_size: Tuple[int, int, int] = (3, 224, 224)
    test_input_size: Optional[Tuple[int, int, int]] = None
    min_input_size: Optional[Tuple[int, int, int]] = None
    fixed_input_size: bool = False
    interpolation: str = 'bicubic'
    crop_pct: float = 0.875
    test_crop_pct: Optional[float] = None
    crop_mode: str = 'center'
    mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
    std: Tuple[float, ...] = (0.229, 0.224, 0.225)

    # head / classifier config and meta-data
    num_classes: int = 1000
    label_offset: Optional[int] = None
    label_names: Optional[Tuple[str]] = None
    label_descriptions: Optional[Dict[str, str]] = None

    # model attributes that vary with above or required for pretrained adaptation
    pool_size: Optional[Tuple[int, ...]] = None
    test_pool_size: Optional[Tuple[int, ...]] = None
    first_conv: Optional[str] = None
    classifier: Optional[str] = None

    license: Optional[str] = None
    description: Optional[str] = None
    origin_url: Optional[str] = None
    paper_name: Optional[str] = None
    paper_ids: Optional[Union[str, Tuple[str]]] = None
    notes: Optional[Tuple[str]] = None

    @property
    def has_weights(self):
        return self.url or self.file or self.hf_hub_id

    def to_dict(self, remove_source=False, remove_null=True):
        return filter_pretrained_cfg(
            asdict(self),
            remove_source=remove_source,
            remove_null=remove_null
        )


def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
    filtered_cfg = {}
    keep_null = {'pool_size', 'first_conv', 'classifier'}  # always keep these keys, even if none
    for k, v in cfg.items():
        if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
            continue
        if remove_null and v is None and k not in keep_null:
            continue
        filtered_cfg[k] = v
    return filtered_cfg


@dataclass
class DefaultCfg:
    tags: Deque[str] = field(default_factory=deque)  # priority queue of tags (first is default)
    cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict)  # pretrained cfgs by tag
    is_pretrained: bool = False  # at least one of the configs has a pretrained source set

    @property
    def default(self):
        return self.cfgs[self.tags[0]]

    @property
    def default_with_tag(self):
        tag = self.tags[0]
        return tag, self.cfgs[tag]

In [26]:
import os
import pkgutil
from copy import deepcopy

from torch import nn as nn

from timm.layers import Conv2dSame, BatchNormAct2d, Linear

__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']


def extract_layer(model, layer):
    layer = layer.split('.')
    module = model
    if hasattr(model, 'module') and layer[0] != 'module':
        module = model.module
    if not hasattr(model, 'module') and layer[0] == 'module':
        layer = layer[1:]
    for l in layer:
        if hasattr(module, l):
            if not l.isdigit():
                module = getattr(module, l)
            else:
                module = module[int(l)]
        else:
            return module
    return module


def set_layer(model, layer, val):
    layer = layer.split('.')
    module = model
    if hasattr(model, 'module') and layer[0] != 'module':
        module = model.module
    lst_index = 0
    module2 = module
    for l in layer:
        if hasattr(module2, l):
            if not l.isdigit():
                module2 = getattr(module2, l)
            else:
                module2 = module2[int(l)]
            lst_index += 1
    lst_index -= 1
    for l in layer[:lst_index]:
        if not l.isdigit():
            module = getattr(module, l)
        else:
            module = module[int(l)]
    l = layer[lst_index]
    setattr(module, l, val)


def adapt_model_from_string(parent_module, model_string):
    separator = '***'
    state_dict = {}
    lst_shape = model_string.split(separator)
    for k in lst_shape:
        k = k.split(':')
        key = k[0]
        shape = k[1][1:-1].split(',')
        if shape[0] != '':
            state_dict[key] = [int(i) for i in shape]

    new_module = deepcopy(parent_module)
    for n, m in parent_module.named_modules():
        old_module = extract_layer(parent_module, n)
        if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
            if isinstance(old_module, Conv2dSame):
                conv = Conv2dSame
            else:
                conv = nn.Conv2d
            s = state_dict[n + '.weight']
            in_channels = s[1]
            out_channels = s[0]
            g = 1
            if old_module.groups > 1:
                in_channels = out_channels
                g = in_channels
            new_conv = conv(
                in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
                bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
                groups=g, stride=old_module.stride)
            set_layer(new_module, n, new_conv)
        elif isinstance(old_module, BatchNormAct2d):
            new_bn = BatchNormAct2d(
                state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
                affine=old_module.affine, track_running_stats=True)
            new_bn.drop = old_module.drop
            new_bn.act = old_module.act
            set_layer(new_module, n, new_bn)
        elif isinstance(old_module, nn.BatchNorm2d):
            new_bn = nn.BatchNorm2d(
                num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
                affine=old_module.affine, track_running_stats=True)
            set_layer(new_module, n, new_bn)
        elif isinstance(old_module, nn.Linear):
            # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
            num_features = state_dict[n + '.weight'][1]
            new_fc = Linear(
                in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
            set_layer(new_module, n, new_fc)
            if hasattr(new_module, 'num_features'):
                new_module.num_features = num_features
    new_module.eval()
    parent_module.eval()

    return new_module


def adapt_model_from_file(parent_module, model_variant):
    adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
    return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())

In [27]:
import dataclasses
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple

from torch import nn as nn
from torch.hub import load_state_dict_from_url

from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
from timm.models._registry import get_pretrained_cfg

_logger = logging.getLogger(__name__)

# Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0

__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
           'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']


def _resolve_pretrained_source(pretrained_cfg):
    cfg_source = pretrained_cfg.get('source', '')
    pretrained_url = pretrained_cfg.get('url', None)
    pretrained_file = pretrained_cfg.get('file', None)
    pretrained_sd = pretrained_cfg.get('state_dict', None)
    hf_hub_id = pretrained_cfg.get('hf_hub_id', None)

    # resolve where to load pretrained weights from
    load_from = ''
    pretrained_loc = ''
    if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
        # hf-hub specified as source via model identifier
        load_from = 'hf-hub'
        assert hf_hub_id
        pretrained_loc = hf_hub_id
    else:
        # default source == timm or unspecified
        if pretrained_sd:
            # direct state_dict pass through is the highest priority
            load_from = 'state_dict'
            pretrained_loc = pretrained_sd
            assert isinstance(pretrained_loc, dict)
        elif pretrained_file:
            # file load override is the second-highest priority if set
            load_from = 'file'
            pretrained_loc = pretrained_file
        else:
            old_cache_valid = False
            if _USE_OLD_CACHE:
                # prioritized old cached weights if exists and env var enabled
                old_cache_valid = check_cached_file(pretrained_url) if pretrained_url else False
            if not old_cache_valid and hf_hub_id and has_hf_hub(necessary=True):
                # hf-hub available as alternate weight source in default_cfg
                load_from = 'hf-hub'
                pretrained_loc = hf_hub_id
            elif pretrained_url:
                load_from = 'url'
                pretrained_loc = pretrained_url

    if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
        # if a filename override is set, return tuple for location w/ (hub_id, filename)
        pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
    return load_from, pretrained_loc


def set_pretrained_download_progress(enable=True):
    """ Set download progress for pretrained weights on/off (globally). """
    global _DOWNLOAD_PROGRESS
    _DOWNLOAD_PROGRESS = enable


def set_pretrained_check_hash(enable=True):
    """ Set hash checking for pretrained weights on/off (globally). """
    global _CHECK_HASH
    _CHECK_HASH = enable


def load_custom_pretrained(
        model: nn.Module,
        pretrained_cfg: Optional[Dict] = None,
        load_fn: Optional[Callable] = None,
):
    r"""Loads a custom (read non .pth) weight file

    Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
    a passed in custom load fun, or the `load_pretrained` model member fn.

    If the object is already present in `model_dir`, it's deserialized and returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        model: The instantiated model to load weights into
        pretrained_cfg (dict): Default pretrained model cfg
        load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
            'laod_pretrained' on the model will be called if it exists
    """
    pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
    if not pretrained_cfg:
        _logger.warning("Invalid pretrained config, cannot load weights.")
        return

    load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
    if not load_from:
        _logger.warning("No pretrained weights exist for this model. Using random initialization.")
        return
    if load_from == 'hf-hub':
        _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
    elif load_from == 'url':
        pretrained_loc = download_cached_file(
            pretrained_loc,
            check_hash=_CHECK_HASH,
            progress=_DOWNLOAD_PROGRESS,
        )

    if load_fn is not None:
        load_fn(model, pretrained_loc)
    elif hasattr(model, 'load_pretrained'):
        model.load_pretrained(pretrained_loc)
    else:
        _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")


def load_pretrained(
        model: nn.Module,
        pretrained_cfg: Optional[Dict] = None,
        num_classes: int = 1000,
        in_chans: int = 3,
        filter_fn: Optional[Callable] = None,
        strict: bool = True,
):
    """ Load pretrained checkpoint

    Args:
        model (nn.Module) : PyTorch model module
        pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
        num_classes (int): num_classes for target model
        in_chans (int): in_chans for target model
        filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
        strict (bool): strict load of checkpoint

    """
    pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
    if not pretrained_cfg:
        raise RuntimeError("Invalid pretrained config, cannot load weights. Use `pretrained=False` for random init.")

    load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
    if load_from == 'state_dict':
        _logger.info(f'Loading pretrained weights from state dict')
        state_dict = pretrained_loc  # pretrained_loc is the actual state dict for this override
    elif load_from == 'file':
        _logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
        if pretrained_cfg.get('custom_load', False):
            model.load_pretrained(pretrained_loc)
            return
        else:
            state_dict = load_state_dict(pretrained_loc)
    elif load_from == 'url':
        _logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
        if pretrained_cfg.get('custom_load', False):
            pretrained_loc = download_cached_file(
                pretrained_loc,
                progress=_DOWNLOAD_PROGRESS,
                check_hash=_CHECK_HASH,
            )
            model.load_pretrained(pretrained_loc)
            return
        else:
            state_dict = load_state_dict_from_url(
                pretrained_loc,
                map_location='cpu',
                progress=_DOWNLOAD_PROGRESS,
                check_hash=_CHECK_HASH,
            )
    elif load_from == 'hf-hub':
        _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
        if isinstance(pretrained_loc, (list, tuple)):
            state_dict = load_state_dict_from_hf(*pretrained_loc)
        else:
            state_dict = load_state_dict_from_hf(pretrained_loc)
    else:
        model_name = pretrained_cfg.get('architecture', 'this model')
        raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

    if filter_fn is not None:
        try:
            state_dict = filter_fn(state_dict, model)
        except TypeError as e:
            # for backwards compat with filter fn that take one arg
            state_dict = filter_fn(state_dict)

    input_convs = pretrained_cfg.get('first_conv', None)
    if input_convs is not None and in_chans != 3:
        if isinstance(input_convs, str):
            input_convs = (input_convs,)
        for input_conv_name in input_convs:
            weight_name = input_conv_name + '.weight'
            try:
                state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
                _logger.info(
                    f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
            except NotImplementedError as e:
                del state_dict[weight_name]
                strict = False
                _logger.warning(
                    f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')

    classifiers = pretrained_cfg.get('classifier', None)
    label_offset = pretrained_cfg.get('label_offset', 0)
    if classifiers is not None:
        if isinstance(classifiers, str):
            classifiers = (classifiers,)
        if num_classes != pretrained_cfg['num_classes']:
            for classifier_name in classifiers:
                # completely discard fully connected if model num_classes doesn't match pretrained weights
                state_dict.pop(classifier_name + '.weight', None)
                state_dict.pop(classifier_name + '.bias', None)
            strict = False
        elif label_offset > 0:
            for classifier_name in classifiers:
                # special case for pretrained weights with an extra background class in pretrained weights
                classifier_weight = state_dict[classifier_name + '.weight']
                state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
                classifier_bias = state_dict[classifier_name + '.bias']
                state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]

    load_result = model.load_state_dict(state_dict, strict=strict)
    if load_result.missing_keys:
        _logger.info(
            f'Missing keys ({", ".join(load_result.missing_keys)}) discovered while loading pretrained weights.'
            f' This is expected if model is being adapted.')
    if load_result.unexpected_keys:
        _logger.warning(
            f'Unexpected keys ({", ".join(load_result.unexpected_keys)}) found while loading pretrained weights.'
            f' This may be expected if model is being adapted.')


def pretrained_cfg_for_features(pretrained_cfg):
    pretrained_cfg = deepcopy(pretrained_cfg)
    # remove default pretrained cfg fields that don't have much relevance for feature backbone
    to_remove = ('num_classes', 'classifier', 'global_pool')  # add default final pool size?
    for tr in to_remove:
        pretrained_cfg.pop(tr, None)
    return pretrained_cfg


def _filter_kwargs(kwargs, names):
    if not kwargs or not names:
        return
    for n in names:
        kwargs.pop(n, None)


def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
    """ Update the default_cfg and kwargs before passing to model

    Args:
        pretrained_cfg: input pretrained cfg (updated in-place)
        kwargs: keyword args passed to model build fn (updated in-place)
        kwargs_filter: keyword arg keys that must be removed before model __init__
    """
    # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
    default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
    if pretrained_cfg.get('fixed_input_size', False):
        # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
        default_kwarg_names += ('img_size',)

    for n in default_kwarg_names:
        # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
        # pretrained_cfg has one input_size=(C, H ,W) entry
        if n == 'img_size':
            input_size = pretrained_cfg.get('input_size', None)
            if input_size is not None:
                assert len(input_size) == 3
                kwargs.setdefault(n, input_size[-2:])
        elif n == 'in_chans':
            input_size = pretrained_cfg.get('input_size', None)
            if input_size is not None:
                assert len(input_size) == 3
                kwargs.setdefault(n, input_size[0])
        elif n == 'num_classes':
            default_val = pretrained_cfg.get(n, None)
            # if default is < 0, don't pass through to model
            if default_val is not None and default_val >= 0:
                kwargs.setdefault(n, pretrained_cfg[n])
        else:
            default_val = pretrained_cfg.get(n, None)
            if default_val is not None:
                kwargs.setdefault(n, pretrained_cfg[n])

    # Filter keyword args for task specific model variants (some 'features only' models, etc.)
    _filter_kwargs(kwargs, names=kwargs_filter)


def resolve_pretrained_cfg(
        variant: str,
        pretrained_cfg=None,
        pretrained_cfg_overlay=None,
) -> PretrainedCfg:
    model_with_tag = variant
    pretrained_tag = None
    if pretrained_cfg:
        if isinstance(pretrained_cfg, dict):
            # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
            pretrained_cfg = PretrainedCfg(**pretrained_cfg)
        elif isinstance(pretrained_cfg, str):
            pretrained_tag = pretrained_cfg
            pretrained_cfg = None

    # fallback to looking up pretrained cfg in model registry by variant identifier
    if not pretrained_cfg:
        if pretrained_tag:
            model_with_tag = '.'.join([variant, pretrained_tag])
        pretrained_cfg = get_pretrained_cfg(model_with_tag)

    if not pretrained_cfg:
        _logger.warning(
            f"No pretrained configuration specified for {model_with_tag} model. Using a default."
            f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
        pretrained_cfg = PretrainedCfg()  # instance with defaults

    pretrained_cfg_overlay = pretrained_cfg_overlay or {}
    if not pretrained_cfg.architecture:
        pretrained_cfg_overlay.setdefault('architecture', variant)
    pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)

    return pretrained_cfg


def build_model_with_cfg(
        model_cls: Callable,
        variant: str,
        pretrained: bool,
        pretrained_cfg: Optional[Dict] = None,
        pretrained_cfg_overlay: Optional[Dict] = None,
        model_cfg: Optional[Any] = None,
        feature_cfg: Optional[Dict] = None,
        pretrained_strict: bool = True,
        pretrained_filter_fn: Optional[Callable] = None,
        kwargs_filter: Optional[Tuple[str]] = None,
        **kwargs,
):
    """ Build model with specified default_cfg and optional model_cfg

    This helper fn aids in the construction of a model including:
      * handling default_cfg and associated pretrained weight loading
      * passing through optional model_cfg for models with config based arch spec
      * features_only model adaptation
      * pruning config / model adaptation

    Args:
        model_cls (nn.Module): model class
        variant (str): model variant name
        pretrained (bool): load pretrained weights
        pretrained_cfg (dict): model's pretrained weight/task config
        model_cfg (Optional[Dict]): model's architecture config
        feature_cfg (Optional[Dict]: feature extraction adapter config
        pretrained_strict (bool): load pretrained weights strictly
        pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
        kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
        **kwargs: model args passed through to model __init__
    """
    pruned = kwargs.pop('pruned', False)
    features = False
    feature_cfg = feature_cfg or {}

    # resolve and update model pretrained config and model kwargs
    pretrained_cfg = resolve_pretrained_cfg(
        variant,
        pretrained_cfg=pretrained_cfg,
        pretrained_cfg_overlay=pretrained_cfg_overlay
    )

    # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
    pretrained_cfg = pretrained_cfg.to_dict()

    _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)

    # Setup for feature extraction wrapper done at end of this fn
    if kwargs.pop('features_only', False):
        features = True
        feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
        if 'out_indices' in kwargs:
            feature_cfg['out_indices'] = kwargs.pop('out_indices')

    # Instantiate the model
    if model_cfg is None:
        model = model_cls(**kwargs)
    else:
        model = model_cls(cfg=model_cfg, **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg  # alias for backwards compat

    if pruned:
        model = adapt_model_from_file(model, variant)

    # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
    num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
    if pretrained:
        load_pretrained(
            model,
            pretrained_cfg=pretrained_cfg,
            num_classes=num_classes_pretrained,
            in_chans=kwargs.get('in_chans', 3),
            filter_fn=pretrained_filter_fn,
            strict=pretrained_strict,
        )

    # Wrap the model in a feature extraction module if enabled
    if features:
        feature_cls = FeatureListNet
        output_fmt = getattr(model, 'output_fmt', None)
        if output_fmt is not None:
            feature_cfg.setdefault('output_fmt', output_fmt)
        if 'feature_cls' in feature_cfg:
            feature_cls = feature_cfg.pop('feature_cls')
            if isinstance(feature_cls, str):
                feature_cls = feature_cls.lower()
                if 'hook' in feature_cls:
                    feature_cls = FeatureHookNet
                elif feature_cls == 'fx':
                    feature_cls = FeatureGraphNet
                else:
                    assert False, f'Unknown feature class {feature_cls}'
        model = feature_cls(model, **feature_cfg)
        model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg)  # add back pretrained cfg
        model.default_cfg = model.pretrained_cfg  # alias for rename backwards compat (default_cfg -> pretrained_cfg)

    return model

In [28]:
""" Model Registry
Hacked together by / Copyright 2020 Ross Wightman
"""

import fnmatch
import re
import sys
import warnings
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple

# from ._pretrained import PretrainedCfg, DefaultCfg

__all__ = [
    'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
    'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
    'get_pretrained_cfg_value', 'is_model_pretrained'
]

_module_to_models: Dict[str, Set[str]] = defaultdict(set)  # dict of sets to check membership of model in module
_model_to_module: Dict[str, str] = {}  # mapping of model names to module names
_model_entrypoints: Dict[str, Callable[..., Any]] = {}  # mapping of model names to architecture entrypoint fns
_model_has_pretrained: Set[str] = set()  # set of model names that have pretrained weight url present
_model_default_cfgs: Dict[str, PretrainedCfg] = {}  # central repo for model arch -> default cfg objects
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {}  # central repo for model arch.tag -> pretrained cfgs
_model_with_tags: Dict[str, List[str]] = defaultdict(list)  # shortcut to map each model arch to all model + tag names
_module_to_deprecated_models: Dict[str, Dict[str, Optional[str]]] = defaultdict(dict)
_deprecated_models: Dict[str, Optional[str]] = {}


def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
    model_name, *tag_list = model_name.split('.', 1)
    tag = tag_list[0] if tag_list else no_tag
    return model_name, tag


def get_arch_name(model_name: str) -> str:
    return split_model_name_tag(model_name)[0]


def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
    out = defaultdict(DefaultCfg)
    default_set = set()  # no tag and tags ending with * are prioritized as default

    for k, v in cfgs.items():
        if isinstance(v, dict):
            v = PretrainedCfg(**v)
        has_weights = v.has_weights

        model, tag = split_model_name_tag(k)
        is_default_set = model in default_set
        priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
        tag = tag.strip('*')

        default_cfg = out[model]

        if priority:
            default_cfg.tags.appendleft(tag)
            default_set.add(model)
        elif has_weights and not default_cfg.is_pretrained:
            default_cfg.tags.appendleft(tag)
        else:
            default_cfg.tags.append(tag)

        if has_weights:
            default_cfg.is_pretrained = True

        default_cfg.cfgs[tag] = v

    return out


def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
    # lookup containing module
    mod = sys.modules[fn.__module__]
    module_name_split = fn.__module__.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''

    # add model to __all__ in module
    model_name = fn.__name__
    if hasattr(mod, '__all__'):
        mod.__all__.append(model_name)
    else:
        mod.__all__ = [model_name]  # type: ignore

    # add entries to registry dict/sets
    if model_name in _model_entrypoints:
        warnings.warn(
            f'Overwriting {model_name} in registry with {fn.__module__}.{model_name}. This is because the name being '
            'registered conflicts with an existing name. Please check if this is not expected.',
            stacklevel=2,
        )
    _model_entrypoints[model_name] = fn
    _model_to_module[model_name] = module_name
    _module_to_models[module_name].add(model_name)
    if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
        # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
        # entrypoints or non-matching combos
        default_cfg = mod.default_cfgs[model_name]
        if not isinstance(default_cfg, DefaultCfg):
            # new style default cfg dataclass w/ multiple entries per model-arch
            assert isinstance(default_cfg, dict)
            # old style cfg dict per model-arch
            pretrained_cfg = PretrainedCfg(**default_cfg)
            default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})

        for tag_idx, tag in enumerate(default_cfg.tags):
            is_default = tag_idx == 0
            pretrained_cfg = default_cfg.cfgs[tag]
            model_name_tag = '.'.join([model_name, tag]) if tag else model_name
            replace_items = dict(architecture=model_name, tag=tag if tag else None)
            if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
                # auto-complete hub name w/ architecture.tag
                replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
            pretrained_cfg = replace(pretrained_cfg, **replace_items)

            if is_default:
                _model_pretrained_cfgs[model_name] = pretrained_cfg
                if pretrained_cfg.has_weights:
                    # add tagless entry if it's default and has weights
                    _model_has_pretrained.add(model_name)

            if tag:
                _model_pretrained_cfgs[model_name_tag] = pretrained_cfg
                if pretrained_cfg.has_weights:
                    # add model w/ tag if tag is valid
                    _model_has_pretrained.add(model_name_tag)
                _model_with_tags[model_name].append(model_name_tag)
            else:
                _model_with_tags[model_name].append(model_name)  # has empty tag (to slowly remove these instances)

        _model_default_cfgs[model_name] = default_cfg

    return fn


def _deprecated_model_shim(deprecated_name: str, current_fn: Callable = None, current_tag: str = ''):
    def _fn(pretrained=False, **kwargs):
        assert current_fn is not None,  f'Model {deprecated_name} has been removed with no replacement.'
        current_name = '.'.join([current_fn.__name__, current_tag]) if current_tag else current_fn.__name__
        warnings.warn(f'Mapping deprecated model name {deprecated_name} to current {current_name}.', stacklevel=2)
        pretrained_cfg = kwargs.pop('pretrained_cfg', None)
        return current_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg or current_tag, **kwargs)
    return _fn


def register_model_deprecations(module_name: str, deprecation_map: Dict[str, Optional[str]]):
    mod = sys.modules[module_name]
    module_name_split = module_name.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''

    for deprecated, current in deprecation_map.items():
        if hasattr(mod, '__all__'):
            mod.__all__.append(deprecated)
        current_fn = None
        current_tag = ''
        if current:
            current_name, current_tag = split_model_name_tag(current)
            current_fn = getattr(mod, current_name)
        deprecated_entrypoint_fn = _deprecated_model_shim(deprecated, current_fn, current_tag)
        setattr(mod, deprecated, deprecated_entrypoint_fn)
        _model_entrypoints[deprecated] = deprecated_entrypoint_fn
        _model_to_module[deprecated] = module_name
        _module_to_models[module_name].add(deprecated)
        _deprecated_models[deprecated] = current
        _module_to_deprecated_models[module_name][deprecated] = current


def _natural_key(string_: str) -> List[Union[int, str]]:
    """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def _expand_filter(filter: str):
    """ expand a 'base_filter' to 'base_filter.*' if no tag portion"""
    filter_base, filter_tag = split_model_name_tag(filter)
    if not filter_tag:
        return ['.'.join([filter_base, '*']), filter]
    else:
        return [filter]


def list_models(
        filter: Union[str, List[str]] = '',
        module: str = '',
        pretrained: bool = False,
        exclude_filters: Union[str, List[str]] = '',
        name_matches_cfg: bool = False,
        include_tags: Optional[bool] = None,
) -> List[str]:
    """ Return list of available model names, sorted alphabetically

    Args:
        filter - Wildcard filter string that works with fnmatch
        module - Limit model selection to a specific submodule (ie 'vision_transformer')
        pretrained - Include only models with valid pretrained weights if True
        exclude_filters - Wildcard filters to exclude models after including them with filter
        name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
        include_tags - Include pretrained tags in model names (model.tag). If None, defaults
            set to True when pretrained=True else False (default: None)

    Returns:
        models - The sorted list of models

    Example:
        model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
        model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
    """
    if filter:
        include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
    else:
        include_filters = []

    if include_tags is None:
        # FIXME should this be default behaviour? or default to include_tags=True?
        include_tags = pretrained

    all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
    all_models = all_models - _deprecated_models.keys()  # remove deprecated models from listings

    if include_tags:
        # expand model names to include names w/ pretrained tags
        models_with_tags: Set[str] = set()
        for m in all_models:
            models_with_tags.update(_model_with_tags[m])
        all_models = models_with_tags
        # expand include and exclude filters to include a '.*' for proper match if no tags in filter
        include_filters = [ef for f in include_filters for ef in _expand_filter(f)]
        exclude_filters = [ef for f in exclude_filters for ef in _expand_filter(f)]

    if include_filters:
        models: Set[str] = set()
        for f in include_filters:
            include_models = fnmatch.filter(all_models, f)  # include these models
            if len(include_models):
                models = models.union(include_models)
    else:
        models = all_models

    if exclude_filters:
        if not isinstance(exclude_filters, (tuple, list)):
            exclude_filters = [exclude_filters]
        for xf in exclude_filters:
            exclude_models = fnmatch.filter(models, xf)  # exclude these models
            if len(exclude_models):
                models = models.difference(exclude_models)

    if pretrained:
        models = _model_has_pretrained.intersection(models)

    if name_matches_cfg:
        models = set(_model_pretrained_cfgs).intersection(models)

    return sorted(models, key=_natural_key)


def list_pretrained(
        filter: Union[str, List[str]] = '',
        exclude_filters: str = '',
) -> List[str]:
    return list_models(
        filter=filter,
        pretrained=True,
        exclude_filters=exclude_filters,
        include_tags=True,
    )


def get_deprecated_models(module: str = '') -> Dict[str, str]:
    all_deprecated = _module_to_deprecated_models[module] if module else _deprecated_models
    return deepcopy(all_deprecated)


def is_model(model_name: str) -> bool:
    """ Check if a model name exists
    """
    arch_name = get_arch_name(model_name)
    return arch_name in _model_entrypoints


def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
    """Fetch a model entrypoint for specified model name
    """
    arch_name = get_arch_name(model_name)
    if module_filter and arch_name not in _module_to_models.get(module_filter, {}):
        raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.')
    return _model_entrypoints[arch_name]


def list_modules() -> List[str]:
    """ Return list of module names that contain models / model entrypoints
    """
    modules = _module_to_models.keys()
    return sorted(modules)


def is_model_in_modules(
        model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
    """Check if a model exists within a subset of modules

    Args:
        model_name - name of model to check
        module_names - names of modules to search in
    """
    arch_name = get_arch_name(model_name)
    assert isinstance(module_names, (tuple, list, set))
    return any(arch_name in _module_to_models[n] for n in module_names)


def is_model_pretrained(model_name: str) -> bool:
    return model_name in _model_has_pretrained


def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
    if model_name in _model_pretrained_cfgs:
        return deepcopy(_model_pretrained_cfgs[model_name])
    arch_name, tag = split_model_name_tag(model_name)
    if arch_name in _model_default_cfgs:
        # if model arch exists, but the tag is wrong, error out
        raise RuntimeError(f'Invalid pretrained tag ({tag}) for {arch_name}.')
    if allow_unregistered:
        # if model arch doesn't exist, it has no pretrained_cfg registered, allow a default to be created
        return None
    raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')


def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
    """ Get a specific model default_cfg value by key. None if key doesn't exist.
    """
    cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
    return getattr(cfg, cfg_key, None)

In [29]:
""" DaViT: Dual Attention Vision Transformers

As described in https://arxiv.org/abs/2204.03645

Input size invariant transformer architecture that combines channel and spacial
attention in each block. The attention mechanisms used are linear in complexity.

DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below

"""
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import NormMlpClassifierHead, ClassifierHead, RelPosBias, get_attn
# from ._builder import build_model_with_cfg
# from ._features_fx import register_notrace_function
# from ._manipulate import checkpoint_seq
# from ._registry import generate_default_cfgs, register_model

__all__ = ['DaVit']


class ConvPosEnc(nn.Module):
    def __init__(self, dim: int, k: int = 3, act: bool = False):
        super(ConvPosEnc, self).__init__()

        self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
        self.act = nn.GELU() if act else nn.Identity()

    def forward(self, x: Tensor):
        feat = self.proj(x)
        x = x + self.act(feat)
        return x


class Stem(nn.Module):
    """ Size-agnostic implementation of 2D image to patch embedding,
        allowing input size to be adjusted during model forward operation
    """

    def __init__(
            self,
            in_chs=3,
            out_chs=96,
            stride=4,
            norm_layer=LayerNorm2d,
    ):
        super().__init__()
        stride = to_2tuple(stride)
        self.stride = stride
        self.in_chs = in_chs
        self.out_chs = out_chs
        assert stride[0] == 4  # only setup for stride==4
        self.conv = nn.Conv2d(
            in_chs,
            out_chs,
            kernel_size=7,
            stride=stride,
            padding=3,
        )
        self.norm = norm_layer(out_chs)

    def forward(self, x: Tensor):
        B, C, H, W = x.shape
        x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
        x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
        x = self.conv(x)
        x = self.norm(x)
        return x


class Downsample(nn.Module):
    def __init__(
            self,
            in_chs,
            out_chs,
            norm_layer=LayerNorm2d,
    ):
        super().__init__()
        self.in_chs = in_chs
        self.out_chs = out_chs

        self.norm = norm_layer(in_chs)
        self.conv = nn.Conv2d(
            in_chs,
            out_chs,
            kernel_size=2,
            stride=2,
            padding=0,
        )

    def forward(self, x: Tensor):
        B, C, H, W = x.shape
        x = self.norm(x)
        x = F.pad(x, (0, (2 - W % 2) % 2))
        x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
        x = self.conv(x)
        return x


class ChannelAttention(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: Tensor):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        k = k * self.scale
        attention = k.transpose(-1, -2) @ v
        attention = attention.softmax(dim=-1)
        x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x


class ChannelBlock(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            ffn=True,
            cpe_act=False,
    ):
        super().__init__()

        self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
        self.ffn = ffn
        self.norm1 = norm_layer(dim)
        self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)

        if self.ffn:
            self.norm2 = norm_layer(dim)
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=int(dim * mlp_ratio),
                act_layer=act_layer,
            )
            self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        else:
            self.norm2 = None
            self.mlp = None
            self.drop_path2 = None

    def forward(self, x: Tensor):
        B, C, H, W = x.shape

        x = self.cpe1(x).flatten(2).transpose(1, 2)

        cur = self.norm1(x)
        cur = self.attn(cur)
        x = x + self.drop_path1(cur)

        x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))

        if self.mlp is not None:
            x = x.flatten(2).transpose(1, 2)
            x = x + self.drop_path2(self.mlp(self.norm2(x)))
            x = x.transpose(1, 2).view(B, C, H, W)

        return x


def window_partition(x: Tensor, window_size: Tuple[int, int]):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
    return windows


@register_notrace_function  # reason: int argument is a Proxy
def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    C = windows.shape[-1]
    x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
    """
    fused_attn: torch.jit.Final[bool]

    def __init__(self, dim, window_size, num_heads, qkv_bias=True):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()
        self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#         self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        
        self.proj = nn.Linear(dim, dim)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x: Tensor, q_global: Optional[torch.Tensor] = None):
        B_, N, C = x.shape
#         print(x.shape, "x.shape")
#         qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#         q, k, v = qkv.unbind(0)

#         if self.fused_attn:
#             x = F.scaled_dot_product_attention(q, k, v)
#         else:
#             q = q * self.scale
#             attn = (q @ k.transpose(-2, -1))
#             attn = self.softmax(attn)
#             x = attn @ v
#         print(q.shape, k.shape, v.shape, x.shape, "q.shape, k.shape, v.shape, x.shape")    
#         q_global = torch.rand(1, 7, 7, 128 )
#         print(x.shape, global_query.shape)
#         q_global = global_query
#         print(x.shape, q_global.shape, global_query.shape)
        kv = self.qkv(x)
#         print(kv.shape, "kv.shape")
        kv = kv.reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        _, k, v = kv.unbind(0)
#         print(k.shape, "k shape", v.shape, "v shape")
        q = q_global.repeat(B_ // q_global.shape[0], 1, 1, 1)
#         print(q.shape, "q shape", B_, N, self.num_heads, self.head_dim)
        q = q.reshape(B_, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
#         print(q.shape, "q.shape after reshape")  
        
        q = q * self.scale

        attn = q @ k.transpose(-2, -1).contiguous()  # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
        attn = self.rel_pos(attn)
        attn = attn.softmax(dim=-1)
#         attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
#         x = self.proj_drop(x)
        
#         x = x.transpose(1, 2).reshape(B_, N, C)
#         x = self.proj(x)
        return x

# global_query = torch.rand(1, 7, 7, 128 )
class MbConvBlock(nn.Module):
    """ A depthwise separable / fused mbconv style residual block with SE, `no norm.
    """
    def __init__(
            self,
            in_chs,
            out_chs=None,
            expand_ratio=1.0,
            attn_layer='se',
            bias=False,
            act_layer=nn.GELU,
    ):
        super().__init__()
        attn_kwargs = dict(act_layer=act_layer)
        if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca':
            attn_kwargs['rd_ratio'] = 0.25
            attn_kwargs['bias'] = False
        attn_layer = get_attn(attn_layer)
        out_chs = out_chs or in_chs
        mid_chs = int(expand_ratio * in_chs)
#         print(in_chs, out_chs, mid_chs, "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
        self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias)
        self.act = act_layer()
        self.se = attn_layer(mid_chs, **attn_kwargs)
        self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias)

    def forward(self, x):
        shortcut = x
#         print(x.shape, "MBConvBlock.................................")
        x = self.conv_dw(x)
        
        x = self.act(x)
        x = self.se(x)
        x = self.conv_pw(x)
        x = x + shortcut
        return x

class FeatureBlock(nn.Module):
    def __init__(
            self,
            dim,
            levels=0,
            reduction='max',
            act_layer=nn.GELU,
    ):
        super().__init__()
        reductions = levels
        levels = max(1, levels)
        if reduction == 'avg':
            pool_fn = partial(nn.AvgPool2d, kernel_size=2)
        else:
            pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
        self.blocks = nn.Sequential()
        for i in range(levels):
            self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer))
            if reductions:
                self.blocks.add_module(f'pool{i+1}', pool_fn())
                reductions -= 1

    def forward(self, x):
#         print(x.shape, "................................................................................")
        return self.blocks(x)

class SpatialBlock(nn.Module):
    r""" Windows Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(
            self,
            dim,
            num_heads,
            window_size=7,
            mlp_ratio=4.,
            qkv_bias=True,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            ffn=True,
            cpe_act=False,
    ):
        super().__init__()
        self.dim = dim
        self.ffn = ffn
        self.num_heads = num_heads
        self.window_size = to_2tuple(window_size)
        self.mlp_ratio = mlp_ratio

        self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim,
            self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
        if self.ffn:
            self.norm2 = norm_layer(dim)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
            )
            self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        else:
            self.norm2 = None
            self.mlp = None
            self.drop_path1 = None

    def forward(self, x: Tensor, q_global: Optional[torch.Tensor] = None):
        B, C, H, W = x.shape
#         print(x.shape, "In spatial block")
        shortcut = self.cpe1(x).flatten(2).transpose(1, 2)

        x = self.norm1(shortcut)
        x = x.view(B, H, W, C)

        pad_l = pad_t = 0
        pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
        pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, q_global)

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
        x = window_reverse(attn_windows, self.window_size, Hp, Wp)

        # if pad_r > 0 or pad_b > 0:
        x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path1(x)

        x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))

        if self.mlp is not None:
            x = x.flatten(2).transpose(1, 2)
            x = x + self.drop_path2(self.mlp(self.norm2(x)))
            x = x.transpose(1, 2).view(B, C, H, W)
#         print(x.shape, "In spatial block")
        return x


class DaVitStage(nn.Module):
    def __init__(
            self,
            in_chs,
            out_chs,
            dim,
            feat_size: Tuple[int, int],
            depth=1,
            downsample=True,
            attn_types=('spatial', 'channel'),
            num_heads=3,
            window_size=7,
            mlp_ratio=4,
            qkv_bias=True,
            drop_path_rates=(0, 0),
            norm_layer=LayerNorm2d,
            norm_layer_cl=nn.LayerNorm,
            ffn=True,
            cpe_act=False,
            global_norm= False
    ):
        super().__init__()

        self.grad_checkpointing = False

        # downsample embedding layer at the beginning of each stage
        if downsample:
            self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
            dim = dim * 2
            feat_size = (feat_size[0] // 2, feat_size[1] // 2)
        else:
            self.downsample = nn.Identity()

        '''
         repeating alternating attention blocks in each stage
         default: (spatial -> channel) x depth
         
         potential opportunity to integrate with a more general version of ByobNet/ByoaNet
         since the logic is similar
        '''
        window_size = to_2tuple(window_size)
        self.feat_size = feat_size
#         print(feat_size, window_size, "feat_size, window_size,")
        feat_levels = int(math.log2(min(feat_size) / min(window_size)))
        self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
#         print(dim, feat_levels, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        
        self.global_block = FeatureBlock(dim, feat_levels)
        stage_blocks = []
        self.spatialBlock = SpatialBlock(
                        dim=out_chs,
                        num_heads=num_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop_path=drop_path_rates[0],
                        norm_layer=norm_layer_cl,
                        ffn=ffn,
                        cpe_act=cpe_act,
                        window_size=window_size,
                    )
        self.channelBlock = ChannelBlock(
                        dim=out_chs,
                        num_heads=num_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop_path=drop_path_rates[0],
                        norm_layer=norm_layer_cl,
                        ffn=ffn,
                        cpe_act=cpe_act
                    )
        for block_idx in range(depth):
            dual_attention_block = []
            for attn_idx, attn_type in enumerate(attn_types):
                if attn_type == 'spatial':
                    dual_attention_block.append(SpatialBlock(
                        dim=out_chs,
                        num_heads=num_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop_path=drop_path_rates[block_idx],
                        norm_layer=norm_layer_cl,
                        ffn=ffn,
                        cpe_act=cpe_act,
                        window_size=window_size,
                    ))
                elif attn_type == 'channel':
                    dual_attention_block.append(ChannelBlock(
                        dim=out_chs,
                        num_heads=num_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop_path=drop_path_rates[block_idx],
                        norm_layer=norm_layer_cl,
                        ffn=ffn,
                        cpe_act=cpe_act
                    ))
            stage_blocks.append(nn.Sequential(*dual_attention_block))
        self.blocks = nn.Sequential(*stage_blocks)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    def forward(self, x: Tensor):
#         print(x.shape)
        x = self.downsample(x)
#         print(x.shape, "after downsample")
        global_query = self.global_block(x)

        # reshape NCHW --> NHWC for transformer blocks
#         x = x.permute(0, 2, 3, 1)
        global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
#         print(global_query.shape, "global_query.shape")
        
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.spatialBlock(x, global_query)
            x = self.channelBlock(x)
#             x = self.blocks(x)
        return x


class DaVit(nn.Module):
    r""" DaViT
        A PyTorch implementation of `DaViT: Dual Attention Vision Transformers`  - https://arxiv.org/abs/2204.03645
        Supports arbitrary input sizes and pyramid feature extraction
        
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
        embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
        num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
    """

    def __init__(
            self,
            in_chans=3,
            depths=(1, 1, 3, 1),
            embed_dim: int = 128,
            embed_dims=(96, 192, 384, 768),
            num_heads=(3, 6, 12, 24),
            img_size: Tuple[int, int] = 224,
            window_size=7,
            mlp_ratio=4,
            qkv_bias=True,
            norm_layer='layernorm2d',
            norm_layer_cl='layernorm',
            norm_eps=1e-5,
            attn_types=('spatial', 'channel'),
            ffn=True,
            cpe_act=False,
            drop_rate=0.,
            drop_path_rate=0.,
            num_classes=1000,
            global_pool='avg',
            head_norm_first=False,
    ):
        super().__init__()
        num_stages = len(embed_dims)
        assert num_stages == len(num_heads) == len(depths)
        norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
        norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
        self.num_classes = num_classes
        self.num_features = embed_dims[-1]
        self.drop_rate = drop_rate
        self.grad_checkpointing = False
        self.feature_info = []
        img_size = to_2tuple(img_size)
        feat_size = tuple(d // 4 for d in img_size)  # stem reduction by 4

        self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
        in_chs = embed_dims[0]

        dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
        stages = []
        for stage_idx in range(num_stages):
            out_chs = embed_dims[stage_idx]
            
            stage_scale = 2 ** max(stage_idx - 1, 0)
            dim=embed_dim * stage_scale
#             print(feat_size[0], feat_size[1], embed_dim, stage_scale, feat_size[0] // stage_scale, feat_size[1] // stage_scale, ":::::::::::::::::::::::::::::::::::::::")
            
            stage = DaVitStage(
                in_chs,
                out_chs,
                dim=dim,
                feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale),
                depth=depths[stage_idx],
                downsample=stage_idx > 0,
                attn_types=attn_types,
                num_heads=num_heads[stage_idx],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop_path_rates=dpr[stage_idx],
                norm_layer=norm_layer,
                norm_layer_cl=norm_layer_cl,
                ffn=ffn,
                cpe_act=cpe_act,
            )
            in_chs = out_chs
            stages.append(stage)
            self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]

        self.stages = nn.Sequential(*stages)

        # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
        # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
        # FIXME generalize this structure to ClassifierHead
        if head_norm_first:
            self.norm_pre = norm_layer(self.num_features)
            self.head = ClassifierHead(
                self.num_features,
                num_classes,
                pool_type=global_pool,
                drop_rate=self.drop_rate,
            )
        else:
            self.norm_pre = nn.Identity()
            self.head = NormMlpClassifierHead(
                self.num_features,
                num_classes,
                pool_type=global_pool,
                drop_rate=self.drop_rate,
                norm_layer=norm_layer,
            )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable
        for stage in self.stages:
            stage.set_grad_checkpointing(enable=enable)

    @torch.jit.ignore
    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool=None):
        self.head.reset(num_classes, global_pool=global_pool)

    def forward_features(self, x):
        x = self.stem(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.stages, x)
        else:
            x = self.stages(x)
        x = self.norm_pre(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        return self.head(x, pre_logits=True) if pre_logits else self.head(x)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def checkpoint_filter_fn(state_dict, model):
    """ Remap MSFT checkpoints -> timm """
    if 'head.fc.weight' in state_dict:
        return state_dict  # non-MSFT checkpoint

    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']

    import re
    out_dict = {}
    for k, v in state_dict.items():
        k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
        k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
        k = k.replace('downsample.proj', 'downsample.conv')
        k = k.replace('stages.0.downsample', 'stem')
        k = k.replace('head.', 'head.fc.')
        k = k.replace('norms.', 'head.norm.')
        k = k.replace('cpe.0', 'cpe1')
        k = k.replace('cpe.1', 'cpe2')
        out_dict[k] = v
    return out_dict


def _create_davit(variant, pretrained=False, **kwargs):
    default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
    out_indices = kwargs.pop('out_indices', default_out_indices)

    model = build_model_with_cfg(
        DaVit,
        variant,
        pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
        **kwargs)

    return model


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 0.95, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'first_conv': 'stem.conv', 'classifier': 'head.fc',
        **kwargs
    }


# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
    # official microsoft weights from https://github.com/dingmyu/davit
    'davit_tiny.msft_in1k': _cfg(
        hf_hub_id='timm/'),
    'davit_small.msft_in1k': _cfg(
        hf_hub_id='timm/'),
    'davit_base.msft_in1k': _cfg(
        hf_hub_id='timm/'),
    'davit_large': _cfg(),
    'davit_huge': _cfg(),
    'davit_giant': _cfg(),
})


@register_model
def davit_tiny(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
    return _create_davit('davit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_small(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
    return _create_davit('davit_small', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_base(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32))
    return _create_davit('davit_base', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_large(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48))
    return _create_davit('davit_large', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_huge(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64))
    return _create_davit('davit_huge', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_giant(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
    return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))

In [30]:
# with fine-tuning option
class GWA_DaViT(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(GWA_DaViT, self).__init__()
        
        self.davit = davit_base(pretrained=pretrained, num_classes=num_classes, in_chans=1)
        
        if not fine_tune:
            # Freeze all layers except classifier layers
            for param in self.davit.parameters():
                param.requires_grad = False

            # Unfreeze the classifier layers
            for param in self.davit.head.parameters():
                param.requires_grad = True
        
        
    def forward(self, x):
        return self.davit(x)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GWA_DaViT(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits    
display_params_flops(model)


Model output's shape: torch.Size([1, 15])
tensor([[ 0.2911,  0.0299,  0.1014,  0.2131, -0.0794, -0.2590, -0.0380,  0.0067,
          0.1378,  0.0958,  0.2344, -0.5217, -0.1359, -0.4020, -0.1893]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 122.78 M
Number of trainable parameters in millions: 0.02 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
FLOPs: 5.53G, Params: 38.57M


### Proposed 3 DaViT(BASE) + Unetr (v2)

In [31]:
!pip install monai



In [32]:
import timm
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock


class DaViT_UnetR_Modelv2(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(DaViT_UnetR_Modelv2, self).__init__()
        
        self.davit = timm.create_model('davit_base.msft_in1k', pretrained=pretrained, features_only=True, in_chans=1)
        
        if not fine_tune:
            for param in self.davit.parameters():
                param.requires_grad = False
        
        
        spatial_dims = 2 
        in_channels = 1 # R,G,B
        feature_size = 128
        norm_name = "instance"
        hidden_size = 128
        res_block = True
        conv_block = False

        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*2,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*4,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size * 8,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        
        self.conv = nn.Sequential(
            nn.Conv2d(feature_size, 78, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(78, 50, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # Classifier layer with convolution
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2450, 1024),  # (DYNAMIC)Adjust the input size based on the output size of the convolutional layer
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    
    def forward(self, x_in):
        
        hidden_states_out = self.davit(x_in) # returns 4 lists
#         print("Length of hidden states from DaViT:", len(hidden_states_out))
#         for i in hidden_states_out:
#             print(i.shape)
#         print()


        enc1 = self.encoder1(x_in)
#         print("output from encoder1:", enc1.shape)
        
        x2 = hidden_states_out[0]
        enc2 = self.encoder2(x2)
#         print("output from encoder2:", enc2.shape)
        
        x3 = hidden_states_out[1]
        enc3 = self.encoder3(x3)
#         print("output from encoder3:", enc3.shape)
        
        
        x4 = hidden_states_out[2]
        enc4 = self.encoder4(x4)
#         print("output from encoder4:", enc4.shape)
        
#         print("All encoders OK\n")
        
        dec4 = hidden_states_out[3]
#         print("Input to decoder5:", dec4.shape, enc4.shape)
        dec3 = self.decoder5(dec4, enc4)
#         print("output from decoder5:", dec3.shape)
        
#         print("Input to decoder4:", dec3.shape, enc3.shape)
        dec2 = self.decoder4(dec3, enc3)
#         print("output from decoder4:", dec2.shape)
        
#         print("Input to decoder3:", dec2.shape, enc2.shape)
        dec1 = self.decoder3(dec2, enc2)
#         print("output from decoder3:", dec1.shape)
        
#         print("Input to decoder2:", dec1.shape, enc1.shape)
        out = self.decoder2(dec1, enc1) 
#         print("output from decoder2:", out.shape)
        

        
        conv_out = self.conv(out)
#         print(f"conv_out_shape:{conv_out.shape}")

        return self.classifier(conv_out)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DaViT_UnetR_Modelv2(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)


Model output's shape: torch.Size([1, 15])
tensor([[ 0.0553,  0.1761,  0.0403, -0.2128,  0.0069,  0.2055,  0.0741, -0.0338,
          0.1973,  0.2444,  0.2451,  0.0858,  0.0332,  0.0037, -0.0713]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 138.13 M
Number of trainable parameters in millions: 51.21 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[IN

### Proposed Model 4: Swin_Unetr

In [33]:
!pip install monai



In [34]:
import timm
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock


class Swin_Unetr(nn.Module):
    def __init__(self, num_classes, pretrained=True, fine_tune=False):
        super(Swin_Unetr, self).__init__()
        
        self.swin = timm.create_model(
            'swin_large_patch4_window7_224.ms_in22k', 
            pretrained=True,
            num_classes=num_classes,
            in_chans=1,
            features_only=True
        )
        
        
        if not fine_tune:
            for param in self.swin.parameters():
                param.requires_grad = False
        
        
        spatial_dims = 2 
        in_channels = 1 # R,G,B
        feature_size = 192
        norm_name = "instance"
        hidden_size = 192
        res_block = True
        conv_block = False

        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*2,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size*4,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=1,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size * 8,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        
        self.conv = nn.Sequential(
            nn.Conv2d(feature_size, 78, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(78, 50, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # Classifier layer with convolution
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2450, 1024),  # (DYNAMIC)Adjust the input size based on the output size of the convolutional layer
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    
    def forward(self, x_in):
        
        hidden_states_out = self.swin(x_in) # returns 4 lists
#         print("Length of hidden states from Swin:", len(hidden_states_out))
#         for i in hidden_states_out:
#             print(i.shape)
#         print()
        
        ##we will permute each of the intermediate outputs from swin stages so that the channel comes first

        hidden_states_out = [t.permute(0, 3, 1, 2) for t in hidden_states_out]
        
        
        enc1 = self.encoder1(x_in)
#         print("output from encoder1:", enc1.shape)
        
        x2 = hidden_states_out[0]
        enc2 = self.encoder2(x2)
#         print("output from encoder2:", enc2.shape)
        
        x3 = hidden_states_out[1]
        enc3 = self.encoder3(x3)
#         print("output from encoder3:", enc3.shape)
        
        
        x4 = hidden_states_out[2]
        enc4 = self.encoder4(x4)
#         print("output from encoder4:", enc4.shape)
        
#         print("All encoders OK\n")
        
        dec4 = hidden_states_out[3]
#         print("Input to decoder5:", dec4.shape, enc4.shape)
        dec3 = self.decoder5(dec4, enc4)
#         print("output from decoder5:", dec3.shape)
        
#         print("Input to decoder4:", dec3.shape, enc3.shape)
        dec2 = self.decoder4(dec3, enc3)
#         print("output from decoder4:", dec2.shape)
        
#         print("Input to decoder3:", dec2.shape, enc2.shape)
        dec1 = self.decoder3(dec2, enc2)
#         print("output from decoder3:", dec1.shape)
        
#         print("Input to decoder2:", dec1.shape, enc1.shape)
        out = self.decoder2(dec1, enc1) 
#         print("output from decoder2:", out.shape)
        

        
        conv_out = self.conv(out)
#         print(f"conv_out_shape:{conv_out.shape}")

        return self.classifier(conv_out)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Swin_Unetr(num_classes, fine_tune=False)
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)


Model output's shape: torch.Size([1, 15])
tensor([[ 0.0679, -0.2037,  0.1124,  0.1723,  0.1938, -0.1984, -0.0166, -0.1033,
          0.1193,  0.0890,  0.0989, -0.1060, -0.0095, -0.1039,  0.0462]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 306.95 M
Number of trainable parameters in millions: 111.96 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNo

### Proposed Model 5: CoAtNet Multiscale Pyramidal Attention 

In [35]:
!pip install monai



In [36]:
# _features.py

""" PyTorch Feature Extraction Helpers

A collection of classes, functions, modules to help extract features from models
and provide a common interface for describing them.

The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py

Hacked together by / Copyright 2020 Ross Wightman
"""
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from timm.layers import Format


__all__ = [
    'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
    'feature_take_indices'
]


def _take_indices(
        num_blocks: int,
        n: Optional[Union[int, List[int], Tuple[int]]],
) -> Tuple[Set[int], int]:
    if isinstance(n, int):
        assert n >= 0
        take_indices = {x for x in range(num_blocks - n, num_blocks)}
    else:
        take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
    return take_indices, max(take_indices)


def _take_indices_jit(
        num_blocks: int,
        n: Union[int, List[int], Tuple[int]],
) -> Tuple[List[int], int]:
    if isinstance(n, int):
        assert n >= 0
        take_indices = [num_blocks - n + i for i in range(n)]
    elif isinstance(n, tuple):
        # splitting this up is silly, but needed for torchscript type resolution of n
        take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
    else:
        take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
    return take_indices, max(take_indices)


def feature_take_indices(
        num_blocks: int,
        indices: Optional[Union[int, List[int], Tuple[int]]] = None,
) -> Tuple[List[int], int]:
    if indices is None:
        indices = num_blocks  # all blocks if None
    if torch.jit.is_scripting():
        return _take_indices_jit(num_blocks, indices)
    else:
        # NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
        return _take_indices(num_blocks, indices)


def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
    if isinstance(x, int):
        # if indices is an int, take last N features
        return tuple(range(-x, 0))
    return tuple(x)


OutIndicesT = Union[int, Tuple[int, ...]]


class FeatureInfo:

    def __init__(
            self,
            feature_info: List[Dict],
            out_indices: OutIndicesT,
    ):
        out_indices = _out_indices_as_tuple(out_indices)
        prev_reduction = 1
        for i, fi in enumerate(feature_info):
            # sanity check the mandatory fields, there may be additional fields depending on the model
            assert 'num_chs' in fi and fi['num_chs'] > 0
            assert 'reduction' in fi and fi['reduction'] >= prev_reduction
            prev_reduction = fi['reduction']
            assert 'module' in fi
            fi.setdefault('index', i)
        self.out_indices = out_indices
        self.info = feature_info

    def from_other(self, out_indices: OutIndicesT):
        out_indices = _out_indices_as_tuple(out_indices)
        return FeatureInfo(deepcopy(self.info), out_indices)

    def get(self, key: str, idx: Optional[Union[int, List[int]]] = None):
        """ Get value by key at specified index (indices)
        if idx == None, returns value for key at each output index
        if idx is an integer, return value for that feature module index (ignoring output indices)
        if idx is a list/tuple, return value for each module index (ignoring output indices)
        """
        if idx is None:
            return [self.info[i][key] for i in self.out_indices]
        if isinstance(idx, (tuple, list)):
            return [self.info[i][key] for i in idx]
        else:
            return self.info[idx][key]

    def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None):
        """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
        """
        if idx is None:
            if keys is None:
                return [self.info[i] for i in self.out_indices]
            else:
                return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
        if isinstance(idx, (tuple, list)):
            return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
        else:
            return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}

    def channels(self, idx: Optional[Union[int, List[int]]] = None):
        """ feature channels accessor
        """
        return self.get('num_chs', idx)

    def reduction(self, idx: Optional[Union[int, List[int]]] = None):
        """ feature reduction (output stride) accessor
        """
        return self.get('reduction', idx)

    def module_name(self, idx: Optional[Union[int, List[int]]] = None):
        """ feature module name accessor
        """
        return self.get('module', idx)

    def __getitem__(self, item):
        return self.info[item]

    def __len__(self):
        return len(self.info)


class FeatureHooks:
    """ Feature Hook Helper

    This module helps with the setup and extraction of hooks for extracting features from
    internal nodes in a model by node name.

    FIXME This works well in eager Python but needs redesign for torchscript.
    """

    def __init__(
            self,
            hooks: Sequence[str],
            named_modules: dict,
            out_map: Sequence[Union[int, str]] = None,
            default_hook_type: str = 'forward',
    ):
        # setup feature hooks
        self._feature_outputs = defaultdict(OrderedDict)
        modules = {k: v for k, v in named_modules}
        for i, h in enumerate(hooks):
            hook_name = h['module']
            m = modules[hook_name]
            hook_id = out_map[i] if out_map else hook_name
            hook_fn = partial(self._collect_output_hook, hook_id)
            hook_type = h.get('hook_type', default_hook_type)
            if hook_type == 'forward_pre':
                m.register_forward_pre_hook(hook_fn)
            elif hook_type == 'forward':
                m.register_forward_hook(hook_fn)
            else:
                assert False, "Unsupported hook type"

    def _collect_output_hook(self, hook_id, *args):
        x = args[-1]  # tensor we want is last argument, output for fwd, input for fwd_pre
        if isinstance(x, tuple):
            x = x[0]  # unwrap input tuple
        self._feature_outputs[x.device][hook_id] = x

    def get_output(self, device) -> Dict[str, torch.tensor]:
        output = self._feature_outputs[device]
        self._feature_outputs[device] = OrderedDict()  # clear after reading
        return output


def _module_list(module, flatten_sequential=False):
    # a yield/iter would be better for this but wouldn't be compatible with torchscript
    ml = []
    for name, module in module.named_children():
        if flatten_sequential and isinstance(module, nn.Sequential):
            # first level of Sequential containers is flattened into containing model
            for child_name, child_module in module.named_children():
                combined = [name, child_name]
                ml.append(('_'.join(combined), '.'.join(combined), child_module))
        else:
            ml.append((name, name, module))
    return ml


def _get_feature_info(net, out_indices: OutIndicesT):
    feature_info = getattr(net, 'feature_info')
    if isinstance(feature_info, FeatureInfo):
        return feature_info.from_other(out_indices)
    elif isinstance(feature_info, (list, tuple)):
        return FeatureInfo(net.feature_info, out_indices)
    else:
        assert False, "Provided feature_info is not valid"


def _get_return_layers(feature_info, out_map):
    module_names = feature_info.module_name()
    return_layers = {}
    for i, name in enumerate(module_names):
        return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
    return return_layers


class FeatureDictNet(nn.ModuleDict):
    """ Feature extractor with OrderedDict return

    Wrap a model and extract features as specified by the out indices, the network is
    partially re-built from contained modules.

    There is a strong assumption that the modules have been registered into the model in the same
    order as they are used. There should be no reuse of the same nn.Module more than once, including
    trivial modules like `self.relu = nn.ReLU`.

    Only submodules that are directly assigned to the model class (`model.feature1`) or at most
    one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
    All Sequential containers that are directly assigned to the original model will have their
    modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: OutIndicesT = (0, 1, 2, 3, 4),
            out_map: Sequence[Union[int, str]] = None,
            output_fmt: str = 'NCHW',
            feature_concat: bool = False,
            flatten_sequential: bool = False,
    ):
        """
        Args:
            model: Model from which to extract features.
            out_indices: Output indices of the model features to extract.
            out_map: Return id mapping for each output index, otherwise str(index) is used.
            feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
                first element e.g. `x[0]`
            flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
        """
        super(FeatureDictNet, self).__init__()
        self.feature_info = _get_feature_info(model, out_indices)
        self.output_fmt = Format(output_fmt)
        self.concat = feature_concat
        self.grad_checkpointing = False
        self.return_layers = {}

        return_layers = _get_return_layers(self.feature_info, out_map)
        modules = _module_list(model, flatten_sequential=flatten_sequential)
        remaining = set(return_layers.keys())
        layers = OrderedDict()
        for new_name, old_name, module in modules:
            layers[new_name] = module
            if old_name in remaining:
                # return id has to be consistently str type for torchscript
                self.return_layers[new_name] = str(return_layers[old_name])
                remaining.remove(old_name)
            if not remaining:
                break
        assert not remaining and len(self.return_layers) == len(return_layers), \
            f'Return layers ({remaining}) are not present in model'
        self.update(layers)

    def set_grad_checkpointing(self, enable: bool = True):
        self.grad_checkpointing = enable

    def _collect(self, x) -> (Dict[str, torch.Tensor]):
        out = OrderedDict()
        for i, (name, module) in enumerate(self.items()):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # Skipping checkpoint of first module because need a gradient at input
                # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
                # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
                first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
                x = module(x) if first_or_last_module else checkpoint(module, x)
            else:
                x = module(x)

            if name in self.return_layers:
                out_id = self.return_layers[name]
                if isinstance(x, (tuple, list)):
                    # If model tap is a tuple or list, concat or select first element
                    # FIXME this may need to be more generic / flexible for some nets
                    out[out_id] = torch.cat(x, 1) if self.concat else x[0]
                else:
                    out[out_id] = x
        return out

    def forward(self, x) -> Dict[str, torch.Tensor]:
        return self._collect(x)


class FeatureListNet(FeatureDictNet):
    """ Feature extractor with list return

    A specialization of FeatureDictNet that always returns features as a list (values() of dict).
    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: OutIndicesT = (0, 1, 2, 3, 4),
            output_fmt: str = 'NCHW',
            feature_concat: bool = False,
            flatten_sequential: bool = False,
    ):
        """
        Args:
            model: Model from which to extract features.
            out_indices: Output indices of the model features to extract.
            feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
                first element e.g. `x[0]`
            flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
        """
        super().__init__(
            model,
            out_indices=out_indices,
            output_fmt=output_fmt,
            feature_concat=feature_concat,
            flatten_sequential=flatten_sequential,
        )

    def forward(self, x) -> (List[torch.Tensor]):
        return list(self._collect(x).values())


class FeatureHookNet(nn.ModuleDict):
    """ FeatureHookNet

    Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.

    If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
    network in any way.

    If `no_rewrite` is False, the model will be re-written as in the
    FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.

    FIXME this does not currently work with Torchscript, see FeatureHooks class
    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: OutIndicesT = (0, 1, 2, 3, 4),
            out_map: Optional[Sequence[Union[int, str]]] = None,
            return_dict: bool = False,
            output_fmt: str = 'NCHW',
            no_rewrite: bool = False,
            flatten_sequential: bool = False,
            default_hook_type: str = 'forward',
    ):
        """

        Args:
            model: Model from which to extract features.
            out_indices: Output indices of the model features to extract.
            out_map: Return id mapping for each output index, otherwise str(index) is used.
            return_dict: Output features as a dict.
            no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
                flatten_sequential arg must also be False if this is set True.
            flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
            default_hook_type: The default hook type to use if not specified in model.feature_info.
        """
        super().__init__()
        assert not torch.jit.is_scripting()
        self.feature_info = _get_feature_info(model, out_indices)
        self.return_dict = return_dict
        self.output_fmt = Format(output_fmt)
        self.grad_checkpointing = False

        layers = OrderedDict()
        hooks = []
        if no_rewrite:
            assert not flatten_sequential
            if hasattr(model, 'reset_classifier'):  # make sure classifier is removed?
                model.reset_classifier(0)
            layers['body'] = model
            hooks.extend(self.feature_info.get_dicts())
        else:
            modules = _module_list(model, flatten_sequential=flatten_sequential)
            remaining = {
                f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
                for f in self.feature_info.get_dicts()
            }
            for new_name, old_name, module in modules:
                layers[new_name] = module
                for fn, fm in module.named_modules(prefix=old_name):
                    if fn in remaining:
                        hooks.append(dict(module=fn, hook_type=remaining[fn]))
                        del remaining[fn]
                if not remaining:
                    break
            assert not remaining, f'Return layers ({remaining}) are not present in model'
        self.update(layers)
        self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)

    def set_grad_checkpointing(self, enable: bool = True):
        self.grad_checkpointing = enable

    def forward(self, x):
        for i, (name, module) in enumerate(self.items()):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # Skipping checkpoint of first module because need a gradient at input
                # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
                # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
                first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
                x = module(x) if first_or_last_module else checkpoint(module, x)
            else:
                x = module(x)
        out = self.hooks.get_output(x.device)
        return out if self.return_dict else list(out.values())


class FeatureGetterNet(nn.ModuleDict):
    """ FeatureGetterNet

    Wrap models with a feature getter method, like 'get_intermediate_layers'

    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: OutIndicesT = 4,
            out_map: Optional[Sequence[Union[int, str]]] = None,
            return_dict: bool = False,
            output_fmt: str = 'NCHW',
            norm: bool = False,
            prune: bool = True,
    ):
        """

        Args:
            model: Model to wrap.
            out_indices: Indices of features to extract.
            out_map: Remap feature names for dict output (WIP, not supported).
            return_dict: Return features as dictionary instead of list (WIP, not supported).
            norm: Apply final model norm to all output features (if possible).
        """
        super().__init__()
        if prune and hasattr(model, 'prune_intermediate_layers'):
            # replace out_indices after they've been normalized, -ve indices will be invalid after prune
            out_indices = model.prune_intermediate_layers(
                out_indices,
                prune_norm=not norm,
            )
            out_indices = list(out_indices)
        self.feature_info = _get_feature_info(model, out_indices)
        self.model = model
        self.out_indices = out_indices
        self.out_map = out_map
        self.return_dict = return_dict
        self.output_fmt = output_fmt
        self.norm = norm

    def forward(self, x):
        features = self.model.forward_intermediates(
            x,
            indices=self.out_indices,
            norm=self.norm,
            output_fmt=self.output_fmt,
            intermediates_only=True,
        )
        return features

In [37]:
# _features_fx.py

""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, Dict, List, Optional, Union, Tuple, Type

import torch
from torch import nn

# from ._features import _get_feature_info, _get_return_layers

try:
    from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
    has_fx_feature_extraction = True
except ImportError:
    has_fx_feature_extraction = False

# Layers we went to treat as leaf modules
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers.norm_act import (
    BatchNormAct2d,
    SyncBatchNormAct,
    FrozenBatchNormAct2d,
    GroupNormAct,
    GroupNorm1Act,
    LayerNormAct,
    LayerNormAct2d
)

__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
           'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
           'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']


# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
    BilinearAttnTransform,  # reason: flow control t <= 1
    # Reason: get_same_padding has a max which raises a control flow error
    Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
    CondConv2d,  # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]),
    BatchNormAct2d,
    SyncBatchNormAct,
    FrozenBatchNormAct2d,
    GroupNormAct,
    GroupNorm1Act,
    LayerNormAct,
    LayerNormAct2d,
}

try:
    from timm.layers import InplaceAbn
    _leaf_modules.add(InplaceAbn)
except ImportError:
    pass


def register_notrace_module(module: Type[nn.Module]):
    """
    Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
    """
    _leaf_modules.add(module)
    return module


def is_notrace_module(module: Type[nn.Module]):
    return module in _leaf_modules


def get_notrace_modules():
    return list(_leaf_modules)


# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()


def register_notrace_function(func: Callable):
    """
    Decorator for functions which ought not to be traced through
    """
    _autowrap_functions.add(func)
    return func


def is_notrace_function(func: Callable):
    return func in _autowrap_functions


def get_notrace_functions():
    return list(_autowrap_functions)


def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
    assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
    return _create_feature_extractor(
        model, return_nodes,
        tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
    )


class FeatureGraphNet(nn.Module):
    """ A FX Graph based feature extractor that works with the model feature_info metadata
    """
    def __init__(
            self,
            model: nn.Module,
            out_indices: Tuple[int, ...],
            out_map: Optional[Dict] = None,
    ):
        super().__init__()
        assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
        self.feature_info = _get_feature_info(model, out_indices)
        if out_map is not None:
            assert len(out_map) == len(out_indices)
        return_nodes = _get_return_layers(self.feature_info, out_map)
        self.graph_module = create_feature_extractor(model, return_nodes)

    def forward(self, x):
        return list(self.graph_module(x).values())


class GraphExtractNet(nn.Module):
    """ A standalone feature extraction wrapper that maps dict -> list or single tensor
    NOTE:
      * one can use feature_extractor directly if dictionary output is desired
      * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
      metadata for builtin feature extraction mode
      * create_feature_extractor can be used directly if dictionary output is acceptable

    Args:
        model: model to extract features from
        return_nodes: node names to return features from (dict or list)
        squeeze_out: if only one output, and output in list format, flatten to single tensor
    """
    def __init__(
            self,
            model: nn.Module,
            return_nodes: Union[Dict[str, str], List[str]],
            squeeze_out: bool = True,
    ):
        super().__init__()
        self.squeeze_out = squeeze_out
        self.graph_module = create_feature_extractor(model, return_nodes)

    def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
        out = list(self.graph_module(x).values())
        if self.squeeze_out and len(out) == 1:
            return out[0]
        return out

In [38]:
# ._builder.py


import dataclasses
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple

from torch import nn as nn
from torch.hub import load_state_dict_from_url

# from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
from timm.models._registry import get_pretrained_cfg

_logger = logging.getLogger(__name__)

# Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0

__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
           'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']


def _resolve_pretrained_source(pretrained_cfg):
    cfg_source = pretrained_cfg.get('source', '')
    pretrained_url = pretrained_cfg.get('url', None)
    pretrained_file = pretrained_cfg.get('file', None)
    pretrained_sd = pretrained_cfg.get('state_dict', None)
    hf_hub_id = pretrained_cfg.get('hf_hub_id', None)

    # resolve where to load pretrained weights from
    load_from = ''
    pretrained_loc = ''
    if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
        # hf-hub specified as source via model identifier
        load_from = 'hf-hub'
        assert hf_hub_id
        pretrained_loc = hf_hub_id
    else:
        # default source == timm or unspecified
        if pretrained_sd:
            # direct state_dict pass through is the highest priority
            load_from = 'state_dict'
            pretrained_loc = pretrained_sd
            assert isinstance(pretrained_loc, dict)
        elif pretrained_file:
            # file load override is the second-highest priority if set
            load_from = 'file'
            pretrained_loc = pretrained_file
        else:
            old_cache_valid = False
            if _USE_OLD_CACHE:
                # prioritized old cached weights if exists and env var enabled
                old_cache_valid = check_cached_file(pretrained_url) if pretrained_url else False
            if not old_cache_valid and hf_hub_id and has_hf_hub(necessary=True):
                # hf-hub available as alternate weight source in default_cfg
                load_from = 'hf-hub'
                pretrained_loc = hf_hub_id
            elif pretrained_url:
                load_from = 'url'
                pretrained_loc = pretrained_url

    if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
        # if a filename override is set, return tuple for location w/ (hub_id, filename)
        pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
    return load_from, pretrained_loc


def set_pretrained_download_progress(enable=True):
    """ Set download progress for pretrained weights on/off (globally). """
    global _DOWNLOAD_PROGRESS
    _DOWNLOAD_PROGRESS = enable


def set_pretrained_check_hash(enable=True):
    """ Set hash checking for pretrained weights on/off (globally). """
    global _CHECK_HASH
    _CHECK_HASH = enable


def load_custom_pretrained(
        model: nn.Module,
        pretrained_cfg: Optional[Dict] = None,
        load_fn: Optional[Callable] = None,
):
    r"""Loads a custom (read non .pth) weight file

    Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
    a passed in custom load fun, or the `load_pretrained` model member fn.

    If the object is already present in `model_dir`, it's deserialized and returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        model: The instantiated model to load weights into
        pretrained_cfg (dict): Default pretrained model cfg
        load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
            'laod_pretrained' on the model will be called if it exists
    """
    pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
    if not pretrained_cfg:
        _logger.warning("Invalid pretrained config, cannot load weights.")
        return

    load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
    if not load_from:
        _logger.warning("No pretrained weights exist for this model. Using random initialization.")
        return
    if load_from == 'hf-hub':
        _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
    elif load_from == 'url':
        pretrained_loc = download_cached_file(
            pretrained_loc,
            check_hash=_CHECK_HASH,
            progress=_DOWNLOAD_PROGRESS,
        )

    if load_fn is not None:
        load_fn(model, pretrained_loc)
    elif hasattr(model, 'load_pretrained'):
        model.load_pretrained(pretrained_loc)
    else:
        _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")


def load_pretrained(
        model: nn.Module,
        pretrained_cfg: Optional[Dict] = None,
        num_classes: int = 1000,
        in_chans: int = 3,
        filter_fn: Optional[Callable] = None,
        strict: bool = True,
):
    """ Load pretrained checkpoint

    Args:
        model (nn.Module) : PyTorch model module
        pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
        num_classes (int): num_classes for target model
        in_chans (int): in_chans for target model
        filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
        strict (bool): strict load of checkpoint

    """
    pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
    if not pretrained_cfg:
        raise RuntimeError("Invalid pretrained config, cannot load weights. Use `pretrained=False` for random init.")

    load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
    if load_from == 'state_dict':
        _logger.info(f'Loading pretrained weights from state dict')
        state_dict = pretrained_loc  # pretrained_loc is the actual state dict for this override
    elif load_from == 'file':
        _logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
        if pretrained_cfg.get('custom_load', False):
            model.load_pretrained(pretrained_loc)
            return
        else:
            state_dict = load_state_dict(pretrained_loc)
    elif load_from == 'url':
        _logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
        if pretrained_cfg.get('custom_load', False):
            pretrained_loc = download_cached_file(
                pretrained_loc,
                progress=_DOWNLOAD_PROGRESS,
                check_hash=_CHECK_HASH,
            )
            model.load_pretrained(pretrained_loc)
            return
        else:
            state_dict = load_state_dict_from_url(
                pretrained_loc,
                map_location='cpu',
                progress=_DOWNLOAD_PROGRESS,
                check_hash=_CHECK_HASH,
            )
    elif load_from == 'hf-hub':
        _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
        if isinstance(pretrained_loc, (list, tuple)):
            state_dict = load_state_dict_from_hf(*pretrained_loc)
        else:
            state_dict = load_state_dict_from_hf(pretrained_loc)
    else:
        model_name = pretrained_cfg.get('architecture', 'this model')
        raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

    if filter_fn is not None:
        try:
            state_dict = filter_fn(state_dict, model)
        except TypeError as e:
            # for backwards compat with filter fn that take one arg
            state_dict = filter_fn(state_dict)

    input_convs = pretrained_cfg.get('first_conv', None)
    if input_convs is not None and in_chans != 3:
        if isinstance(input_convs, str):
            input_convs = (input_convs,)
        for input_conv_name in input_convs:
            weight_name = input_conv_name + '.weight'
            try:
                state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
                _logger.info(
                    f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
            except NotImplementedError as e:
                del state_dict[weight_name]
                strict = False
                _logger.warning(
                    f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')

    classifiers = pretrained_cfg.get('classifier', None)
    label_offset = pretrained_cfg.get('label_offset', 0)
    if classifiers is not None:
        if isinstance(classifiers, str):
            classifiers = (classifiers,)
        if num_classes != pretrained_cfg['num_classes']:
            for classifier_name in classifiers:
                # completely discard fully connected if model num_classes doesn't match pretrained weights
                state_dict.pop(classifier_name + '.weight', None)
                state_dict.pop(classifier_name + '.bias', None)
            strict = False
        elif label_offset > 0:
            for classifier_name in classifiers:
                # special case for pretrained weights with an extra background class in pretrained weights
                classifier_weight = state_dict[classifier_name + '.weight']
                state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
                classifier_bias = state_dict[classifier_name + '.bias']
                state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]

    load_result = model.load_state_dict(state_dict, strict=strict)
    if load_result.missing_keys:
        _logger.info(
            f'Missing keys ({", ".join(load_result.missing_keys)}) discovered while loading pretrained weights.'
            f' This is expected if model is being adapted.')
    if load_result.unexpected_keys:
        _logger.warning(
            f'Unexpected keys ({", ".join(load_result.unexpected_keys)}) found while loading pretrained weights.'
            f' This may be expected if model is being adapted.')


def pretrained_cfg_for_features(pretrained_cfg):
    pretrained_cfg = deepcopy(pretrained_cfg)
    # remove default pretrained cfg fields that don't have much relevance for feature backbone
    to_remove = ('num_classes', 'classifier', 'global_pool')  # add default final pool size?
    for tr in to_remove:
        pretrained_cfg.pop(tr, None)
    return pretrained_cfg


def _filter_kwargs(kwargs, names):
    if not kwargs or not names:
        return
    for n in names:
        kwargs.pop(n, None)


def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
    """ Update the default_cfg and kwargs before passing to model

    Args:
        pretrained_cfg: input pretrained cfg (updated in-place)
        kwargs: keyword args passed to model build fn (updated in-place)
        kwargs_filter: keyword arg keys that must be removed before model __init__
    """
    # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
    default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
    if pretrained_cfg.get('fixed_input_size', False):
        # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
        default_kwarg_names += ('img_size',)

    for n in default_kwarg_names:
        # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
        # pretrained_cfg has one input_size=(C, H ,W) entry
        if n == 'img_size':
            input_size = pretrained_cfg.get('input_size', None)
            if input_size is not None:
                assert len(input_size) == 3
                kwargs.setdefault(n, input_size[-2:])
        elif n == 'in_chans':
            input_size = pretrained_cfg.get('input_size', None)
            if input_size is not None:
                assert len(input_size) == 3
                kwargs.setdefault(n, input_size[0])
        elif n == 'num_classes':
            default_val = pretrained_cfg.get(n, None)
            # if default is < 0, don't pass through to model
            if default_val is not None and default_val >= 0:
                kwargs.setdefault(n, pretrained_cfg[n])
        else:
            default_val = pretrained_cfg.get(n, None)
            if default_val is not None:
                kwargs.setdefault(n, pretrained_cfg[n])

    # Filter keyword args for task specific model variants (some 'features only' models, etc.)
    _filter_kwargs(kwargs, names=kwargs_filter)


def resolve_pretrained_cfg(
        variant: str,
        pretrained_cfg=None,
        pretrained_cfg_overlay=None,
) -> PretrainedCfg:
    model_with_tag = variant
    pretrained_tag = None
    if pretrained_cfg:
        if isinstance(pretrained_cfg, dict):
            # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
            pretrained_cfg = PretrainedCfg(**pretrained_cfg)
        elif isinstance(pretrained_cfg, str):
            pretrained_tag = pretrained_cfg
            pretrained_cfg = None

    # fallback to looking up pretrained cfg in model registry by variant identifier
    if not pretrained_cfg:
        if pretrained_tag:
            model_with_tag = '.'.join([variant, pretrained_tag])
        pretrained_cfg = get_pretrained_cfg(model_with_tag)

    if not pretrained_cfg:
        _logger.warning(
            f"No pretrained configuration specified for {model_with_tag} model. Using a default."
            f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
        pretrained_cfg = PretrainedCfg()  # instance with defaults

    pretrained_cfg_overlay = pretrained_cfg_overlay or {}
    if not pretrained_cfg.architecture:
        pretrained_cfg_overlay.setdefault('architecture', variant)
    pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)

    return pretrained_cfg


def build_model_with_cfg(
        model_cls: Callable,
        variant: str,
        pretrained: bool,
        pretrained_cfg: Optional[Dict] = None,
        pretrained_cfg_overlay: Optional[Dict] = None,
        model_cfg: Optional[Any] = None,
        feature_cfg: Optional[Dict] = None,
        pretrained_strict: bool = True,
        pretrained_filter_fn: Optional[Callable] = None,
        kwargs_filter: Optional[Tuple[str]] = None,
        **kwargs,
):
    """ Build model with specified default_cfg and optional model_cfg

    This helper fn aids in the construction of a model including:
      * handling default_cfg and associated pretrained weight loading
      * passing through optional model_cfg for models with config based arch spec
      * features_only model adaptation
      * pruning config / model adaptation

    Args:
        model_cls (nn.Module): model class
        variant (str): model variant name
        pretrained (bool): load pretrained weights
        pretrained_cfg (dict): model's pretrained weight/task config
        model_cfg (Optional[Dict]): model's architecture config
        feature_cfg (Optional[Dict]: feature extraction adapter config
        pretrained_strict (bool): load pretrained weights strictly
        pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
        kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
        **kwargs: model args passed through to model __init__
    """
    pruned = kwargs.pop('pruned', False)
    features = False
    feature_cfg = feature_cfg or {}

    # resolve and update model pretrained config and model kwargs
    pretrained_cfg = resolve_pretrained_cfg(
        variant,
        pretrained_cfg=pretrained_cfg,
        pretrained_cfg_overlay=pretrained_cfg_overlay
    )

    # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
    pretrained_cfg = pretrained_cfg.to_dict()

    _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)

    # Setup for feature extraction wrapper done at end of this fn
    if kwargs.pop('features_only', False):
        features = True
        feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
        if 'out_indices' in kwargs:
            feature_cfg['out_indices'] = kwargs.pop('out_indices')

    # Instantiate the model
    if model_cfg is None:
        model = model_cls(**kwargs)
    else:
        model = model_cls(cfg=model_cfg, **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg  # alias for backwards compat

    if pruned:
        model = adapt_model_from_file(model, variant)

    # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
    num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
    if pretrained:
        load_pretrained(
            model,
            pretrained_cfg=pretrained_cfg,
            num_classes=num_classes_pretrained,
            in_chans=kwargs.get('in_chans', 3),
            filter_fn=pretrained_filter_fn,
            strict=pretrained_strict,
        )

    # Wrap the model in a feature extraction module if enabled
    if features:
        feature_cls = FeatureListNet
        output_fmt = getattr(model, 'output_fmt', None)
        if output_fmt is not None:
            feature_cfg.setdefault('output_fmt', output_fmt)
        if 'feature_cls' in feature_cfg:
            feature_cls = feature_cfg.pop('feature_cls')
            if isinstance(feature_cls, str):
                feature_cls = feature_cls.lower()
                if 'hook' in feature_cls:
                    feature_cls = FeatureHookNet
                elif feature_cls == 'dict':
                    feature_cls = FeatureDictNet
                elif feature_cls == 'fx':
                    feature_cls = FeatureGraphNet
                elif feature_cls == 'getter':
                    feature_cls = FeatureGetterNet
                else:
                    assert False, f'Unknown feature class {feature_cls}'
        model = feature_cls(model, **feature_cfg)
        model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg)  # add back pretrained cfg
        model.default_cfg = model.pretrained_cfg  # alias for rename backwards compat (default_cfg -> pretrained_cfg)

    return model

In [39]:
# _manipulate.py

import collections.abc
import math
import re
from collections import defaultdict
from itertools import chain
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union

import torch
from torch import nn as nn
from torch.utils.checkpoint import checkpoint

__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
           'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']


def model_parameters(model: nn.Module, exclude_head: bool = False):
    if exclude_head:
        # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
        return [p for p in model.parameters()][:-2]
    else:
        return model.parameters()


def named_apply(
        fn: Callable,
        module: nn.Module, name='',
        depth_first: bool = True,
        include_root: bool = False,
) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = '.'.join((name, child_name)) if name else child_name
        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


def named_modules(
        module: nn.Module,
        name: str = '',
        depth_first: bool = True,
        include_root: bool = False,
):
    if not depth_first and include_root:
        yield name, module
    for child_name, child_module in module.named_children():
        child_name = '.'.join((name, child_name)) if name else child_name
        yield from named_modules(
            module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        yield name, module


def named_modules_with_params(
        module: nn.Module,
        name: str = '',
        depth_first: bool = True,
        include_root: bool = False,
):
    if module._parameters and not depth_first and include_root:
        yield name, module
    for child_name, child_module in module.named_children():
        child_name = '.'.join((name, child_name)) if name else child_name
        yield from named_modules_with_params(
            module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if module._parameters and depth_first and include_root:
        yield name, module


MATCH_PREV_GROUP = (99999,)


def group_with_matcher(
        named_objects: Iterator[Tuple[str, Any]],
        group_matcher: Union[Dict, Callable],
        return_values: bool = False,
        reverse: bool = False
):
    if isinstance(group_matcher, dict):
        # dictionary matcher contains a dict of raw-string regex expr that must be compiled
        compiled = []
        for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
            if mspec is None:
                continue
            # map all matching specifications into 3-tuple (compiled re, prefix, suffix)
            if isinstance(mspec, (tuple, list)):
                # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
                for sspec in mspec:
                    compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
            else:
                compiled += [(re.compile(mspec), (group_ordinal,), None)]
        group_matcher = compiled

    def _get_grouping(name):
        if isinstance(group_matcher, (list, tuple)):
            for match_fn, prefix, suffix in group_matcher:
                r = match_fn.match(name)
                if r:
                    parts = (prefix, r.groups(), suffix)
                    # map all tuple elem to int for numeric sort, filter out None entries
                    return tuple(map(float, chain.from_iterable(filter(None, parts))))
            return float('inf'),  # un-matched layers (neck, head) mapped to largest ordinal
        else:
            ord = group_matcher(name)
            if not isinstance(ord, collections.abc.Iterable):
                return ord,
            return tuple(ord)

    # map layers into groups via ordinals (ints or tuples of ints) from matcher
    grouping = defaultdict(list)
    for k, v in named_objects:
        grouping[_get_grouping(k)].append(v if return_values else k)

    # remap to integers
    layer_id_to_param = defaultdict(list)
    lid = -1
    for k in sorted(filter(lambda x: x is not None, grouping.keys())):
        if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
            lid += 1
        layer_id_to_param[lid].extend(grouping[k])

    if reverse:
        assert not return_values, "reverse mapping only sensible for name output"
        # output reverse mapping
        param_to_layer_id = {}
        for lid, lm in layer_id_to_param.items():
            for n in lm:
                param_to_layer_id[n] = lid
        return param_to_layer_id

    return layer_id_to_param


def group_parameters(
        module: nn.Module,
        group_matcher,
        return_values: bool = False,
        reverse: bool = False,
):
    return group_with_matcher(
        module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)


def group_modules(
        module: nn.Module,
        group_matcher,
        return_values: bool = False,
        reverse: bool = False,
):
    return group_with_matcher(
        named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)


def flatten_modules(
        named_modules: Iterator[Tuple[str, nn.Module]],
        depth: int = 1,
        prefix: Union[str, Tuple[str, ...]] = '',
        module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
):
    prefix_is_tuple = isinstance(prefix, tuple)
    if isinstance(module_types, str):
        if module_types == 'container':
            module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
        else:
            module_types = (nn.Sequential,)
    for name, module in named_modules:
        if depth and isinstance(module, module_types):
            yield from flatten_modules(
                module.named_children(),
                depth - 1,
                prefix=(name,) if prefix_is_tuple else name,
                module_types=module_types,
            )
        else:
            if prefix_is_tuple:
                name = prefix + (name,)
                yield name, module
            else:
                if prefix:
                    name = '.'.join([prefix, name])
                yield name, module


def checkpoint_seq(
        functions,
        x,
        every=1,
        flatten=False,
        skip_last=False,
        preserve_rng_state=True
):
    r"""A helper function for checkpointing sequential models.

    Sequential models execute a list of modules/functions in order
    (sequentially). Therefore, we can divide such a sequence into segments
    and checkpoint each segment. All segments except run in :func:`torch.no_grad`
    manner, i.e., not storing the intermediate activations. The inputs of each
    checkpointed segment will be saved for re-running the segment in the backward pass.

    See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.

    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.

    .. warning:
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients.

    Args:
        functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
        x: A Tensor that is input to :attr:`functions`
        every: checkpoint every-n functions (default: 1)
        flatten (bool): flatten nn.Sequential of nn.Sequentials
        skip_last (bool): skip checkpointing the last function in the sequence if True
        preserve_rng_state (bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.

    Returns:
        Output of running :attr:`functions` sequentially on :attr:`*inputs`

    Example:
        >>> model = nn.Sequential(...)
        >>> input_var = checkpoint_seq(model, input_var, every=2)
    """
    def run_function(start, end, functions):
        def forward(_x):
            for j in range(start, end + 1):
                _x = functions[j](_x)
            return _x
        return forward

    if isinstance(functions, torch.nn.Sequential):
        functions = functions.children()
    if flatten:
        functions = chain.from_iterable(functions)
    if not isinstance(functions, (tuple, list)):
        functions = tuple(functions)

    num_checkpointed = len(functions)
    if skip_last:
        num_checkpointed -= 1
    end = -1
    for start in range(0, num_checkpointed, every):
        end = min(start + every - 1, num_checkpointed - 1)
        x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
    if skip_last:
        return run_function(end + 1, len(functions) - 1, functions)(x)
    return x


def adapt_input_conv(in_chans, conv_weight):
    conv_type = conv_weight.dtype
    conv_weight = conv_weight.float()  # Some weights are in torch.half, ensure it's float for sum on CPU
    O, I, J, K = conv_weight.shape
    if in_chans == 1:
        if I > 3:
            assert conv_weight.shape[1] % 3 == 0
            # For models with space2depth stems
            conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
            conv_weight = conv_weight.sum(dim=2, keepdim=False)
        else:
            conv_weight = conv_weight.sum(dim=1, keepdim=True)
    elif in_chans != 3:
        if I != 3:
            raise NotImplementedError('Weight format not supported by conversion.')
        else:
            # NOTE this strategy should be better than random init, but there could be other combinations of
            # the original RGB input layer weights that'd work better for specific cases.
            repeat = int(math.ceil(in_chans / 3))
            conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
            conv_weight *= (3 / float(in_chans))
    conv_weight = conv_weight.to(conv_type)
    return conv_weight

In [40]:
import copy
from collections import deque, defaultdict
from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union


__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']


@dataclass
class PretrainedCfg:
    """
    """
    # weight source locations
    url: Optional[Union[str, Tuple[str, str]]] = None  # remote URL
    file: Optional[str] = None  # local / shared filesystem path
    state_dict: Optional[Dict[str, Any]] = None  # in-memory state dict
    hf_hub_id: Optional[str] = None  # Hugging Face Hub model id ('organization/model')
    hf_hub_filename: Optional[str] = None  # Hugging Face Hub filename (overrides default)

    source: Optional[str] = None  # source of cfg / weight location used (url, file, hf-hub)
    architecture: Optional[str] = None  # architecture variant can be set when not implicit
    tag: Optional[str] = None  # pretrained tag of source
    custom_load: bool = False  # use custom model specific model.load_pretrained() (ie for npz files)

    # input / data config
    input_size: Tuple[int, int, int] = (3, 224, 224)
    test_input_size: Optional[Tuple[int, int, int]] = None
    min_input_size: Optional[Tuple[int, int, int]] = None
    fixed_input_size: bool = False
    interpolation: str = 'bicubic'
    crop_pct: float = 0.875
    test_crop_pct: Optional[float] = None
    crop_mode: str = 'center'
    mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
    std: Tuple[float, ...] = (0.229, 0.224, 0.225)

    # head / classifier config and meta-data
    num_classes: int = 1000
    label_offset: Optional[int] = None
    label_names: Optional[Tuple[str]] = None
    label_descriptions: Optional[Dict[str, str]] = None

    # model attributes that vary with above or required for pretrained adaptation
    pool_size: Optional[Tuple[int, ...]] = None
    test_pool_size: Optional[Tuple[int, ...]] = None
    first_conv: Optional[str] = None
    classifier: Optional[str] = None

    license: Optional[str] = None
    description: Optional[str] = None
    origin_url: Optional[str] = None
    paper_name: Optional[str] = None
    paper_ids: Optional[Union[str, Tuple[str]]] = None
    notes: Optional[Tuple[str]] = None

    @property
    def has_weights(self):
        return self.url or self.file or self.hf_hub_id

    def to_dict(self, remove_source=False, remove_null=True):
        return filter_pretrained_cfg(
            asdict(self),
            remove_source=remove_source,
            remove_null=remove_null
        )


def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
    filtered_cfg = {}
    keep_null = {'pool_size', 'first_conv', 'classifier'}  # always keep these keys, even if none
    for k, v in cfg.items():
        if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
            continue
        if remove_null and v is None and k not in keep_null:
            continue
        filtered_cfg[k] = v
    return filtered_cfg


@dataclass
class DefaultCfg:
    tags: Deque[str] = field(default_factory=deque)  # priority queue of tags (first is default)
    cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict)  # pretrained cfgs by tag
    is_pretrained: bool = False  # at least one of the configs has a pretrained source set

    @property
    def default(self):
        return self.cfgs[self.tags[0]]

    @property
    def default_with_tag(self):
        tag = self.tags[0]
        return tag, self.cfgs[tag]

In [41]:
# _registry.py


""" Model Registry
Hacked together by / Copyright 2020 Ross Wightman
"""

import fnmatch
import re
import sys
import warnings
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple

# from ._pretrained import PretrainedCfg, DefaultCfg

__all__ = [
    'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
    'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
    'get_pretrained_cfg_value', 'is_model_pretrained'
]

_module_to_models: Dict[str, Set[str]] = defaultdict(set)  # dict of sets to check membership of model in module
_model_to_module: Dict[str, str] = {}  # mapping of model names to module names
_model_entrypoints: Dict[str, Callable[..., Any]] = {}  # mapping of model names to architecture entrypoint fns
_model_has_pretrained: Set[str] = set()  # set of model names that have pretrained weight url present
_model_default_cfgs: Dict[str, PretrainedCfg] = {}  # central repo for model arch -> default cfg objects
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {}  # central repo for model arch.tag -> pretrained cfgs
_model_with_tags: Dict[str, List[str]] = defaultdict(list)  # shortcut to map each model arch to all model + tag names
_module_to_deprecated_models: Dict[str, Dict[str, Optional[str]]] = defaultdict(dict)
_deprecated_models: Dict[str, Optional[str]] = {}


def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
    model_name, *tag_list = model_name.split('.', 1)
    tag = tag_list[0] if tag_list else no_tag
    return model_name, tag


def get_arch_name(model_name: str) -> str:
    return split_model_name_tag(model_name)[0]


def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
    out = defaultdict(DefaultCfg)
    default_set = set()  # no tag and tags ending with * are prioritized as default

    for k, v in cfgs.items():
        if isinstance(v, dict):
            v = PretrainedCfg(**v)
        has_weights = v.has_weights

        model, tag = split_model_name_tag(k)
        is_default_set = model in default_set
        priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
        tag = tag.strip('*')

        default_cfg = out[model]

        if priority:
            default_cfg.tags.appendleft(tag)
            default_set.add(model)
        elif has_weights and not default_cfg.is_pretrained:
            default_cfg.tags.appendleft(tag)
        else:
            default_cfg.tags.append(tag)

        if has_weights:
            default_cfg.is_pretrained = True

        default_cfg.cfgs[tag] = v

    return out


def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
    # lookup containing module
    mod = sys.modules[fn.__module__]
    module_name_split = fn.__module__.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''

    # add model to __all__ in module
    model_name = fn.__name__
    if hasattr(mod, '__all__'):
        mod.__all__.append(model_name)
    else:
        mod.__all__ = [model_name]  # type: ignore

    # add entries to registry dict/sets
    if model_name in _model_entrypoints:
        warnings.warn(
            f'Overwriting {model_name} in registry with {fn.__module__}.{model_name}. This is because the name being '
            'registered conflicts with an existing name. Please check if this is not expected.',
            stacklevel=2,
        )
    _model_entrypoints[model_name] = fn
    _model_to_module[model_name] = module_name
    _module_to_models[module_name].add(model_name)
    if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
        # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
        # entrypoints or non-matching combos
        default_cfg = mod.default_cfgs[model_name]
        if not isinstance(default_cfg, DefaultCfg):
            # new style default cfg dataclass w/ multiple entries per model-arch
            assert isinstance(default_cfg, dict)
            # old style cfg dict per model-arch
            pretrained_cfg = PretrainedCfg(**default_cfg)
            default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})

        for tag_idx, tag in enumerate(default_cfg.tags):
            is_default = tag_idx == 0
            pretrained_cfg = default_cfg.cfgs[tag]
            model_name_tag = '.'.join([model_name, tag]) if tag else model_name
            replace_items = dict(architecture=model_name, tag=tag if tag else None)
            if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
                # auto-complete hub name w/ architecture.tag
                replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
            pretrained_cfg = replace(pretrained_cfg, **replace_items)

            if is_default:
                _model_pretrained_cfgs[model_name] = pretrained_cfg
                if pretrained_cfg.has_weights:
                    # add tagless entry if it's default and has weights
                    _model_has_pretrained.add(model_name)

            if tag:
                _model_pretrained_cfgs[model_name_tag] = pretrained_cfg
                if pretrained_cfg.has_weights:
                    # add model w/ tag if tag is valid
                    _model_has_pretrained.add(model_name_tag)
                _model_with_tags[model_name].append(model_name_tag)
            else:
                _model_with_tags[model_name].append(model_name)  # has empty tag (to slowly remove these instances)

        _model_default_cfgs[model_name] = default_cfg

    return fn


def _deprecated_model_shim(deprecated_name: str, current_fn: Callable = None, current_tag: str = ''):
    def _fn(pretrained=False, **kwargs):
        assert current_fn is not None,  f'Model {deprecated_name} has been removed with no replacement.'
        current_name = '.'.join([current_fn.__name__, current_tag]) if current_tag else current_fn.__name__
        warnings.warn(f'Mapping deprecated model name {deprecated_name} to current {current_name}.', stacklevel=2)
        pretrained_cfg = kwargs.pop('pretrained_cfg', None)
        return current_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg or current_tag, **kwargs)
    return _fn


def register_model_deprecations(module_name: str, deprecation_map: Dict[str, Optional[str]]):
    mod = sys.modules[module_name]
    module_name_split = module_name.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''

    for deprecated, current in deprecation_map.items():
        if hasattr(mod, '__all__'):
            mod.__all__.append(deprecated)
        current_fn = None
        current_tag = ''
        if current:
            current_name, current_tag = split_model_name_tag(current)
            current_fn = getattr(mod, current_name)
        deprecated_entrypoint_fn = _deprecated_model_shim(deprecated, current_fn, current_tag)
        setattr(mod, deprecated, deprecated_entrypoint_fn)
        _model_entrypoints[deprecated] = deprecated_entrypoint_fn
        _model_to_module[deprecated] = module_name
        _module_to_models[module_name].add(deprecated)
        _deprecated_models[deprecated] = current
        _module_to_deprecated_models[module_name][deprecated] = current


def _natural_key(string_: str) -> List[Union[int, str]]:
    """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def _expand_filter(filter: str):
    """ expand a 'base_filter' to 'base_filter.*' if no tag portion"""
    filter_base, filter_tag = split_model_name_tag(filter)
    if not filter_tag:
        return ['.'.join([filter_base, '*']), filter]
    else:
        return [filter]


def list_models(
        filter: Union[str, List[str]] = '',
        module: str = '',
        pretrained: bool = False,
        exclude_filters: Union[str, List[str]] = '',
        name_matches_cfg: bool = False,
        include_tags: Optional[bool] = None,
) -> List[str]:
    """ Return list of available model names, sorted alphabetically

    Args:
        filter - Wildcard filter string that works with fnmatch
        module - Limit model selection to a specific submodule (ie 'vision_transformer')
        pretrained - Include only models with valid pretrained weights if True
        exclude_filters - Wildcard filters to exclude models after including them with filter
        name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
        include_tags - Include pretrained tags in model names (model.tag). If None, defaults
            set to True when pretrained=True else False (default: None)

    Returns:
        models - The sorted list of models

    Example:
        model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
        model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
    """
    if filter:
        include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
    else:
        include_filters = []

    if include_tags is None:
        # FIXME should this be default behaviour? or default to include_tags=True?
        include_tags = pretrained

    all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
    all_models = all_models - _deprecated_models.keys()  # remove deprecated models from listings

    if include_tags:
        # expand model names to include names w/ pretrained tags
        models_with_tags: Set[str] = set()
        for m in all_models:
            models_with_tags.update(_model_with_tags[m])
        all_models = models_with_tags
        # expand include and exclude filters to include a '.*' for proper match if no tags in filter
        include_filters = [ef for f in include_filters for ef in _expand_filter(f)]
        exclude_filters = [ef for f in exclude_filters for ef in _expand_filter(f)]

    if include_filters:
        models: Set[str] = set()
        for f in include_filters:
            include_models = fnmatch.filter(all_models, f)  # include these models
            if len(include_models):
                models = models.union(include_models)
    else:
        models = all_models

    if exclude_filters:
        if not isinstance(exclude_filters, (tuple, list)):
            exclude_filters = [exclude_filters]
        for xf in exclude_filters:
            exclude_models = fnmatch.filter(models, xf)  # exclude these models
            if len(exclude_models):
                models = models.difference(exclude_models)

    if pretrained:
        models = _model_has_pretrained.intersection(models)

    if name_matches_cfg:
        models = set(_model_pretrained_cfgs).intersection(models)

    return sorted(models, key=_natural_key)


def list_pretrained(
        filter: Union[str, List[str]] = '',
        exclude_filters: str = '',
) -> List[str]:
    return list_models(
        filter=filter,
        pretrained=True,
        exclude_filters=exclude_filters,
        include_tags=True,
    )


def get_deprecated_models(module: str = '') -> Dict[str, str]:
    all_deprecated = _module_to_deprecated_models[module] if module else _deprecated_models
    return deepcopy(all_deprecated)


def is_model(model_name: str) -> bool:
    """ Check if a model name exists
    """
    arch_name = get_arch_name(model_name)
    return arch_name in _model_entrypoints


def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
    """Fetch a model entrypoint for specified model name
    """
    arch_name = get_arch_name(model_name)
    if module_filter and arch_name not in _module_to_models.get(module_filter, {}):
        raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.')
    return _model_entrypoints[arch_name]


def list_modules() -> List[str]:
    """ Return list of module names that contain models / model entrypoints
    """
    modules = _module_to_models.keys()
    return sorted(modules)


def is_model_in_modules(
        model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
    """Check if a model exists within a subset of modules

    Args:
        model_name - name of model to check
        module_names - names of modules to search in
    """
    arch_name = get_arch_name(model_name)
    assert isinstance(module_names, (tuple, list, set))
    return any(arch_name in _module_to_models[n] for n in module_names)


def is_model_pretrained(model_name: str) -> bool:
    return model_name in _model_has_pretrained


def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
    if model_name in _model_pretrained_cfgs:
        return deepcopy(_model_pretrained_cfgs[model_name])
    arch_name, tag = split_model_name_tag(model_name)
    if arch_name in _model_default_cfgs:
        # if model arch exists, but the tag is wrong, error out
        raise RuntimeError(f'Invalid pretrained tag ({tag}) for {arch_name}.')
    if allow_unregistered:
        # if model arch doesn't exist, it has no pretrained_cfg registered, allow a default to be created
        return None
    raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')


def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
    """ Get a specific model default_cfg value by key. None if key doesn't exist.
    """
    cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
    return getattr(cfg, cfg_key, None)

In [42]:
""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch

This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch.

99% of the implementation was done from papers, however last minute some adjustments were made
based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit

There are multiple sets of models defined for both architectures. Typically, names with a
 `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit.
These configs work well and appear to be a bit faster / lower resource than the paper.

The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.

Papers:

MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
@article{tu2022maxvit,
  title={MaxViT: Multi-Axis Vision Transformer},
  author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
  journal={ECCV},
  year={2022},
}

CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803
@article{DBLP:journals/corr/abs-2106-04803,
  author    = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan},
  title     = {CoAtNet: Marrying Convolution and Attention for All Data Sizes},
  journal   = {CoRR},
  volume    = {abs/2106.04803},
  year      = {2021}
}

Hacked together by / Copyright 2022, Ross Wightman
"""

import math
from collections import OrderedDict
from dataclasses import dataclass, replace, field
from functools import partial
from typing import Callable, Optional, Union, Tuple, List

import torch
from torch import nn
from torch.jit import Final

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table
# from ._builder import build_model_with_cfg
# from ._features_fx import register_notrace_function
# from ._manipulate import named_apply, checkpoint_seq
# from ._registry import generate_default_cfgs, register_model

__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']


@dataclass
class MaxxVitTransformerCfg:
    dim_head: int = 32
    head_first: bool = True  # head ordering in qkv channel dim
    expand_ratio: float = 4.0
    expand_first: bool = True
    shortcut_bias: bool = True
    attn_bias: bool = True
    attn_drop: float = 0.
    proj_drop: float = 0.
    pool_type: str = 'avg2'
    rel_pos_type: str = 'bias'
    rel_pos_dim: int = 512  # for relative position types w/ MLP
    partition_ratio: int = 32
    window_size: Optional[Tuple[int, int]] = None
    grid_size: Optional[Tuple[int, int]] = None
    no_block_attn: bool = False  # disable window block attention for maxvit (ie only grid)
    use_nchw_attn: bool = False  # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
    init_values: Optional[float] = None
    act_layer: str = 'gelu'
    norm_layer: str = 'layernorm2d'
    norm_layer_cl: str = 'layernorm'
    norm_eps: float = 1e-6

    def __post_init__(self):
        if self.grid_size is not None:
            self.grid_size = to_2tuple(self.grid_size)
        if self.window_size is not None:
            self.window_size = to_2tuple(self.window_size)
            if self.grid_size is None:
                self.grid_size = self.window_size


@dataclass
class MaxxVitConvCfg:
    block_type: str = 'mbconv'
    expand_ratio: float = 4.0
    expand_output: bool = True  # calculate expansion channels from output (vs input chs)
    kernel_size: int = 3
    group_size: int = 1  # 1 == depthwise
    pre_norm_act: bool = False  # activation after pre-norm
    output_bias: bool = True  # bias for shortcut + final 1x1 projection conv
    stride_mode: str = 'dw'  # stride done via one of 'pool', '1x1', 'dw'
    pool_type: str = 'avg2'
    downsample_pool_type: str = 'avg2'
    padding: str = ''
    attn_early: bool = False  # apply attn between conv2 and norm2, instead of after norm2
    attn_layer: str = 'se'
    attn_act_layer: str = 'silu'
    attn_ratio: float = 0.25
    init_values: Optional[float] = 1e-6  # for ConvNeXt block, ignored by MBConv
    act_layer: str = 'gelu'
    norm_layer: str = ''
    norm_layer_cl: str = ''
    norm_eps: Optional[float] = None

    def __post_init__(self):
        # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
        assert self.block_type in ('mbconv', 'convnext')
        use_mbconv = self.block_type == 'mbconv'
        if not self.norm_layer:
            self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
        if not self.norm_layer_cl and not use_mbconv:
            self.norm_layer_cl = 'layernorm'
        if self.norm_eps is None:
            self.norm_eps = 1e-5 if use_mbconv else 1e-6
        self.downsample_pool_type = self.downsample_pool_type or self.pool_type


@dataclass
class MaxxVitCfg:
    embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
    depths: Tuple[int, ...] = (2, 3, 5, 2)
    block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
    stem_width: Union[int, Tuple[int, int]] = 64
    stem_bias: bool = False
    conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
    transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
    head_hidden_size: int = None
    weight_init: str = 'vit_eff'


class Attention2d(nn.Module):
    fused_attn: Final[bool]

    """ multi-head attention for 2D NCHW tensors"""
    def __init__(
            self,
            dim: int,
            dim_out: Optional[int] = None,
            dim_head: int = 32,
            bias: bool = True,
            expand_first: bool = True,
            head_first: bool = True,
            rel_pos_cls: Callable = None,
            attn_drop: float = 0.,
            proj_drop: float = 0.
    ):
        super().__init__()
        dim_out = dim_out or dim
        dim_attn = dim_out if expand_first else dim
        self.num_heads = dim_attn // dim_head
        self.dim_head = dim_head
        self.head_first = head_first
        self.scale = dim_head ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
        self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
#         print(f"Inside Attention2d forward: {x.shape}")
        B, C, H, W = x.shape

#         y = x.clone()
#         qkv_temp = self.qkv(y)

#         print(f"After self.qkv {qkv_temp.shape}")
        
        if self.head_first:
            q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
#             print("self.head_first=True")
            
        else:
            q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
#             print("self.head_first=False")
            
        
#         print(f"q, k, v: {q.shape}, {k.shape}, {v.shape}")


        if self.fused_attn:
            # here
#             print("self.fused_attn = True")
            attn_bias = None
            if self.rel_pos is not None:
                #here
                attn_bias = self.rel_pos.get_bias()
#                 print("self.rel_pos is not None")
                
            elif shared_rel_pos is not None:
                attn_bias = shared_rel_pos
#                 print("shared_rel_pos is not None")
            
            
#             if attn_bias is not None:    
#                 print(f"attn_bias: {attn_bias.shape}")
#             else:
#                 print("attn_bias is none")
            
#             print(f"Now applying x = torch.nn.functional.scaled_dot_product_attention(\
#                 q.transpose(-1, -2).contiguous(),\
#                 k.transpose(-1, -2).contiguous(),\
#                 v.transpose(-1, -2).contiguous(),\
#                 attn_mask=attn_bias,\
#                 dropout_p=self.attn_drop.p if self.training else 0.,\
#             ).transpose(-1, -2).reshape(B, -1, H, W)")

            x = torch.nn.functional.scaled_dot_product_attention(
                q.transpose(-1, -2).contiguous(),
                k.transpose(-1, -2).contiguous(),
                v.transpose(-1, -2).contiguous(),
                attn_mask=attn_bias,
                dropout_p=self.attn_drop.p if self.training else 0.,
            ).transpose(-1, -2).reshape(B, -1, H, W)
        else:
#             print("self.fused_attn = False")
            q = q * self.scale
            attn = q.transpose(-2, -1) @ k
            if self.rel_pos is not None:
                attn = self.rel_pos(attn)
            elif shared_rel_pos is not None:
                attn = attn + shared_rel_pos
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
        
        
#         print(f"After attention mechanism: {x.shape}")
        x = self.proj(x)
#         print(f"After self.proj: {x.shape}")
        x = self.proj_drop(x)
#         print(f"After self.proj_drop: {x.shape}")
#         print('*'*100)
        return x


### PROPOSED ATTENTION MECHANISM ###
import torch.nn.functional as F

class Attention2dPyramidal(nn.Module):
    fused_attn: Final[bool]
        
    def __init__(self, num_levels: int, 
            dim: int,
            dim_out: Optional[int] = None,
            dim_head: int = 32,
            bias: bool = True,
            expand_first: bool = True,
            head_first: bool = True,
            rel_pos_cls: Callable = None,
            attn_drop: float = 0.,
            proj_drop: float = 0.):
        super().__init__()
        self.num_levels = num_levels
        self.attention_modules = nn.ModuleList([
            Attention2d(dim, dim_out, dim_head, expand_first, bias, rel_pos_cls, attn_drop, proj_drop) 
            for _ in range(num_levels)
        ])
        self.attn_bias = bias
        self.rel_pos_cls = rel_pos_cls
        

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
#         print(f"In Attention2dPyramidal forward num_levels={self.num_levels} and x={x.shape}:-")
    
        # Divide input into pyramidal levels
        pyramidal_levels = [x]
        for i in range(1, self.num_levels):
            downsample_factor = 2 * i
            level_input = F.avg_pool2d(x, kernel_size=downsample_factor, stride=downsample_factor)
#             print(f"Shape after downsampling(avgpool2D) {i}th level={level_input.shape}")
            pyramidal_levels.append(level_input)

        # Compute attention at each level
        attention_outputs = []
        for level, level_input in enumerate(pyramidal_levels):
            attention_output = self.attention_modules[level](level_input)
#             print(f"Shape after applying attention at {level}th level={attention_output.shape}")
            attention_outputs.append(attention_output)

        
        # Upsample and combine attention outputs starting from the last level
        combined_attention_output = attention_outputs[-1]
        for i in range(self.num_levels - 2, -1, -1): 
            # Interpolate the attention output of the current level to match the spatial dimensions of the level above it
            combined_attention_output = F.interpolate(combined_attention_output, size=attention_outputs[i].shape[2:], mode='bilinear', align_corners=False)
#             print(f"Shape after applying interpolation at {i+1}th level={combined_attention_output.shape}")

            # Combine the attention output of the current level with the level above it
            combined_attention_output += attention_outputs[i]
#             print(f"Shape after combining {i+1} and {i}th level={combined_attention_output.shape}")


        # Resize the final combined attention output to match the spatial dimensions of the input at the first level
        combined_attention_output = F.interpolate(combined_attention_output, size=pyramidal_levels[0].shape[2:], mode='bilinear', align_corners=False)
#         print(f"Final shape after (interpolation_last) Attention2dPyramidal : {combined_attention_output.shape}")
#         print("$"*100)
        return combined_attention_output

    def match_spatial_dimensions(self, tensor1, tensor2):
        """
        Pad or crop tensor1 to match the spatial dimensions of tensor2.
        """
        if tensor1.shape[2:] != tensor2.shape[2:]:
            diff_h = tensor2.shape[2] - tensor1.shape[2]
            diff_w = tensor2.shape[3] - tensor1.shape[3]
            pad_left = diff_w // 2
            pad_right = diff_w - pad_left
            pad_top = diff_h // 2
            pad_bottom = diff_h - pad_top
            tensor1 = F.pad(tensor1, (pad_left, pad_right, pad_top, pad_bottom))
        return tensor1
    

class AttentionCl(nn.Module):
    """ Channels-last multi-head attention (B, ..., C) """
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            dim_out: Optional[int] = None,
            dim_head: int = 32,
            bias: bool = True,
            expand_first: bool = True,
            head_first: bool = True,
            rel_pos_cls: Callable = None,
            attn_drop: float = 0.,
            proj_drop: float = 0.
    ):
        super().__init__()
        dim_out = dim_out or dim
        dim_attn = dim_out if expand_first and dim_out > dim else dim
        assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim'
        self.num_heads = dim_attn // dim_head
        self.dim_head = dim_head
        self.head_first = head_first
        self.scale = dim_head ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
        self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_attn, dim_out, bias=bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
        B = x.shape[0]
        restore_shape = x.shape[:-1]

        if self.head_first:
            q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3)
        else:
            q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)

        if self.fused_attn:
            attn_bias = None
            if self.rel_pos is not None:
                attn_bias = self.rel_pos.get_bias()
            elif shared_rel_pos is not None:
                attn_bias = shared_rel_pos

            x = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=attn_bias,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            if self.rel_pos is not None:
                attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
            elif shared_rel_pos is not None:
                attn = attn + shared_rel_pos
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(restore_shape + (-1,))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        gamma = self.gamma
        return x.mul_(gamma) if self.inplace else x * gamma


class LayerScale2d(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        gamma = self.gamma.view(1, -1, 1, 1)
        return x.mul_(gamma) if self.inplace else x * gamma


class Downsample2d(nn.Module):
    """ A downsample pooling module supporting several maxpool and avgpool modes
    * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
    * 'max2' - MaxPool2d w/ kernel_size = stride = 2
    * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
    * 'avg2' - AvgPool2d w/ kernel_size = stride = 2
    """

    def __init__(
            self,
            dim: int,
            dim_out: int,
            pool_type: str = 'avg2',
            padding: str = '',
            bias: bool = True,
    ):
        super().__init__()
        assert pool_type in ('max', 'max2', 'avg', 'avg2')
        if pool_type == 'max':
            self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1)
        elif pool_type == 'max2':
            self.pool = create_pool2d('max', 2, padding=padding or 0)  # kernel_size == stride == 2
        elif pool_type == 'avg':
            self.pool = create_pool2d(
                'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1)
        else:
            self.pool = create_pool2d('avg', 2, padding=padding or 0)

        if dim != dim_out:
            self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
        else:
            self.expand = nn.Identity()

    def forward(self, x):
#         print(f"In Downsample2D forward:- {x.shape}")
        x = self.pool(x)  # spatial downsample
#         print(f"After self.pool: {x.shape}")
        x = self.expand(x)  # expand chs
#         print(f"After self.expand: {x.shape}")
        return x


def _init_transformer(module, name, scheme=''):
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        if scheme == 'normal':
            nn.init.normal_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'trunc_normal':
            trunc_normal_tf_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'xavier_normal':
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        else:
            # vit like
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                if 'mlp' in name:
                    nn.init.normal_(module.bias, std=1e-6)
                else:
                    nn.init.zeros_(module.bias)


class TransformerBlock2d(nn.Module):
    """ Transformer block with 2D downsampling
    '2D' NCHW tensor layout

    Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
    for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.

    This impl was faster on TPU w/ PT XLA than the 1D experiment.
    """

    def __init__(
            self,
            dim: int,
            dim_out: int,
            stride: int = 1,
            rel_pos_cls: Callable = None,
            cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
            drop_path: float = 0.,
    ):
        super().__init__()
        norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
        act_layer = get_act_layer(cfg.act_layer)

        if stride == 2:
            self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias)
            self.norm1 = nn.Sequential(OrderedDict([
                ('norm', norm_layer(dim)),
                ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)),
            ]))
        else:
            assert dim == dim_out
            self.shortcut = nn.Identity()
            self.norm1 = norm_layer(dim)

        ### MODIFIED HERE    
        self.attn = Attention2dPyramidal(
            4, # num_levels
            dim,
            dim_out,
            dim_head=cfg.dim_head,
            expand_first=cfg.expand_first,
            bias=cfg.attn_bias,
            rel_pos_cls=rel_pos_cls,
            attn_drop=cfg.attn_drop,
            proj_drop=cfg.proj_drop
        )
        self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim_out)
        self.mlp = ConvMlp(
            in_features=dim_out,
            hidden_features=int(dim_out * cfg.expand_ratio),
            act_layer=act_layer,
            drop=cfg.proj_drop)
        self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def init_weights(self, scheme=''):
        named_apply(partial(_init_transformer, scheme=scheme), self)

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
#         print(f"Inside TransformerBlock2d forward: {x.shape}")
        
#         y = x.clone()
        
#         y_shortcut = y.clone()
        
        
#         y = self.norm1(y)
#         print(f"After self.norm1 {y.shape}")
        
#         y = self.attn(y, shared_rel_pos=shared_rel_pos)
#         print(f"After self.attn(y, shared_rel_pos=shared_rel_pos): {y.shape} ")
        
#         y = self.ls1(y)
#         print(f"After self.ls1(y): {y.shape}")
        
#         y = self.drop_path1(y)
#         print(f"After self.drop_path1(y): {y.shape}")
        
#         y_shortcut = self.shortcut(y_shortcut)
#         print(f"After self.shortcut(y_shortcut): {y_shortcut.shape}")
        
#         y = y_shortcut + y
#         print(f"After y_shortcut + y : {y.shape}")
        
        
        x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
#         print(f"after self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) : {x.shape}")
        
        
        
#         y = x.clone()
        
#         y_shortcut = x.clone()
        
#         y = self.norm2(y)
#         print(f"After self.norm2(y): {y.shape}")
        
#         y = self.mlp(y)
#         print(f"After self.mlp: {y.shape}")
        
#         y = self.ls2(y)
#         print(f"After self.ls2: {y.shape}")

#         y = self.drop_path2(y)
#         print(f"After self.drop_path2: {y.shape}")

        
#         y = y_shortcut + y
#         print(f"After y_shortcut + y: {y.shape}")
        
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
#         print(f"After x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))): {x.shape}")
        
        return x


def _init_conv(module, name, scheme=''):
    if isinstance(module, nn.Conv2d):
        if scheme == 'normal':
            nn.init.normal_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'trunc_normal':
            trunc_normal_tf_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'xavier_normal':
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        else:
            # efficientnet like
            fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
            fan_out //= module.groups
            nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
            if module.bias is not None:
                nn.init.zeros_(module.bias)


def num_groups(group_size, channels):
    if not group_size:  # 0 or None
        return 1  # normal conv with 1 group
    else:
        # NOTE group_size == 1 -> depthwise conv
        assert channels % group_size == 0
        return channels // group_size


class MbConvBlock(nn.Module):
    """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
    """
    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            stride: int = 1,
            dilation: Tuple[int, int] = (1, 1),
            cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
            drop_path: float = 0.
    ):
        super(MbConvBlock, self).__init__()
        norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps)
        mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio)
        groups = num_groups(cfg.group_size, mid_chs)

        if stride == 2:
            self.shortcut = Downsample2d(
                in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding)
        else:
            self.shortcut = nn.Identity()

        assert cfg.stride_mode in ('pool', '1x1', 'dw')
        stride_pool, stride_1, stride_2 = 1, 1, 1
        if cfg.stride_mode == 'pool':
            # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1
            stride_pool, dilation_2 = stride, dilation[1]
            # FIXME handle dilation of avg pool
        elif cfg.stride_mode == '1x1':
            # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away
            stride_1, dilation_2 = stride, dilation[1]
        else:
            stride_2, dilation_2 = stride, dilation[0]

        self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act)
        if stride_pool > 1:
            self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding)
        else:
            self.down = nn.Identity()
        self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1)
        self.norm1 = norm_act_layer(mid_chs)

        self.conv2_kxk = create_conv2d(
            mid_chs, mid_chs, cfg.kernel_size,
            stride=stride_2, dilation=dilation_2, groups=groups, padding=cfg.padding)

        attn_kwargs = {}
        if isinstance(cfg.attn_layer, str):
            if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca':
                attn_kwargs['act_layer'] = cfg.attn_act_layer
                attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs))

        # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2)
        if cfg.attn_early:
            self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
            self.norm2 = norm_act_layer(mid_chs)
            self.se = None
        else:
            self.se_early = None
            self.norm2 = norm_act_layer(mid_chs)
            self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)

        self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def init_weights(self, scheme=''):
        named_apply(partial(_init_conv, scheme=scheme), self)

    def forward(self, x):
#         print('-'*100)
#         print(f"Inside MbConvBlock forward:{x.shape}")
        shortcut = self.shortcut(x)
#         print(f"After self.shortcut:{x.shape}")

        x = self.pre_norm(x)
#         print(f"After self.pre_norm:{x.shape}")
        x = self.down(x)
#         print(f"After self.down:{x.shape}")


        # 1x1 expansion conv & norm-act
        x = self.conv1_1x1(x)
#         print(f"After self.conv1_1x1 {x.shape}")
        x = self.norm1(x)
#         print(f"After self.norm1 {x.shape}")


        # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act
        x = self.conv2_kxk(x)
#         print(f"After self.conv2_kxk {x.shape}")
        if self.se_early is not None:
            x = self.se_early(x)
#             print(f"self.se_early is not NONE. After self.se_early {x.shape}")

        x = self.norm2(x)
#         print(f"After self.norm2 {x.shape}")

        if self.se is not None:
            # here
            x = self.se(x)
#             print(f"self.se is not NONE. After self.se: {x.shape}")

        # 1x1 linear projection to output width
        x = self.conv3_1x1(x)
#         print(f"After self.conv3_1x1: {x.shape}")

        x = self.drop_path(x) + shortcut
#         print(f"After self.drop_path(x) + shortcut: {x.shape}")
#         print('-'*100)
        return x


class ConvNeXtBlock(nn.Module):
    """ ConvNeXt Block
    """

    def __init__(
            self,
            in_chs: int,
            out_chs: Optional[int] = None,
            kernel_size: int = 7,
            stride: int = 1,
            dilation: Tuple[int, int] = (1, 1),
            cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
            conv_mlp: bool = True,
            drop_path: float = 0.
    ):
        super().__init__()
        out_chs = out_chs or in_chs
        act_layer = get_act_layer(cfg.act_layer)
        if conv_mlp:
            norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
            mlp_layer = ConvMlp
        else:
            assert 'layernorm' in cfg.norm_layer
            norm_layer = LayerNorm
            mlp_layer = Mlp
        self.use_conv_mlp = conv_mlp

        if stride == 2:
            self.shortcut = Downsample2d(in_chs, out_chs)
        elif in_chs != out_chs:
            self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias)
        else:
            self.shortcut = nn.Identity()

        assert cfg.stride_mode in ('pool', 'dw')
        stride_pool, stride_dw = 1, 1
        # FIXME handle dilation?
        if cfg.stride_mode == 'pool':
            stride_pool = stride
        else:
            stride_dw = stride

        if stride_pool == 2:
            self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type)
        else:
            self.down = nn.Identity()

        self.conv_dw = create_conv2d(
            in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1],
            depthwise=True, bias=cfg.output_bias)
        self.norm = norm_layer(out_chs)
        self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer)
        if conv_mlp:
            self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
        else:
            self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.down(x)
        x = self.conv_dw(x)
        if self.use_conv_mlp:
            x = self.norm(x)
            x = self.mlp(x)
            x = self.ls(x)
        else:
            x = x.permute(0, 2, 3, 1)
            x = self.norm(x)
            x = self.mlp(x)
            x = self.ls(x)
            x = x.permute(0, 3, 1, 2)

        x = self.drop_path(x) + shortcut
        return x


def window_partition(x, window_size: List[int]):
    B, H, W, C = x.shape
    _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
    _assert(W % window_size[1] == 0, '')
    x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
    return windows


@register_notrace_function  # reason: int argument is a Proxy
def window_reverse(windows, window_size: List[int], img_size: List[int]):
    H, W = img_size
    C = windows.shape[-1]
    x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
    return x


def grid_partition(x, grid_size: List[int]):
    B, H, W, C = x.shape
    _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
    _assert(W % grid_size[1] == 0, '')
    x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
    windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
    return windows


@register_notrace_function  # reason: int argument is a Proxy
def grid_reverse(windows, grid_size: List[int], img_size: List[int]):
    H, W = img_size
    C = windows.shape[-1]
    x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C)
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)
    return x


def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
    rel_pos_cls = None
    if cfg.rel_pos_type == 'mlp':
        rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim)
    elif cfg.rel_pos_type == 'bias':
        rel_pos_cls = partial(RelPosBias, window_size=window_size)
    elif cfg.rel_pos_type == 'bias_tf':
        rel_pos_cls = partial(RelPosBiasTf, window_size=window_size)
    return rel_pos_cls


class PartitionAttentionCl(nn.Module):
    """ Grid or Block partition + Attn + FFN.
    NxC 'channels last' tensor layout.
    """

    def __init__(
            self,
            dim: int,
            partition_type: str = 'block',
            cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
            drop_path: float = 0.,
    ):
        super().__init__()
        norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps)  # NOTE this block is channels-last
        act_layer = get_act_layer(cfg.act_layer)

        self.partition_block = partition_type == 'block'
        self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
        rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)

        self.norm1 = norm_layer(dim)
        self.attn = AttentionCl(
            dim,
            dim,
            dim_head=cfg.dim_head,
            bias=cfg.attn_bias,
            head_first=cfg.head_first,
            rel_pos_cls=rel_pos_cls,
            attn_drop=cfg.attn_drop,
            proj_drop=cfg.proj_drop,
        )
        self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * cfg.expand_ratio),
            act_layer=act_layer,
            drop=cfg.proj_drop)
        self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def _partition_attn(self, x):
        img_size = x.shape[1:3]
        if self.partition_block:
            partitioned = window_partition(x, self.partition_size)
        else:
            partitioned = grid_partition(x, self.partition_size)

        partitioned = self.attn(partitioned)

        if self.partition_block:
            x = window_reverse(partitioned, self.partition_size, img_size)
        else:
            x = grid_reverse(partitioned, self.partition_size, img_size)
        return x

    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


class ParallelPartitionAttention(nn.Module):
    """ Experimental. Grid and Block partition + single FFN
    NxC tensor layout.
    """

    def __init__(
            self,
            dim: int,
            cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
            drop_path: float = 0.,
    ):
        super().__init__()
        assert dim % 2 == 0
        norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps)  # NOTE this block is channels-last
        act_layer = get_act_layer(cfg.act_layer)

        assert cfg.window_size == cfg.grid_size
        self.partition_size = to_2tuple(cfg.window_size)
        rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)

        self.norm1 = norm_layer(dim)
        self.attn_block = AttentionCl(
            dim,
            dim // 2,
            dim_head=cfg.dim_head,
            bias=cfg.attn_bias,
            head_first=cfg.head_first,
            rel_pos_cls=rel_pos_cls,
            attn_drop=cfg.attn_drop,
            proj_drop=cfg.proj_drop,
        )
        self.attn_grid = AttentionCl(
            dim,
            dim // 2,
            dim_head=cfg.dim_head,
            bias=cfg.attn_bias,
            head_first=cfg.head_first,
            rel_pos_cls=rel_pos_cls,
            attn_drop=cfg.attn_drop,
            proj_drop=cfg.proj_drop,
        )
        self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * cfg.expand_ratio),
            out_features=dim,
            act_layer=act_layer,
            drop=cfg.proj_drop)
        self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def _partition_attn(self, x):
        img_size = x.shape[1:3]

        partitioned_block = window_partition(x, self.partition_size)
        partitioned_block = self.attn_block(partitioned_block)
        x_window = window_reverse(partitioned_block, self.partition_size, img_size)

        partitioned_grid = grid_partition(x, self.partition_size)
        partitioned_grid = self.attn_grid(partitioned_grid)
        x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size)

        return torch.cat([x_window, x_grid], dim=-1)

    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


def window_partition_nchw(x, window_size: List[int]):
    B, C, H, W = x.shape
    _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
    _assert(W % window_size[1] == 0, '')
    x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
    windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
    return windows


@register_notrace_function  # reason: int argument is a Proxy
def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]):
    H, W = img_size
    C = windows.shape[1]
    x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
    return x


def grid_partition_nchw(x, grid_size: List[int]):
    B, C, H, W = x.shape
    _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
    _assert(W % grid_size[1] == 0, '')
    x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
    windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
    return windows


@register_notrace_function  # reason: int argument is a Proxy
def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
    H, W = img_size
    C = windows.shape[1]
    x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1])
    x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W)
    return x


class PartitionAttention2d(nn.Module):
    """ Grid or Block partition + Attn + FFN

    '2D' NCHW tensor layout.
    """

    def __init__(
            self,
            dim: int,
            partition_type: str = 'block',
            cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
            drop_path: float = 0.,
    ):
        super().__init__()
        norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)  # NOTE this block is channels-last
        act_layer = get_act_layer(cfg.act_layer)

        self.partition_block = partition_type == 'block'
        self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
        rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)

        self.norm1 = norm_layer(dim)
        self.attn = Attention2d(
            dim,
            dim,
            dim_head=cfg.dim_head,
            bias=cfg.attn_bias,
            head_first=cfg.head_first,
            rel_pos_cls=rel_pos_cls,
            attn_drop=cfg.attn_drop,
            proj_drop=cfg.proj_drop,
        )
        self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = ConvMlp(
            in_features=dim,
            hidden_features=int(dim * cfg.expand_ratio),
            act_layer=act_layer,
            drop=cfg.proj_drop)
        self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def _partition_attn(self, x):
        img_size = x.shape[-2:]
        if self.partition_block:
            partitioned = window_partition_nchw(x, self.partition_size)
        else:
            partitioned = grid_partition_nchw(x, self.partition_size)

        partitioned = self.attn(partitioned)

        if self.partition_block:
            x = window_reverse_nchw(partitioned, self.partition_size, img_size)
        else:
            x = grid_reverse_nchw(partitioned, self.partition_size, img_size)
        return x

    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


class MaxxVitBlock(nn.Module):
    """ MaxVit conv, window partition + FFN , grid partition + FFN
    """

    def __init__(
            self,
            dim: int,
            dim_out: int,
            stride: int = 1,
            conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
            transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
            drop_path: float = 0.,
    ):
        super().__init__()
        self.nchw_attn = transformer_cfg.use_nchw_attn

        conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
        self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)

        attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
        partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
        self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
        self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)

    def init_weights(self, scheme=''):
        if self.attn_block is not None:
            named_apply(partial(_init_transformer, scheme=scheme), self.attn_block)
        named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid)
        named_apply(partial(_init_conv, scheme=scheme), self.conv)

    def forward(self, x):
        # NCHW format
        x = self.conv(x)

        if not self.nchw_attn:
            x = x.permute(0, 2, 3, 1)  # to NHWC (channels-last)
        if self.attn_block is not None:
            x = self.attn_block(x)
        x = self.attn_grid(x)
        if not self.nchw_attn:
            x = x.permute(0, 3, 1, 2)  # back to NCHW
        return x


class ParallelMaxxVitBlock(nn.Module):
    """ MaxVit block with parallel cat(window + grid), one FF
    Experimental timm block.
    """

    def __init__(
            self,
            dim,
            dim_out,
            stride=1,
            num_conv=2,
            conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
            transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
            drop_path=0.,
    ):
        super().__init__()

        conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
        if num_conv > 1:
            convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)]
            convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1)
            self.conv = nn.Sequential(*convs)
        else:
            self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
        self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)

    def init_weights(self, scheme=''):
        named_apply(partial(_init_transformer, scheme=scheme), self.attn)
        named_apply(partial(_init_conv, scheme=scheme), self.conv)

    def forward(self, x):
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1)
        x = self.attn(x)
        x = x.permute(0, 3, 1, 2)
        return x


class MaxxVitStage(nn.Module):
    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            stride: int = 2,
            depth: int = 4,
            feat_size: Tuple[int, int] = (14, 14),
            block_types: Union[str, Tuple[str]] = 'C',
            transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
            conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
            drop_path: Union[float, List[float]] = 0.,
    ):
        super().__init__()
        self.grad_checkpointing = False

        block_types = extend_tuple(block_types, depth)
        blocks = []
        for i, t in enumerate(block_types):
            block_stride = stride if i == 0 else 1
            assert t in ('C', 'T', 'M', 'PM')
            if t == 'C':
                conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
                blocks += [conv_cls(
                    in_chs,
                    out_chs,
                    stride=block_stride,
                    cfg=conv_cfg,
                    drop_path=drop_path[i],
                )]
            elif t == 'T':
                rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size)
                blocks += [TransformerBlock2d(
                    in_chs,
                    out_chs,
                    stride=block_stride,
                    rel_pos_cls=rel_pos_cls,
                    cfg=transformer_cfg,
                    drop_path=drop_path[i],
                )]
            elif t == 'M':
                blocks += [MaxxVitBlock(
                    in_chs,
                    out_chs,
                    stride=block_stride,
                    conv_cfg=conv_cfg,
                    transformer_cfg=transformer_cfg,
                    drop_path=drop_path[i],
                )]
            elif t == 'PM':
                blocks += [ParallelMaxxVitBlock(
                    in_chs,
                    out_chs,
                    stride=block_stride,
                    conv_cfg=conv_cfg,
                    transformer_cfg=transformer_cfg,
                    drop_path=drop_path[i],
                )]
            in_chs = out_chs
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
#         print(f"Inside MaxxVitStage x.shape: {x.shape}")
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
#         print("After self.blocks", x.shape)
        return x


class Stem(nn.Module):

    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            kernel_size: int = 3,
            padding: str = '',
            bias: bool = False,
            act_layer: str = 'gelu',
            norm_layer: str = 'batchnorm2d',
            norm_eps: float = 1e-5,
    ):
        super().__init__()
        if not isinstance(out_chs, (list, tuple)):
            out_chs = to_2tuple(out_chs)

        norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
        self.in_chs = in_chs # added this line
        self.out_chs = out_chs[-1]
        self.stride = 2

        self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias)
        self.norm1 = norm_act_layer(out_chs[0])
        self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias)

    def init_weights(self, scheme=''):
        named_apply(partial(_init_conv, scheme=scheme), self)

    def forward(self, x):
#         print("Inside Stem class", x.shape)
#         print(f"in_chs:{self.in_chs} and out_chs:{self.out_chs}")
        x = self.conv1(x)
#         print("after self.conv1", x.shape)
        x = self.norm1(x)
#         print("after self.norm1", x.shape)
        x = self.conv2(x)
#         print("after self.conv2", x.shape)
        
        return x


def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
    if cfg.window_size is not None:
        assert cfg.grid_size
        return cfg
    partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio
    cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
    return cfg


def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
    transformer_kwargs = {}
    conv_kwargs = {}
    base_kwargs = {}
    for k, v in kwargs.items():
        if k.startswith('transformer_'):
            transformer_kwargs[k.replace('transformer_', '')] = v
        elif k.startswith('conv_'):
            conv_kwargs[k.replace('conv_', '')] = v
        else:
            base_kwargs[k] = v
    cfg = replace(
        cfg,
        transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
        conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
        **base_kwargs
    )
    return cfg


class MaxxVit(nn.Module):
    """ CoaTNet + MaxVit base model.

    Highly configurable for different block compositions, tensor layouts, pooling types.
    """

    def __init__(
            self,
            cfg: MaxxVitCfg,
            img_size: Union[int, Tuple[int, int]] = 224,
            in_chans: int = 3,
            num_classes: int = 1000,
            global_pool: str = 'avg',
            drop_rate: float = 0.,
            drop_path_rate: float = 0.,
            **kwargs,
    ):
        super().__init__()
        img_size = to_2tuple(img_size)
        if kwargs:
            cfg = _overlay_kwargs(cfg, **kwargs)
        transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
        self.num_classes = num_classes
        self.global_pool = global_pool
        self.num_features = self.embed_dim = cfg.embed_dim[-1]
        self.drop_rate = drop_rate
        self.grad_checkpointing = False
        self.feature_info = []

        self.stem = Stem(
            in_chs=in_chans,
            out_chs=cfg.stem_width,
            padding=cfg.conv_cfg.padding,
            bias=cfg.stem_bias,
            act_layer=cfg.conv_cfg.act_layer,
            norm_layer=cfg.conv_cfg.norm_layer,
            norm_eps=cfg.conv_cfg.norm_eps,
        )
        stride = self.stem.stride
        self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
        feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])

        num_stages = len(cfg.embed_dim)
        assert len(cfg.depths) == num_stages
        dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
        in_chs = self.stem.out_chs
        stages = []
        for i in range(num_stages):
            stage_stride = 2
            out_chs = cfg.embed_dim[i]
            feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size])
            stages += [MaxxVitStage(
                in_chs,
                out_chs,
                depth=cfg.depths[i],
                block_types=cfg.block_type[i],
                conv_cfg=cfg.conv_cfg,
                transformer_cfg=transformer_cfg,
                feat_size=feat_size,
                drop_path=dpr[i],
            )]
            stride *= stage_stride
            in_chs = out_chs
            self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
        self.stages = nn.Sequential(*stages)

        final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
        self.head_hidden_size = cfg.head_hidden_size
        if self.head_hidden_size:
            self.norm = nn.Identity()
            self.head = NormMlpClassifierHead(
                self.num_features,
                num_classes,
                hidden_size=self.head_hidden_size,
                pool_type=global_pool,
                drop_rate=drop_rate,
                norm_layer=final_norm_layer,
            )
        else:
            # standard classifier head w/ norm, pooling, fc classifier
            self.norm = final_norm_layer(self.num_features)
            self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)

        # Weight init (default PyTorch init works well for AdamW if scheme not set)
        assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff')
        if cfg.weight_init:
            named_apply(partial(self._init_weights, scheme=cfg.weight_init), self)

    def _init_weights(self, module, name, scheme=''):
        if hasattr(module, 'init_weights'):
            try:
                module.init_weights(scheme=scheme)
            except TypeError:
                module.init_weights()

    @torch.jit.ignore
    def no_weight_decay(self):
        return {
            k for k, _ in self.named_parameters()
            if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        matcher = dict(
            stem=r'^stem',  # stem and embed
            blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
        )
        return matcher

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        for s in self.stages:
            s.grad_checkpointing = enable

    @torch.jit.ignore
    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool=None):
        self.num_classes = num_classes
        self.head.reset(num_classes, global_pool)

    def forward_features(self, x):
#         print(f"In forward_features: {x.shape}")
        x = self.stem(x)
#         print(f"After self.stem: {x.shape}")
        x = self.stages(x)
#         print(f"After self.stages: {x.shape}")
        x = self.norm(x)
#         print(f"After self.norm: {x.shape}")
        return x

    def forward_head(self, x, pre_logits: bool = False):
#         print(f"In forward_head function: {x.shape}")
#         print(f"pre_logits is {pre_logits}")
        return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)

    def forward(self, x):
#         print(f"In forward (main): {x.shape}")
        x = self.forward_features(x)
#         print(f"After forward_features (main): {x.shape}")
        x = self.forward_head(x)
#         print(f"After forward_head (main): {x.shape}")
        return x


def _rw_coat_cfg(
        stride_mode='pool',
        pool_type='avg2',
        conv_output_bias=False,
        conv_attn_early=False,
        conv_attn_act_layer='relu',
        conv_norm_layer='',
        transformer_shortcut_bias=True,
        transformer_norm_layer='layernorm2d',
        transformer_norm_layer_cl='layernorm',
        init_values=None,
        rel_pos_type='bias',
        rel_pos_dim=512,
):
    # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
    # Common differences for initial timm models:
    # - pre-norm layer in MZBConv included an activation after norm
    # - mbconv expansion calculated from input instead of output chs
    # - mbconv shortcut and final 1x1 conv did not have a bias
    # - SE act layer was relu, not silu
    # - mbconv uses silu in timm, not gelu
    # - expansion in attention block done via output proj, not input proj
    # Variable differences (evolved over training initial models):
    # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
    # - SE attention was between conv2 and norm/act
    # - default to avg pool for mbconv downsample instead of 1x1 or dw conv
    # - transformer block shortcut has no bias
    return dict(
        conv_cfg=MaxxVitConvCfg(
            stride_mode=stride_mode,
            pool_type=pool_type,
            pre_norm_act=True,
            expand_output=False,
            output_bias=conv_output_bias,
            attn_early=conv_attn_early,
            attn_act_layer=conv_attn_act_layer,
            act_layer='silu',
            norm_layer=conv_norm_layer,
        ),
        transformer_cfg=MaxxVitTransformerCfg(
            expand_first=False,
            shortcut_bias=transformer_shortcut_bias,
            pool_type=pool_type,
            init_values=init_values,
            norm_layer=transformer_norm_layer,
            norm_layer_cl=transformer_norm_layer_cl,
            rel_pos_type=rel_pos_type,
            rel_pos_dim=rel_pos_dim,
        ),
    )


def _rw_max_cfg(
        stride_mode='dw',
        pool_type='avg2',
        conv_output_bias=False,
        conv_attn_ratio=1 / 16,
        conv_norm_layer='',
        transformer_norm_layer='layernorm2d',
        transformer_norm_layer_cl='layernorm',
        window_size=None,
        dim_head=32,
        init_values=None,
        rel_pos_type='bias',
        rel_pos_dim=512,
):
    # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
    # Differences of initial timm models:
    # - mbconv expansion calculated from input instead of output chs
    # - mbconv shortcut and final 1x1 conv did not have a bias
    # - mbconv uses silu in timm, not gelu
    # - expansion in attention block done via output proj, not input proj
    return dict(
        conv_cfg=MaxxVitConvCfg(
            stride_mode=stride_mode,
            pool_type=pool_type,
            expand_output=False,
            output_bias=conv_output_bias,
            attn_ratio=conv_attn_ratio,
            act_layer='silu',
            norm_layer=conv_norm_layer,
        ),
        transformer_cfg=MaxxVitTransformerCfg(
            expand_first=False,
            pool_type=pool_type,
            dim_head=dim_head,
            window_size=window_size,
            init_values=init_values,
            norm_layer=transformer_norm_layer,
            norm_layer_cl=transformer_norm_layer_cl,
            rel_pos_type=rel_pos_type,
            rel_pos_dim=rel_pos_dim,
        ),
    )


def _next_cfg(
        stride_mode='dw',
        pool_type='avg2',
        conv_norm_layer='layernorm2d',
        conv_norm_layer_cl='layernorm',
        transformer_norm_layer='layernorm2d',
        transformer_norm_layer_cl='layernorm',
        window_size=None,
        no_block_attn=False,
        init_values=1e-6,
        rel_pos_type='mlp',  # MLP by default for maxxvit
        rel_pos_dim=512,
):
    # For experimental models with convnext instead of mbconv
    init_values = to_2tuple(init_values)
    return dict(
        conv_cfg=MaxxVitConvCfg(
            block_type='convnext',
            stride_mode=stride_mode,
            pool_type=pool_type,
            expand_output=False,
            init_values=init_values[0],
            norm_layer=conv_norm_layer,
            norm_layer_cl=conv_norm_layer_cl,
        ),
        transformer_cfg=MaxxVitTransformerCfg(
            expand_first=False,
            pool_type=pool_type,
            window_size=window_size,
            no_block_attn=no_block_attn,  # enabled for MaxxViT-V2
            init_values=init_values[1],
            norm_layer=transformer_norm_layer,
            norm_layer_cl=transformer_norm_layer_cl,
            rel_pos_type=rel_pos_type,
            rel_pos_dim=rel_pos_dim,
        ),
    )


def _tf_cfg():
    return dict(
        conv_cfg=MaxxVitConvCfg(
            norm_eps=1e-3,
            act_layer='gelu_tanh',
            padding='same',
        ),
        transformer_cfg=MaxxVitTransformerCfg(
            norm_eps=1e-5,
            act_layer='gelu_tanh',
            head_first=False,  # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....)
            rel_pos_type='bias_tf',
        ),
    )


model_cfgs = dict(
    # timm specific CoAtNet configs
    coatnet_pico_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(2, 3, 5, 2),
        stem_width=(32, 64),
        **_rw_max_cfg(  # using newer max defaults here
            conv_output_bias=True,
            conv_attn_ratio=0.25,
        ),
    ),
    coatnet_nano_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(3, 4, 6, 3),
        stem_width=(32, 64),
        **_rw_max_cfg(  # using newer max defaults here
            stride_mode='pool',
            conv_output_bias=True,
            conv_attn_ratio=0.25,
        ),
    ),
    coatnet_0_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 3, 7, 2),  # deeper than paper '0' model
        stem_width=(32, 64),
        **_rw_coat_cfg(
            conv_attn_early=True,
            transformer_shortcut_bias=False,
        ),
    ),
    coatnet_1_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 6, 14, 2),
        stem_width=(32, 64),
        **_rw_coat_cfg(
            stride_mode='dw',
            conv_attn_early=True,
            transformer_shortcut_bias=False,
        )
    ),
    coatnet_2_rw=MaxxVitCfg(
        embed_dim=(128, 256, 512, 1024),
        depths=(2, 6, 14, 2),
        stem_width=(64, 128),
        **_rw_coat_cfg(
            stride_mode='dw',
            conv_attn_act_layer='silu',
            #init_values=1e-6,
        ),
    ),
    coatnet_3_rw=MaxxVitCfg(
        embed_dim=(192, 384, 768, 1536),
        depths=(2, 6, 14, 2),
        stem_width=(96, 192),
        **_rw_coat_cfg(
            stride_mode='dw',
            conv_attn_act_layer='silu',
            init_values=1e-6,
        ),
    ),

    # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
    coatnet_bn_0_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 3, 7, 2),  # deeper than paper '0' model
        stem_width=(32, 64),
        **_rw_coat_cfg(
            stride_mode='dw',
            conv_attn_early=True,
            transformer_shortcut_bias=False,
            transformer_norm_layer='batchnorm2d',
        )
    ),
    coatnet_rmlp_nano_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(3, 4, 6, 3),
        stem_width=(32, 64),
        **_rw_max_cfg(
            conv_output_bias=True,
            conv_attn_ratio=0.25,
            rel_pos_type='mlp',
            rel_pos_dim=384,
        ),
    ),
    coatnet_rmlp_0_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 3, 7, 2),  # deeper than paper '0' model
        stem_width=(32, 64),
        **_rw_coat_cfg(
            stride_mode='dw',
            rel_pos_type='mlp',
        ),
    ),
    coatnet_rmlp_1_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 6, 14, 2),
        stem_width=(32, 64),
        **_rw_coat_cfg(
            pool_type='max',
            conv_attn_early=True,
            transformer_shortcut_bias=False,
            rel_pos_type='mlp',
            rel_pos_dim=384,  # was supposed to be 512, woops
        ),
    ),
    coatnet_rmlp_1_rw2=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 6, 14, 2),
        stem_width=(32, 64),
        **_rw_coat_cfg(
            stride_mode='dw',
            rel_pos_type='mlp',
            rel_pos_dim=512,  # was supposed to be 512, woops
        ),
    ),
    coatnet_rmlp_2_rw=MaxxVitCfg(
        embed_dim=(128, 256, 512, 1024),
        depths=(2, 6, 14, 2),
        stem_width=(64, 128),
        **_rw_coat_cfg(
            stride_mode='dw',
            conv_attn_act_layer='silu',
            init_values=1e-6,
            rel_pos_type='mlp'
        ),
    ),
    coatnet_rmlp_3_rw=MaxxVitCfg(
        embed_dim=(192, 384, 768, 1536),
        depths=(2, 6, 14, 2),
        stem_width=(96, 192),
        **_rw_coat_cfg(
            stride_mode='dw',
            conv_attn_act_layer='silu',
            init_values=1e-6,
            rel_pos_type='mlp'
        ),
    ),

    coatnet_nano_cc=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(3, 4, 6, 3),
        stem_width=(32, 64),
        block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
        **_rw_coat_cfg(),
    ),
    coatnext_nano_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(3, 4, 6, 3),
        stem_width=(32, 64),
        weight_init='normal',
        **_next_cfg(
            rel_pos_type='bias',
            init_values=(1e-5, None)
        ),
    ),

    # Trying to be like the CoAtNet paper configs
    coatnet_0=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 3, 5, 2),
        stem_width=64,
        head_hidden_size=768,
    ),
    coatnet_1=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 6, 14, 2),
        stem_width=64,
        head_hidden_size=768,
    ),
    coatnet_2=MaxxVitCfg(
        embed_dim=(128, 256, 512, 1024),
        depths=(2, 6, 14, 2),
        stem_width=128,
        head_hidden_size=1024,
    ),
    coatnet_3=MaxxVitCfg(
        embed_dim=(192, 384, 768, 1536),
        depths=(2, 6, 14, 2),
        stem_width=192,
        head_hidden_size=1536,
    ),
    coatnet_4=MaxxVitCfg(
        embed_dim=(192, 384, 768, 1536),
        depths=(2, 12, 28, 2),
        stem_width=192,
        head_hidden_size=1536,
    ),
    coatnet_5=MaxxVitCfg(
        embed_dim=(256, 512, 1280, 2048),
        depths=(2, 12, 28, 2),
        stem_width=192,
        head_hidden_size=2048,
    ),

    # Experimental MaxVit configs
    maxvit_pico_rw=MaxxVitCfg(
        embed_dim=(32, 64, 128, 256),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=(24, 32),
        **_rw_max_cfg(),
    ),
    maxvit_nano_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(1, 2, 3, 1),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        **_rw_max_cfg(),
    ),
    maxvit_tiny_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        **_rw_max_cfg(),
    ),
    maxvit_tiny_pm=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(2, 2, 5, 2),
        block_type=('PM',) * 4,
        stem_width=(32, 64),
        **_rw_max_cfg(),
    ),

    maxvit_rmlp_pico_rw=MaxxVitCfg(
        embed_dim=(32, 64, 128, 256),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=(24, 32),
        **_rw_max_cfg(rel_pos_type='mlp'),
    ),
    maxvit_rmlp_nano_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(1, 2, 3, 1),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        **_rw_max_cfg(rel_pos_type='mlp'),
    ),
    maxvit_rmlp_tiny_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        **_rw_max_cfg(rel_pos_type='mlp'),
    ),
    maxvit_rmlp_small_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        **_rw_max_cfg(
            rel_pos_type='mlp',
            init_values=1e-6,
        ),
    ),
    maxvit_rmlp_base_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 6, 14, 2),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        head_hidden_size=768,
        **_rw_max_cfg(
            rel_pos_type='mlp',
        ),
    ),

    maxxvit_rmlp_nano_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(1, 2, 3, 1),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        weight_init='normal',
        **_next_cfg(),
    ),
    maxxvit_rmlp_tiny_rw=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=(32, 64),
        **_next_cfg(),
    ),
    maxxvit_rmlp_small_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=(48, 96),
        **_next_cfg(),
    ),

    maxxvitv2_nano_rw=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(1, 2, 3, 1),
        block_type=('M',) * 4,
        stem_width=(48, 96),
        weight_init='normal',
        **_next_cfg(
            no_block_attn=True,
            rel_pos_type='bias',
        ),
    ),
    maxxvitv2_rmlp_base_rw=MaxxVitCfg(
        embed_dim=(128, 256, 512, 1024),
        depths=(2, 6, 12, 2),
        block_type=('M',) * 4,
        stem_width=(64, 128),
        **_next_cfg(
            no_block_attn=True,
        ),
    ),
    maxxvitv2_rmlp_large_rw=MaxxVitCfg(
        embed_dim=(160, 320, 640, 1280),
        depths=(2, 6, 16, 2),
        block_type=('M',) * 4,
        stem_width=(80, 160),
        head_hidden_size=1280,
        **_next_cfg(
            no_block_attn=True,
        ),
    ),

    # Trying to be like the MaxViT paper configs
    maxvit_tiny_tf=MaxxVitCfg(
        embed_dim=(64, 128, 256, 512),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=64,
        stem_bias=True,
        head_hidden_size=512,
        **_tf_cfg(),
    ),
    maxvit_small_tf=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 2, 5, 2),
        block_type=('M',) * 4,
        stem_width=64,
        stem_bias=True,
        head_hidden_size=768,
        **_tf_cfg(),
    ),
    maxvit_base_tf=MaxxVitCfg(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 6, 14, 2),
        block_type=('M',) * 4,
        stem_width=64,
        stem_bias=True,
        head_hidden_size=768,
        **_tf_cfg(),
    ),
    maxvit_large_tf=MaxxVitCfg(
        embed_dim=(128, 256, 512, 1024),
        depths=(2, 6, 14, 2),
        block_type=('M',) * 4,
        stem_width=128,
        stem_bias=True,
        head_hidden_size=1024,
        **_tf_cfg(),
    ),
    maxvit_xlarge_tf=MaxxVitCfg(
        embed_dim=(192, 384, 768, 1536),
        depths=(2, 6, 14, 2),
        block_type=('M',) * 4,
        stem_width=192,
        stem_bias=True,
        head_hidden_size=1536,
        **_tf_cfg(),
    ),
)


def checkpoint_filter_fn(state_dict, model: nn.Module):
    model_state_dict = model.state_dict()
    out_dict = {}
    for k, v in state_dict.items():
        if k.endswith('relative_position_bias_table'):
            m = model.get_submodule(k[:-29])
            if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
                v = resize_rel_pos_bias_table(
                    v,
                    new_window_size=m.window_size,
                    new_bias_shape=m.relative_position_bias_table.shape,
                )

        if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
            # adapt between conv2d / linear layers
            assert v.ndim in (2, 4)
            v = v.reshape(model_state_dict[k].shape)
        out_dict[k] = v
    return out_dict


def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
    if cfg_variant is None:
        if variant in model_cfgs:
            cfg_variant = variant
        else:
            cfg_variant = '_'.join(variant.split('_')[:-1])
    return build_model_with_cfg(
        MaxxVit, variant, pretrained,
        model_cfg=model_cfgs[cfg_variant],
        feature_cfg=dict(flatten_sequential=True),
        pretrained_filter_fn=checkpoint_filter_fn,
        **kwargs)


def _cfg(url='', **kwargs):
    return {
        'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 0.95, 'interpolation': 'bicubic',
        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
        'first_conv': 'stem.conv1', 'classifier': 'head.fc',
        'fixed_input_size': True,
        **kwargs
    }


default_cfgs = generate_default_cfgs({
    # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
    'coatnet_pico_rw_224.untrained': _cfg(url=''),
    'coatnet_nano_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
        crop_pct=0.9),
    'coatnet_0_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
    'coatnet_1_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
    ),

    # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
    'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    #'coatnet_3_rw_224.untrained': _cfg(url=''),

    # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
    'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),

    # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
    'coatnet_bn_0_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
        crop_pct=0.95),
    'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
        crop_pct=0.9),
    'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
    'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
    'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
    'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
    'coatnet_nano_cc_224.untrained': _cfg(url=''),
    'coatnext_nano_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
        crop_pct=0.9),

    # ImagenNet-12k pretrain CoAtNet
    'coatnet_2_rw_224.sw_in12k': _cfg(
        hf_hub_id='timm/',
        num_classes=11821),
    'coatnet_3_rw_224.sw_in12k': _cfg(
        hf_hub_id='timm/',
        num_classes=11821),
    'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
        hf_hub_id='timm/',
        num_classes=11821),
    'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
        hf_hub_id='timm/',
        num_classes=11821),

    # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
    'coatnet_0_224.untrained': _cfg(url=''),
    'coatnet_1_224.untrained': _cfg(url=''),
    'coatnet_2_224.untrained': _cfg(url=''),
    'coatnet_3_224.untrained': _cfg(url=''),
    'coatnet_4_224.untrained': _cfg(url=''),
    'coatnet_5_224.untrained': _cfg(url=''),

    # timm specific MaxVit configs, ImageNet-1k pretrain or untrained
    'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxvit_nano_rw_256.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
        input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxvit_tiny_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
    'maxvit_tiny_rw_256.untrained': _cfg(
        url='',
        input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),

    # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
    'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
        input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
        input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
        input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
        crop_pct=0.9,
    ),
    'maxvit_rmlp_small_rw_256.untrained': _cfg(
        url='',
        input_size=(3, 256, 256), pool_size=(8, 8)),

    # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
    'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/',
    ),
    'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),

    # timm specific MaxVit w/ ImageNet-12k pretrain
    'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
        hf_hub_id='timm/',
        num_classes=11821,
    ),

    # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
    'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
        input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
        hf_hub_id='timm/',
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
        input_size=(3, 256, 256), pool_size=(8, 8)),

    # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
    'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 256, 256), pool_size=(8, 8)),
    'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),

    'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
        hf_hub_id='timm/',
        num_classes=11821),

    # MaxViT models ported from official Tensorflow impl
    'maxvit_tiny_tf_224.in1k': _cfg(
        hf_hub_id='timm/',
        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    'maxvit_tiny_tf_384.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxvit_tiny_tf_512.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
    'maxvit_small_tf_224.in1k': _cfg(
        hf_hub_id='timm/',
        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    'maxvit_small_tf_384.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxvit_small_tf_512.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
    'maxvit_base_tf_224.in1k': _cfg(
        hf_hub_id='timm/',
        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    'maxvit_base_tf_384.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxvit_base_tf_512.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
    'maxvit_large_tf_224.in1k': _cfg(
        hf_hub_id='timm/',
        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    'maxvit_large_tf_384.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxvit_large_tf_512.in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),

    'maxvit_base_tf_224.in21k': _cfg(
        hf_hub_id='timm/',
        num_classes=21843),
    'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
    'maxvit_large_tf_224.in21k': _cfg(
        hf_hub_id='timm/',
        num_classes=21843),
    'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
    'maxvit_xlarge_tf_224.in21k': _cfg(
        hf_hub_id='timm/',
        num_classes=21843),
    'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
    'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
})


@register_model
def coatnet_pico_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_0_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_1_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_2_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_3_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_bn_0_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_rmlp_1_rw2_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)


@register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_nano_cc_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)


@register_model
def coatnext_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_0_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_1_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_2_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_3_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_4_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs)


@register_model
def coatnet_5_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs)


@register_model
def maxvit_pico_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxvit_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxvit_tiny_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs)


@register_model
def maxvit_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)


@register_model
def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)


@register_model
def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)


@register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)


@register_model
def maxxvit_rmlp_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxxvitv2_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)


@register_model
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)


@register_model
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)


@register_model
def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)


@register_model
def maxvit_tiny_tf_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_tiny_tf_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_tiny_tf_512(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_small_tf_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_small_tf_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_small_tf_512(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_base_tf_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_base_tf_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_base_tf_512(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_large_tf_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_large_tf_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_large_tf_512(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_xlarge_tf_224(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_xlarge_tf_384(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)


@register_model
def maxvit_xlarge_tf_512(pretrained=False, **kwargs) -> MaxxVit:
    return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)



In [43]:
import torch
import torch.nn as nn
import torchvision.models as models


class CoAtNetMultiscalePyramidal(nn.Module):
    def __init__(self, num_classes, fine_tune=False):
        super(CoAtNetMultiscalePyramidal, self).__init__()

        # Add a convolutional layer at the top
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)  # Assuming input is grayscale (1 channel)

        ## set the num_classes here according to the pretraining experiment
        # loading the weights for coatnet module from DTD(47classes) or Flowers102(102classes)
        #pretrained weights
        
        # DTD weights (kaggle dataset version-2)
        self.coatnet = coatnet_3_rw_224(pretrained=False, in_chans=3, num_classes=47)
        model_path = '/kaggle/input/18-04-2024-dtd-coatnetmultiscalev1-multiclass-prwt/best_model_precision.pth'
        
        #OR (comment one of these)
        
        # Flowers102 weights (kaggle dataset version-1)
#         self.coatnet = coatnet_3_rw_224(pretrained=False, in_chans=3, num_classes=102)
#         model_path = '/kaggle/input/18-04-2024-flowers102-coatnetmultiscale-pretwts/best_model_precision.pth'

        
        self.coatnet.load_state_dict(torch.load(model_path, map_location=device))
        
        
        if not fine_tune:
            # Freeze all layers except classifier layers
            for param in self.coatnet.parameters():
                param.requires_grad = False

            # Unfreeze the classifier layers
            for param in self.coatnet.head.parameters():
                param.requires_grad = True
            

        # Get the number of input features for the final fully connected layer
        in_features = self.coatnet.head.fc.in_features

        # Replace the final fully connected layer with a new one for the specified number of classes
        self.coatnet.head.fc = nn.Linear(in_features, num_classes)
        
        

    def forward(self, x):
        x = self.conv(x)
        x = self.coatnet(x)
        return x
    
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CoAtNetMultiscalePyramidal(num_classes, fine_tune=False) # change fine_tune as required
model.to(device)

# print(model)
print()

x = torch.randn(1, 1, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(device)


output = model(x)
print("Model output's shape:", output.shape)
print(output) # logits 
display_params_flops(model)


Model output's shape: torch.Size([1, 15])
tensor([[-0.0622,  0.5884,  0.0112,  0.0047, -0.3784, -0.0343, -0.3528, -0.6443,
         -0.0919, -0.6064,  0.2091, -0.2661,  0.1060,  0.1185, -0.1239]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Number of parameters in millions: 309.49 M
Number of trainable parameters in millions: 0.02 M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_avgpool() for <class 'torch.nn.modules.pooling.AvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 34.96G, Params: 309.36M


In [44]:
model.coatnet.head

ClassifierHead(
  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
  (drop): Dropout(p=0.0, inplace=False)
  (fc): Linear(in_features=1536, out_features=15, bias=True)
  (flatten): Identity()
)