# Torch implementation to match model dict of microsoft/resnet50

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResNetConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation=True):
        super(ResNetConvLayer, self).__init__()
        self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size,
                                     stride=stride, padding=padding, bias=False)
        self.normalization = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU() if activation else nn.Identity()

    def forward(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        x = self.activation(x)
        return x

class ResNetShortCut(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(ResNetShortCut, self).__init__()
        self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                     stride=stride, bias=False)
        self.normalization = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        return x

class ResNetBottleNeckLayer(nn.Module):
    def __init__(self, in_channels, out_channels, bottleneck_channels, stride):
        super(ResNetBottleNeckLayer, self).__init__()
        if in_channels != out_channels or stride != 1:
            self.shortcut = ResNetShortCut(in_channels, out_channels, stride)
        else:
            self.shortcut = nn.Identity()

        self.layer = nn.Sequential(
            ResNetConvLayer(in_channels, bottleneck_channels, kernel_size=1, stride=1),
            ResNetConvLayer(bottleneck_channels, bottleneck_channels, kernel_size=3,
                            stride=stride, padding=1),
            ResNetConvLayer(bottleneck_channels, out_channels, kernel_size=1,
                            stride=1, activation=False)
        )
        self.activation = nn.ReLU()

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.layer(x)
        out += identity
        out = self.activation(out)
        return out

class ResNetStage(nn.Module):
    def __init__(self, num_layers, in_channels, out_channels, bottleneck_channels, stride):
        super(ResNetStage, self).__init__()
        layers = []
        layers.append(ResNetBottleNeckLayer(in_channels, out_channels, bottleneck_channels, stride))
        for _ in range(1, num_layers):
            layers.append(ResNetBottleNeckLayer(out_channels, out_channels, bottleneck_channels, stride=1))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

class ResNetEncoder(nn.Module):
    def __init__(self, layers_config):
        super(ResNetEncoder, self).__init__()
        stages = []
        in_channels = 64
        for num_layers, out_channels, bottleneck_channels, stride in layers_config:
            stage = ResNetStage(num_layers, in_channels, out_channels, bottleneck_channels, stride)
            stages.append(stage)
            in_channels = out_channels
        self.stages = nn.ModuleList(stages)

    def forward(self, x):
        for stage in self.stages:
            x = stage(x)
        return x

class ResNetEmbeddings(nn.Module):
    def __init__(self):
        super(ResNetEmbeddings, self).__init__()
        self.embedder = ResNetConvLayer(3, 64, kernel_size=7, stride=2, padding=3)
        self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.embedder(x)
        x = self.pooler(x)
        return x

class ResNetModel(nn.Module):
    def __init__(self):
        super(ResNetModel, self).__init__()
        self.embedder = ResNetEmbeddings()
        layers_config = [
            (3, 256, 64, 1),   # (num_layers, out_channels, bottleneck_channels, stride)
            (4, 512, 128, 2),
            (6, 1024, 256, 2),
            (3, 2048, 512, 2),
        ]
        self.encoder = ResNetEncoder(layers_config)
        self.pooler = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.embedder(x)
        x = self.encoder(x)
        x = self.pooler(x)
        return x


In [6]:
import torch
import torch.nn as nn
from transformers import AutoModel  # This is just for the Hugging Face example

# Your provided ResNet code with Bottleneck and ResNet classes is assumed to be defined above.

def load_pretrained_weights(model, state_dict):
    """
    This function loads weights from a state_dict into the custom ResNet model.
    """
    model_dict = model.state_dict()  # Get the model's state dict
    pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}  # Filter weights that exist in model
    model_dict.update(pretrained_dict)  # Update the model's state dict with pre-trained weights
    model.load_state_dict(model_dict, strict=False)  # Load the updated state dict with strict=False

    return model

# Instantiate your custom ResNet50 model
model = ResNetModel().to('cuda')


# Load pre-trained weights (using Hugging Face as an example)
# Replace with Hugging Face's ResNet-50 model weights
pretrained_model = AutoModel.from_pretrained("microsoft/resnet-50")
pretrained_weights = pretrained_model.state_dict()


In [7]:
# Load the pre-trained weights into your model
model = load_pretrained_weights(model, pretrained_weights)

# Print a summary of the model to verify
print(model)


ResNetModel(
  (embedder): ResNetEmbeddings(
    (embedder): ResNetConvLayer(
      (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (encoder): ResNetEncoder(
    (stages): ModuleList(
      (0): ResNetStage(
        (layers): Sequential(
          (0): ResNetBottleNeckLayer(
            (shortcut): ResNetShortCut(
              (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (layer): Sequential(
              (0): ResNetConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalizatio

In [9]:
# Example forward pass with dummy input
input_tensor = torch.randn(1, 3, 224, 224).to('cuda')  # A batch of images with shape (1, 3, 224, 224)
output = model(input_tensor)
print("Output shape:", output.shape)

Output shape: torch.Size([1, 2048, 1, 1])


# Tripy Implementation

In [10]:
import tripy as tp

class TPBatchNorm(tp.Module):
    def __init__(self, num_features, eps=1e-5):
        super(TPBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps

        # Initialize learnable parameters (gamma and beta)
        self.gamma = tp.ones((num_features,), dtype=tp.float32)
        self.beta = tp.zeros((num_features,), dtype=tp.float32)

    def __call__(self, x):
        # Calculate mean and variance across the batch and spatial dimensions
        mean = tp.mean(x, dim=(0, 2, 3), keepdim=True)
        variance = tp.var(x, dim=(0, 2, 3), keepdim=True)

        # Normalize the input
        x_normalized = (x - mean) / tp.sqrt(variance + self.eps)

        # Apply the learned scaling (gamma) and shifting (beta)
        x_scaled = self.gamma * x_normalized + self.beta
        return x_scaled


class TPIdentity(tp.Module):
    def __init__(self):
        super(TPIdentity, self).__init__()

    def __call__(self, x):
        return x

class TPResNetConvLayer(tp.Module):
    def __init__(self, in_channels, out_channels, kernel_dims, stride=(1, 1), padding=((0, 0), (0, 0)), activation=True):
        super(TPResNetConvLayer, self).__init__()
        # All parameters should be passed as kernel_dims, stride, and padding in correct shape
        self.convolution = tp.Conv(
            in_channels, out_channels, kernel_dims=kernel_dims,
            stride=stride, padding=padding, bias=False
        )
        self.normalization = TPBatchNorm(out_channels)
        self.activation = tp.relu if activation else TPIdentity()

    def __call__(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        x = self.activation(x)
        return x

class TPResNetShortCut(tp.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(TPResNetShortCut, self).__init__()
        self.convolution = tp.Conv(
            in_channels, out_channels, kernel_dims=(1, 1),
            stride=stride, bias=False
        )
        self.normalization = TPBatchNorm(out_channels)

    def __call__(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        return x

class TPResNetBottleNeckLayer(tp.Module):
    def __init__(self, in_channels, out_channels, bottleneck_channels, stride):
        super(TPResNetBottleNeckLayer, self).__init__()
        self.shortcut = TPResNetShortCut(in_channels, out_channels, stride) if in_channels != out_channels or stride != (1, 1) else TPIdentity()

        self.conv1 = TPResNetConvLayer(in_channels, bottleneck_channels, kernel_dims=(1, 1), stride=(1, 1))
        self.conv2 = TPResNetConvLayer(bottleneck_channels, bottleneck_channels, kernel_dims=(3, 3), stride=stride, padding=((1, 1), (1, 1)))
        self.conv3 = TPResNetConvLayer(bottleneck_channels, out_channels, kernel_dims=(1, 1), stride=(1, 1), activation=False)
        self.activation = tp.relu

    def __call__(self, x):
        identity = self.shortcut(x)
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = out + identity
        out = self.activation(out)
        return out

class TPResNetStage(tp.Module):
    def __init__(self, num_layers, in_channels, out_channels, bottleneck_channels, stride):
        super(TPResNetStage, self).__init__()
        self.layers = []
        for i in range(num_layers):
            layer = TPResNetBottleNeckLayer(
                in_channels if i == 0 else out_channels,
                out_channels, bottleneck_channels,
                stride if i == 0 else (1, 1)
            )
            setattr(self, f'layer_{i}', layer)
            self.layers.append(layer)

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class TPResNetEncoder(tp.Module):
    def __init__(self, layers_config):
        super(TPResNetEncoder, self).__init__()
        self.stages = []
        in_channels = 64
        for idx, (num_layers, out_channels, bottleneck_channels, stride) in enumerate(layers_config):
            stage = TPResNetStage(num_layers, in_channels, out_channels, bottleneck_channels, stride)
            setattr(self, f'stage_{idx}', stage)
            self.stages.append(stage)
            in_channels = out_channels

    def __call__(self, x):
        for stage in self.stages:
            x = stage(x)
        return x

class TPMaxPool2d(tp.Module):
    def __init__(self, kernel_dims, stride=None, padding=((0, 0), (0, 0))):
        """
        Custom MaxPool2d class.

        :param kernel_dims: Size of the window to take the max over.
        :param stride: Stride of the window. If None, it will default to the kernel_dims.
        :param padding: Implicit padding added on both sides of the input. Should be a tuple of tuples.
        """
        super(TPMaxPool2d, self).__init__()
        self.kernel_dims = kernel_dims
        self.stride = stride if stride is not None else kernel_dims
        self.padding = padding

    def __call__(self, x):
        return tp.maxpool(x, kernel_dims=self.kernel_dims, stride=self.stride, padding=self.padding)

class TPResNetEmbeddings(tp.Module):
    def __init__(self):
        super(TPResNetEmbeddings, self).__init__()
        self.embedder = TPResNetConvLayer(3, 64, kernel_dims=(7, 7), stride=(2, 2), padding=((3, 3), (3, 3)))
        self.pooler = TPMaxPool2d(kernel_dims=(3, 3), stride=(2, 2), padding=((1, 1), (1, 1)))

    def __call__(self, x):
        x = self.embedder(x)
        x = self.pooler(x)
        return x

class TPAdaptiveAvgPool2d(tp.Module):
    def __init__(self, output_size):
        super(TPAdaptiveAvgPool2d, self).__init__()
        self.output_size = output_size

    def __call__(self, x):
        N, C, H_in, W_in = x.shape
        H_out, W_out = self.output_size

        # Calculate stride and kernel size
        stride_h = H_in // H_out
        stride_w = W_in // W_out

        kernel_size_h = H_in - (H_out - 1) * stride_h
        kernel_size_w = W_in - (W_out - 1) * stride_w

        return tp.avgpool(x, kernel_dims=(int(kernel_size_h), int(kernel_size_w)), stride=(int(stride_h), int(stride_w)))

class TPResNetModel(tp.Module):
    def __init__(self):
        super(TPResNetModel, self).__init__()
        self.embedder = TPResNetEmbeddings()
        layers_config = [
            (3, 256, 64, (1, 1)),
            (4, 512, 128, (2, 2)),
            (6, 1024, 256, (2, 2)),
            (3, 2048, 512, (2, 2)),
        ]
        self.encoder = TPResNetEncoder(layers_config)
        self.pooler = TPAdaptiveAvgPool2d(output_size=(1, 1))

    def __call__(self, x):
        x = self.embedder(x)
        x = self.encoder(x)
        x = self.pooler(x)
        return x


In [12]:
! export CUDA_VISIBLE_DEVICES=0

In [13]:
# Example usage:
x= tp.Tensor(torch.randn((1, 3, 224, 224), dtype=torch.float32), device=tp.device("gpu"))
model = TPResNetModel()


In [14]:
x

tensor(
    [[[[-0.1311, 1.0774, 1.3947,  ..., 1.0397, 0.2198, -0.2758],
       [-1.0221, -1.4959, 0.7427,  ..., 0.3901, -0.9117, 1.0383],
       [-1.2783, -0.7854, -1.2019,  ..., -0.6904, 1.3004, 2.9297],
       ...,
       [1.0976, 0.0842, 0.2624,  ..., -0.3132, -2.8922, -0.8280],
       [0.1549, -1.5338, -0.1673,  ..., -0.5271, 0.7771, -0.6724],
       [0.4640, -0.4680, -1.5712,  ..., 0.9485, -0.4994, 1.7285]],

      [[-1.1836, -0.9162, -0.8744,  ..., 0.2399, -0.0623, -1.0218],
       [0.4253, 2.1474, -0.2283,  ..., 0.4382, 0.5894, 0.5940],
       [0.9455, 0.1197, 1.0935,  ..., -0.1383, -0.3859, -0.0111],
       ...,
       [0.3884, -0.3844, -0.5540,  ..., -0.2329, 0.3538, -1.1639],
       [-0.7965, 0.6928, 1.2354,  ..., -0.1626, -0.6443, -1.0238],
       [-0.6482, 1.6554, 0.6752,  ..., -0.2146, 2.0221, -0.4599]],

      [[2.1775, 0.5714, 1.0615,  ..., -1.0824, 0.6238, -0.7028],
       [-0.3998, 0.9117, -0.0165,  ..., -1.0873, 0.8513, 0.2896],
       [0.5579, -0.6037, -0.0752,  ...

In [17]:
eager_output = model(x) 

KeyboardInterrupt: 

In [None]:
eager_output.shape

In [15]:
input_shape = [1,3,224,224]

In [16]:
import time

model = TPResNetModel()

compile_start_time = time.perf_counter()
model = tp.compile(model, args=[tp.InputInfo(input_shape, dtype=tp.float32)])
compile_end_time = time.perf_counter()
print(f"Compilation took {compile_end_time - compile_start_time} seconds.")

KeyboardInterrupt: 

In [41]:
output = model(x)

AttributeError: 'function' object has no attribute 'eval'

# TODO
- explain the intuiton and make it more readable