# Using ResNet18 for Feature Extraction

In this notebook, we'll be using ResNet18 for feature extraction.

# importing Pkg

In [1]:
import torch
import torch.nn.functional as F 
import torch.nn as nn
from torchvision import models

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

backbones = {
    "resnet18": models.resnet18,
    "wide_resnet50": models.wide_resnet50_2,
    "resnext50_32x4d": models.resnext50_32x4d,
    "resnet50": models.resnet50,
}


# Feature extraction class

In [32]:
# Backbone name
# layer: which i want to extract features from
class FeatureExtraction(nn.Module):
    def __init__(self, backbone_name, layer_indices, device):
        super().__init__()
        self.backbone = backbones[backbone_name](weights =True)
        self.device =device 
        self.layer_indices = layer_indices
    
    def forward(self,batch,layer_hook=None):
        with torch.no_grad():
            batch = self.backbone.conv1(batch)        
            batch = self.backbone.bn1(batch)
            batch = self.backbone.relu(batch)
            batch = self.backbone.maxpool(batch)
            
            
            layer1 = self.backbone.layer1(batch)
            layer2 = self.backbone.layer2(layer1)
            layer3 = self.backbone.layer3(layer2)
            layer4 = self.backbone.layer4(layer3)
            layers = [layer1,layer2,layer3,layer4]
            
            if  self.layer_indices is not None:
                layers = [layers[i] for i in self.layer_indices]
            
            if layer_hook:
                layers = [layer_hook(layer) for layer in layers]
            embeddings = concatenate_layers(layers)
            return embeddings
                            

# Concatenating the features

In [31]:

def concatenate_layers(layers):
    size = layers[0].shape[-2:]
    resized_layers = [F.interpolate(layer, size=size, mode="nearest") if layer.shape[-2:] != size else layer for layer in layers]
    return torch.cat(resized_layers, dim=1)


# Initialize the FeatureExtraction

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

extractor = FeatureExtraction(
    backbone_name="wide_resnet50",
    layer_indices=[0,1],
    device=device
)

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to C:\Users\abdulgader/.cache\torch\hub\checkpoints\wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:04<00:00, 31.9MB/s] 


# Transformer - Preprossing

In [29]:

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


# Example

In [39]:
from PIL import Image
img = Image.open("img1.jpg").convert("RGB")


batch = transform(img).unsqueeze(0).to(device)
embeddings = extractor(batch)


print("embeddings",  embeddings.shape)

embeddings torch.Size([1, 768, 56, 56])


In [None]:

# print("Layer1",  embeddings[0].shape)
# print("Layer2",  embeddings[1].shape)
# print("Layer3",  embeddings[2].shape)
# print("Layer4",  embeddings[3].shape)

Layer1 torch.Size([1, 64, 56, 56])
Layer2 torch.Size([1, 128, 28, 28])
Layer3 torch.Size([1, 256, 14, 14])
Layer4 torch.Size([1, 512, 7, 7])
