# ResNet50 with `tripy` - An Implementation and Demo

This notebook aims to demonstrate the implementation of a ResNet50 model using `tripy`. We'll explore the architecture, 
load pretrained weights, and run predictions on a sample dataset, showcasing how `tripy` can be used effectively for deep learning tasks.

### Objectives:
1. Implement ResNet50 using `tripy` and load pretrained weights.
2. Run predictions on sample images and visualize results.
3. Provide an intuitive and educational overview of ResNet components and their implementation in `tripy`.


# Setup

install neccessary libraries, load pretrained resnet50 weights and the inference dataset.

In [None]:
!pip install datasets matplotlib pillow

In [None]:
# Load necessary libraries
import torch
from transformers import ResNetForImageClassification, AutoImageProcessor
from datasets import load_dataset

# Load the pretrained ResNet50 model from Hugging Face
resnet_pretrained = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
resnet_pretrained.eval()
resnet_pretrained = resnet_pretrained.to('cuda')

# Load the image processor
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

# Load a sample image from the dataset
dataset = load_dataset("rafaelpadilla/coco2017", split="val")
idx = 103  # You can choose any index
image = dataset[idx]['image']
image


In [None]:
# Preprocess the image
inputs = processor(image, return_tensors="pt")['pixel_values'].to('cuda')

# Run the image through the model
with torch.no_grad():
    logits = resnet_pretrained(inputs).logits

# Get the predicted label
predicted_label = logits.argmax(-1).item()
print(f"Predicted Label: {resnet_pretrained.config.id2label[predicted_label]}")


## Tripy Implementation
Now we will implement the model definition in tripy. Most of the resenet50 modules are supported by tripy, but there are a few missing which we will custom define.

In [25]:
import tripy as tp

In [34]:
# class BatchNorm(tp.Module):
#     def __init__(self, num_features, eps=1e-5):
#         super(BatchNorm, self).__init__()
#         self.num_features = num_features
#         self.eps = eps

#         # Learnable parameters (scale and shift)
#         self.weight = tp.Parameter(tp.ones((1, num_features, 1, 1), dtype=tp.float32))
#         self.bias = tp.Parameter(tp.zeros((1, num_features, 1, 1), dtype=tp.float32))

#         # Running statistics (not updated during training)
#         self.running_mean = tp.Parameter(tp.zeros((1, num_features, 1, 1), dtype=tp.float32))
#         self.running_var = tp.Parameter(tp.ones((1, num_features, 1, 1), dtype=tp.float32))
#         self.num_batches_tracked = tp.Parameter(tp.Tensor(0, dtype=tp.int64))

#     def __call__(self, x):
#         # Normalize the input
#         x = (x - self.running_mean) / tp.sqrt(self.running_var + self.eps)

#         # Apply the learned scaling (weight) and shifting (bias)
#         x = self.weight * x + self.bias

#         return x


class ResNetConvLayer(tp.Module):
    def __init__(self, in_channels, out_channels, kernel_dims, stride=(1, 1), padding=((0, 0), (0, 0)), activation=True):
        super(ResNetConvLayer, self).__init__()
        self.convolution = tp.Conv(
            in_channels, out_channels, kernel_dims=kernel_dims,
            stride=stride, padding=padding, bias=False
        )
        self.normalization = BatchNorm(out_channels)
        self.activation = activation

    def __call__(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        if self.activation:
            x = tp.relu(x)
        return x

In [35]:
class ResNetShortCut(tp.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(ResNetShortCut, self).__init__()
        self.convolution = tp.Conv(
            in_channels, out_channels, kernel_dims=(1, 1),
            stride=stride, bias=False
        )
        self.normalization = BatchNorm(out_channels)

    def __call__(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        return x

class ResNetBottleNeckLayer(tp.Module):
    def __init__(self, in_channels, out_channels, bottleneck_channels, stride):
        super(ResNetBottleNeckLayer, self).__init__()

        self.shortcut = ResNetShortCut(in_channels, out_channels, stride) if in_channels != out_channels or stride != (1, 1) else lambda x: x
        self.layer = [
            ResNetConvLayer(in_channels, bottleneck_channels, kernel_dims=(1, 1), stride=(1, 1)),
            ResNetConvLayer(bottleneck_channels, bottleneck_channels, kernel_dims=(3, 3), stride=stride, padding=((1, 1), (1, 1))),
            ResNetConvLayer(bottleneck_channels, out_channels, kernel_dims=(1, 1), stride=(1, 1), activation=False),
        ]

        self.activation = tp.relu

    def __call__(self, x):
        identity = self.shortcut(x)
        for layer in self.layer:
            x = layer(x)
        x = x + identity
        x = self.activation(x)
        return x


Define resenet stage and encoder for the backbone model

In [36]:
class ResNetStage(tp.Module):
    def __init__(self, num_layers, in_channels, out_channels, bottleneck_channels, stride):
        super(ResNetStage, self).__init__()
        self.layers = []
        for i in range(num_layers):
            layer = ResNetBottleNeckLayer(
                in_channels if i == 0 else out_channels,
                out_channels, bottleneck_channels,
                stride if i == 0 else (1, 1)
            )
            self.layers.append(layer)

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class ResNetEncoder(tp.Module):
    def __init__(self, layers_config):
        super(ResNetEncoder, self).__init__()
        self.stages = []
        in_channels = 64
        for idx, (num_layers, out_channels, bottleneck_channels, stride) in enumerate(layers_config):
            stage = ResNetStage(num_layers, in_channels, out_channels, bottleneck_channels, stride)
            self.stages.append(stage)
            in_channels = out_channels

    def __call__(self, x):
        for stage in self.stages:
            x = stage(x)
        return x


Define an embeddings class for input pre-procDefine a max pooling class and embeddings class for input pre-processing before feeding to encoder

In [45]:
class ResNetEmbeddings(tp.Module):
    def __init__(self):
        super(ResNetEmbeddings, self).__init__()
        self.embedder = ResNetConvLayer(3, 64, kernel_dims=(7, 7), stride=(2, 2), padding=((3, 3), (3, 3)))

    def __call__(self, x):
        x = self.embedder(x)
        x = tp.maxpool(x, kernel_dims=(3, 3), stride=(2, 2), padding=((1, 1), (1, 1)))
        return x

Define a adaptive average pooling helper function and the resenet backbone model itself

In [46]:
def adaptive_avg_pool(x, output_size: tuple[int, int]):
    _, _, H_in, W_in = x.shape
    H_out, W_out = output_size

    # Calculate stride and kernel size
    stride_h = H_in // H_out
    stride_w = W_in // W_out

    kernel_size_h = H_in - (H_out - 1) * stride_h
    kernel_size_w = W_in - (W_out - 1) * stride_w

    return tp.avgpool(x, kernel_dims=(int(kernel_size_h), int(kernel_size_w)), stride=(int(stride_h), int(stride_w)))

class ResNetModel(tp.Module):
    def __init__(self):
        super(ResNetModel, self).__init__()
        self.embedder = ResNetEmbeddings()
        layers_config = [
            (3, 256, 64, (1, 1)),
            (4, 512, 128, (2, 2)),
            (6, 1024, 256, (2, 2)),
            (3, 2048, 512, (2, 2)),
        ]
        self.encoder = ResNetEncoder(layers_config)

    def __call__(self, x):
        x = self.embedder(x)
        x = self.encoder(x)
        x = adaptive_avg_pool(x, output_size=(1, 1))
        return x




Define resenet classifier head and a wrapper class `TPResNetClassifier` to combine backbone model and classifier.

In [47]:
class ResNetClassifierHead(tp.Module):
    def __init__(self, in_features, num_classes):
        super(ResNetClassifierHead, self).__init__()
        setattr(self, "1", tp.Linear(in_features, num_classes))

    def __call__(self, x):
        x = tp.flatten(x, start_dim=1)
        x = getattr(self, "1")(x)
        return x
    
class ResNetClassifier(tp.Module):
    def __init__(self, num_classes=1000):
        super(ResNetClassifier, self).__init__()
        self.resnet = ResNetModel()
        self.classifier = ResNetClassifierHead(2048, num_classes)

    def __call__(self, x):
        features = self.resnet(x)
        output = self.classifier(features)
        return output

In [48]:
# Instantiate the Tripy model
tripy_model = ResNetClassifier()

## Load weights to Tripy Model

First we ensure all state dict keys match and then we load the pretrained weights

In [49]:
# Ensure all keys match
assert (set(list(resnet_pretrained.state_dict().keys())) - set(list(tripy_model.state_dict().keys()))) == set()

In [50]:
def load_pretrained_weights_tp(model, state_dict):
    """
    This function loads weights from a state_dict into the custom TPResNet model,
    converting them into tp.Parameter before loading.
    """
    model_dict = model.state_dict()  # Get the model's state dict
    pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}  # Filter weights that exist in model

    # Convert each weight to a tp.Parameter and update the model's state dict
    for k, v in pretrained_dict.items():
        assert k in set(model_dict.keys())
        new_v = tp.Parameter(v.contiguous())
        # if k.endswith("normalization.weight") or \
        #     k.endswith("normalization.bias") or \
        #     k.endswith("normalization.running_mean") or \
        #     k.endswith("normalization.running_var"):
            
        #     new_v = tp.reshape(new_v, (1, new_v.shape[0], 1, 1))

        model_dict[k] = new_v 

    # Load the updated state dict
    model.load_state_dict(model_dict)

    return model

pretrained_weights = resnet_pretrained.state_dict()
tripy_model = load_pretrained_weights_tp(tripy_model, pretrained_weights)


In [51]:
# Create an input tensor
input_shape = [1, 3, 224, 224]
x = tp.ones(input_shape, dtype=tp.float32)

In [52]:
compiled_model = tp.compile(tripy_model, args=[tp.InputInfo(input_shape, dtype=tp.float32)])

In [53]:
compiled_output = compiled_model(x)

## Tripy Demo

Now we will test our tripy resnet50 classification model on a few sample images for image classification task 

In [None]:
# Prepare the image for Tripy
processed_image = processor(image, return_tensors="np")['pixel_values']
tp_image = tp.Tensor(processed_image, dtype=tp.float32, device=tp.device("gpu"))
tp_image.shape


In [None]:
# Run the image through the model
tp_logits = compiled_model(tp_image)
tp_logits.shape


In [None]:
# Get the predicted label
tp_predicted_label = torch.argmax(torch.from_dlpack(tp_logits), -1).tolist()[0]
print(f"Predicted Label: {resnet_pretrained.config.id2label[tp_predicted_label]}")
image

In [57]:
def get_tp_pred(model, image, processor):
    processed_image = processor(image, return_tensors="np")['pixel_values']
    tp_image = tp.Tensor(processed_image, dtype=tp.float32, device=tp.device("gpu"))
    tp_logits = model(tp_image)
    return torch.argmax(torch.from_dlpack(tp_logits), -1).tolist()[0]

def get_exact_answer(model, image, processor):
    inputs = processor(image, return_tensors="pt")['pixel_values'].to('cuda')

    with torch.no_grad():
        logits_pretained = model(inputs).logits

    return logits_pretained.argmax(-1).item()


In [None]:
# Visualize predictions
import matplotlib.pyplot as plt

idxs = [1, 44, 10, 28, 105]

# Create a figure with a subplot for each image
fig, axes = plt.subplots(1, len(idxs), figsize=(15, 3)) 

# Loop through each index and corresponding subplot
for ax, nid in zip(axes, idxs):
    image = dataset[nid]['image']
    ax.imshow(image)  # Display the image
    tp_predicted_label = get_tp_pred(compiled_model, image, processor)
    exact_answer = get_exact_answer(resnet_pretrained, image, processor)

    pred = resnet_pretrained.config.id2label[tp_predicted_label]
    exact = resnet_pretrained.config.id2label[exact_answer]

    assert pred == exact, f"{pred} vs {exact}"

    ax.set_title(pred, fontsize=10, fontweight='bold') 
    ax.axis('off') 
    ax.set_xticks([]) 
    ax.set_yticks([]) 

# Adjust layout
plt.tight_layout(pad=3.0)  # Increase padding between subplots
plt.suptitle("Visualized Predictions", fontsize=16, fontweight='bold', y=1.05)  # Title for the entire figure
plt.show()
