# ResNet50 Image Classificiation with `tripy` - An Implementation and Demo

This notebook aims to demonstrate the implementation of a ResNet50 model using `tripy`. We'll explore the architecture, 
load pretrained weights, and run predictions on a sample dataset, showcasing how `tripy` can be used effectively for inference of ResNet50 based image classification model.

### Objectives:
1. Implement ResNet50 using `tripy` and load pretrained weights.
2. Run predictions on sample images and visualize results.


# Setup

install neccessary libraries, load pretrained resnet50 weights and the inference dataset.

In [None]:
!pip install datasets matplotlib pillow

In [None]:
# Load necessary libraries
import torch
from transformers import ResNetForImageClassification, AutoImageProcessor
from datasets import load_dataset

# Load the pretrained ResNet50 model from Hugging Face
resnet_pretrained = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
resnet_pretrained.eval()
resnet_pretrained = resnet_pretrained.to('cuda')

# Load the image processor
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

# Load a sample image from the dataset
dataset = load_dataset("comet-team/coco-500", split="train")
idx = 103
image = dataset[idx]['Image']
image

In [None]:
# Preprocess the image
inputs = processor(image, return_tensors="pt")['pixel_values'].to('cuda')

# Run the image through the model
with torch.no_grad():
    logits = resnet_pretrained(inputs).logits

# Get the predicted label
predicted_label = logits.argmax(-1).item()
print(f"Predicted Label: {resnet_pretrained.config.id2label[predicted_label]}")


## Tripy Implementation

In [4]:
import tripy as tp

We will define:

1. **ResNet Building Blocks**:
   - **`ResNetConvLayer`**: A configurable convolutional layer with batch normalization and optional ReLU activation.
   - **`ResNetBottleNeckLayer`**: A layer consisting of three `ResNetConvLayer`s with a shortcut connection for residual learning.

2. **`ResNetStage`**:
   - Each stage contains multiple bottleneck layers, increasing channels and reducing spatial dimensions.

3. **`ResNetEmbeddings`**:
   - Applies an initial convolution and max-pooling operation to reduce dimensions and prepare input for the encoder.

4. **`ResNetEncoder`**:
   - Comprises multiple stages of bottleneck layers to extract abstract features from the input.

5. **`ResNetModel` Backbone**:
   - Combines `ResNetEmbeddings` and `ResNetEncoder` to create the backbone for feature extraction.

6. **`ResNetClassifier`**:
   - Combines the backbone (`ResNetModel`) and a classifier head for outputting class probabilities.


### Our final image classification model is visualized below:

```text
Input Image (224 x 224 x 3)
        |
        v
+------------------------------------------------------------------------------+
|                               ResNetClassifier                               |
|                                                                              |
|   +----------------------------------------------------------+               |
|   |                    ResNetModel (Backbone)                |               |
|   |                                                          |               |
|   | +----------------------+   +-------------------------+   |               |
|   | | ResNetEmbeddings     |   | ResNetEncoder           |   |               |
|   | | Conv7x7, 64, Stride 2|-->| (4 stages with          |   |               |
|   | | MaxPool 3x3, Stride 2|   | bottlenecks):           |   |               |
|   | | Output: (56 x 56 x 64)|  | - Stage 1: (3 layers)   |   |               |
|   | +----------------------+   | - Stage 2: (4 layers)   |   |               |
|   |                            | - Stage 3: (6 layers)   |   |               |
|   |                            | - Stage 4: (3 layers)   |   |               |
|   |                            | Output: (1 x 1 x 2048)  |   |               |
|   |                            +-------------------------+   |               |
|   |                                            |             |               |
|   |                                            v             |               |
|   |  +----------------------------------------------------+  |               |
|   |  | Average Pooling (kernel 7x7, strid 7x7)            |  |               |
|   |  | Output: (1 x 1 x 2048)                             |  |               |
|   |  +----------------------------------------------------+  |               |
|                               |                                              |
|                               v                                              |
|   +-------------------------------------------------------+                  |
|   |                   Classifier                          |                  |
|   |                 Fully Connected Layer                 |                  | -> Ouput classes probablity (1000)
|   |               Output: (1000 classes)                  |                  |
|   +-------------------------------------------------------+                  |
+-------------------------------------------------------------------------------+

```

In [5]:
class ResNetConvLayer(tp.Module):
    def __init__(self, in_channels, out_channels, kernel_dims, stride=(1, 1), padding=((0, 0), (0, 0)), activation=True):
        super(ResNetConvLayer, self).__init__()
        self.convolution = tp.Conv(
            in_channels, out_channels, kernel_dims=kernel_dims,
            stride=stride, padding=padding, bias=False
        )
        self.normalization = tp.BatchNorm(out_channels)
        self.activation = activation

    def __call__(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        if self.activation:
            x = tp.relu(x)
        return x
    
class ResNetShortCut(tp.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(ResNetShortCut, self).__init__()
        self.convolution = tp.Conv(
            in_channels, out_channels, kernel_dims=(1, 1),
            stride=stride, bias=False
        )
        self.normalization = tp.BatchNorm(out_channels)

    def __call__(self, x):
        x = self.convolution(x)
        x = self.normalization(x)
        return x

In [6]:
class ResNetBottleNeckLayer(tp.Module):
    def __init__(self, in_channels, out_channels, bottleneck_channels, stride):
        super(ResNetBottleNeckLayer, self).__init__()

        self.shortcut = ResNetShortCut(in_channels, out_channels, stride) if in_channels != out_channels or stride != (1, 1) else lambda x: x
        self.layer = [
            ResNetConvLayer(in_channels, bottleneck_channels, kernel_dims=(1, 1), stride=(1, 1)),
            ResNetConvLayer(bottleneck_channels, bottleneck_channels, kernel_dims=(3, 3), stride=stride, padding=((1, 1), (1, 1))),
            ResNetConvLayer(bottleneck_channels, out_channels, kernel_dims=(1, 1), stride=(1, 1), activation=False),
        ]

        self.activation = tp.relu

    def __call__(self, x):
        identity = self.shortcut(x)
        for layer in self.layer:
            x = layer(x)
        x = x + identity
        x = self.activation(x)
        return x

class ResNetStage(tp.Module):
    def __init__(self, num_layers, in_channels, out_channels, bottleneck_channels, stride):
        super(ResNetStage, self).__init__()
        self.layers = []
        for i in range(num_layers):
            layer = ResNetBottleNeckLayer(
                in_channels if i == 0 else out_channels,
                out_channels, bottleneck_channels,
                stride if i == 0 else (1, 1)
            )
            self.layers.append(layer)

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [7]:
class ResNetEmbeddings(tp.Module):
    def __init__(self):
        super(ResNetEmbeddings, self).__init__()
        self.embedder = ResNetConvLayer(3, 64, kernel_dims=(7, 7), stride=(2, 2), padding=((3, 3), (3, 3)))

    def __call__(self, x):
        x = self.embedder(x)
        x = tp.maxpool(x, kernel_dims=(3, 3), stride=(2, 2), padding=((1, 1), (1, 1)))
        return x


class ResNetEncoder(tp.Module):
    def __init__(self, layers_config):
        super(ResNetEncoder, self).__init__()
        self.stages = []
        in_channels = 64
        for _, (num_layers, out_channels, bottleneck_channels, stride) in enumerate(layers_config):
            stage = ResNetStage(num_layers, in_channels, out_channels, bottleneck_channels, stride)
            self.stages.append(stage)
            in_channels = out_channels

    def __call__(self, x):
        for stage in self.stages:
            x = stage(x)
        return x

In [8]:
class ResNetModel(tp.Module):
    def __init__(self):
        super(ResNetModel, self).__init__()
        self.embedder = ResNetEmbeddings()
        layers_config = [
            (3, 256, 64, (1, 1)),
            (4, 512, 128, (2, 2)),
            (6, 1024, 256, (2, 2)),
            (3, 2048, 512, (2, 2)),
        ]
        self.encoder = ResNetEncoder(layers_config)

    def __call__(self, x):
        x = self.embedder(x)
        x = self.encoder(x)
        x = tp.avgpool(x, kernel_dims=(7, 7), stride=(7, 7)) # output size will be (1, 1)
        return x

In [9]:
class ResNetClassifier(tp.Module):
    def __init__(self, num_classes=1000):
        super(ResNetClassifier, self).__init__()
        self.resnet = ResNetModel()
        self.classifier = tp.Linear(2048, num_classes)

    def __call__(self, x):
        features = self.resnet(x)
        features = tp.flatten(features, start_dim=1)
        output = self.classifier(features)
        return output

In [10]:
# Instantiate the Tripy model
tripy_model = ResNetClassifier()

## Load weights to Tripy Model

First we ensure all state dict keys match and then we load the pretrained weights.

In [None]:
def load_pretrained_weights_tp(model, pretrained_state_dict):
    """
    This function loads weights from a state_dict into the custom TPResNet model,
    converting them into tp.Parameter before loading.
    """
    model_dict = model.state_dict()  # Get the model's state dict

    # Convert each weight to a tp.Parameter and update the model's state dict
    for k, v in pretrained_state_dict.items():
        if 'num_batches_tracked' in k:
            continue    # Skip num_batches_tracked since it is not needed for inference (tp.BatchNorm does not support)
        if 'classifier' in k:
            k = k.replace('1.', '') # Remove extra 1 (slight naming mismtach)

        assert k in set(model_dict.keys())
        new_v = tp.Parameter(v.contiguous())
        model_dict[k] = new_v 

    model.load_state_dict(model_dict)
    return model

pretrained_state_dict = resnet_pretrained.state_dict()
tripy_model = load_pretrained_weights_tp(tripy_model, pretrained_state_dict)

Now we compile the model with static input shape

In [None]:
input_shape = [1, 3, 224, 224]
compiled_model = tp.compile(tripy_model, args=[tp.InputInfo(input_shape, dtype=tp.float32)])

In [None]:
# Run a dummy forward path
x = tp.ones(input_shape, dtype=tp.float32)
compiled_output = compiled_model(x)
compiled_output.shape

## Tripy Demo

Now we will test our tripy resnet50 classification model on a few sample images for image classification task 

In [None]:
# Prepare the image for Tripy
processed_image = processor(image, return_tensors="np")['pixel_values']
tp_image = tp.Tensor(processed_image, dtype=tp.float32, device=tp.device("gpu"))
tp_image.shape

In [None]:
# Run the image through the model
tp_logits = compiled_model(tp_image)
tp_logits.shape

In [None]:
# Get the predicted label
tp_predicted_label = torch.argmax(torch.from_dlpack(tp_logits), -1).tolist()[0]
print(f"Predicted Label: {resnet_pretrained.config.id2label[tp_predicted_label]}")
image

In [18]:
def get_tp_pred(model, image, processor):
    processed_image = processor(image, return_tensors="np")['pixel_values']
    tp_image = tp.Tensor(processed_image, dtype=tp.float32, device=tp.device("gpu"))
    tp_logits = model(tp_image)
    return torch.argmax(torch.from_dlpack(tp_logits), -1).tolist()[0]

def get_exact_answer(model, image, processor):
    inputs = processor(image, return_tensors="pt")['pixel_values'].to('cuda')

    with torch.no_grad():
        logits_pretained = model(inputs).logits

    return logits_pretained.argmax(-1).item()


Visualize a few more classification examplesusing tripy model

In [None]:

import matplotlib.pyplot as plt

idxs = [6, 44, 34, 25, 105]
fig, axes = plt.subplots(1, len(idxs), figsize=(15, 3)) 

# Loop through each index and corresponding subplot
for ax, nid in zip(axes, idxs):
    image = dataset[nid]['Image']
    ax.imshow(image)
    tp_predicted_label = get_tp_pred(compiled_model, image, processor)
    exact_answer = get_exact_answer(resnet_pretrained, image, processor)

    pred = resnet_pretrained.config.id2label[tp_predicted_label]
    exact = resnet_pretrained.config.id2label[exact_answer]

    assert pred == exact, f"{pred} vs {exact}"

    ax.set_title(pred, fontsize=10, fontweight='bold') 
    ax.axis('off') 
    ax.set_xticks([]) 
    ax.set_yticks([]) 


plt.tight_layout(pad=3.0)
plt.suptitle("Visualized Predictions", fontsize=16, fontweight='bold', y=1.05)
plt.show()
