In [None]:
!pip freeze | grep torch

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
from torch.utils.data import Dataset
from torchvision import models
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)

from tqdm import tqdm

import numpy as np

import os
import PIL.Image

In [3]:
class MuriDataset(Dataset):
    def __init__(self, root: str, transforms=None):

        self.root = root
        self.transforms = transforms
        img_path_list = os.listdir(root)
        img_path_list.sort()

        self.img_path_list = img_path_list

    def __len__(self):
        return len(self.img_path_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root, self.img_path_list[idx])
        # print(f"{img_name = }")
        image = PIL.Image.open(img_name).convert("RGB")

        if self.transforms:
            image = self.transforms(image)

        return image

class BlurPoolConv2d(torch.nn.Module):

    # Purpose: This class creates a convolutional layer that first applies a blurring filter to the input before performing the convolution operation.
    # Condition: The function apply_blurpool iterates over all layers of the model and replaces convolution layers (ch.nn.Conv2d) with BlurPoolConv2d if they have a stride greater than 1 and at least 16 input channels.
    # Preventing Aliasing: Blurring the output of convolution layers (especially those with strides greater than 1) helps to reduce aliasing effects. Aliasing occurs when high-frequency signals are sampled too sparsely, leading to incorrect representations.
    # Smooth Transitions: Applying a blur before downsampling ensures that transitions between pixels are smooth, preserving important information in the feature maps.
    # Stabilizing Training: Blurring can help stabilize training by reducing high-frequency noise, making the model less sensitive to small changes in the input data.
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer("blur_filter", filt)

    def forward(self, x):
        blurred = F.conv2d(
            x,
            self.blur_filter,
            stride=1,
            padding=(1, 1),
            groups=self.conv.in_channels,
            bias=None,
        )
        return self.conv.forward(blurred)
        
def apply_blurpool(mod: torch.nn.Module):
    for name, child in mod.named_children():
        if isinstance(child, torch.nn.Conv2d) and (
            np.max(child.stride) > 1 and child.in_channels >= 16
        ):
            setattr(mod, name, BlurPoolConv2d(child))
        else:
            apply_blurpool(child)

def remove_prefix(state_dict, prefix):
    """
    Remove a prefix from the state_dict keys.

    Args:
    state_dict (dict): State dictionary from which the prefix will be removed.
    prefix (str): Prefix to be removed.

    Returns:
    dict: State dictionary with prefix removed from keys.
    """
    return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)}

def match_and_load_weights(checkpoint_state_dict, model, prefix='module.'):
    """
    Match weights from checkpoint_state_dict with model's state_dict and load them into the model.

    Args:
    checkpoint_state_dict (dict): State dictionary from checkpoint.
    model (torch.nn.Module): The model instance.
    prefix (str): Prefix to be removed from checkpoint keys.

    Returns:
    None
    """
    # Remove the prefix from checkpoint state dict keys
    cleaned_checkpoint_state_dict = remove_prefix(checkpoint_state_dict, prefix)
    
    model_state_dict = model.state_dict()
    matched_weights = {}

    # Iterate over the cleaned checkpoint state dict
    for ckpt_key, ckpt_weight in cleaned_checkpoint_state_dict.items():
        if ckpt_key in model_state_dict:
            # If the layer name matches, add to the matched_weights dict
            matched_weights[ckpt_key] = ckpt_weight
        else:
            print(f"Layer {ckpt_key} from checkpoint not found in the model state dict.")
    
    return matched_weights

def get_dataset(root: str, input_size: int = 256):
    
    normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    imagenet_transform = transforms.Compose(
        [
            transforms.Resize(input_size),
            transforms.ToTensor(),
            normalize,
        ]
    )
    
    return MuriDataset(root=root, transforms=imagenet_transform)


def get_prefix(task: str = "imagenet"):
    if task == "imagenet":
        return "module."

    elif task == "memory":
        return "model.model."

    else:
        raise NotImplementedError
    
def get_model(model_name, checkpoint_path: str = None, layer_name: str = None, use_blurpool: bool = True, task: str = "imagenet"):
    """
    Create a model from torchvision.models and load weights from checkpoint if provided.

    Args:
    model_name (str): Name of the model to be created.
    checkpoint_path (str): Path to the checkpoint file.
    layer_name (str): Name of the layer to extract features from.
    use_blurpool (bool): Whether to use BlurPoolConv2d for convolution layers with stride > 1.
    task (str): Whether to be imagenet, memory, or combine
    

    Returns:
    torch.nn.Module: The model instance.
    """
    model = getattr(models, model_name)(pretrained=False)
    
    if use_blurpool:
        apply_blurpool(model)
    
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        prefix = get_prefix(task)
        matched_weights = match_and_load_weights(checkpoint, model, prefix=prefix)
        model.load_state_dict(matched_weights)
    
    if layer_name:
        model = create_feature_extractor(model, [layer_name])
    
    return model

In [8]:
# Runnning this script will extract features from the specified layer of the model for the images in the dataset.
# Please specify the root path of the dataset
# Please specify the model name
# Please specify the path of the checkpoint
# Please specify the layer name

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

root = ''' Please specify the path of the dataset here '''
ds = get_dataset(root)


model_name = ''' Please specify the model name here ''' # example: resnet50
checkpoint_path = ''' Please specify the path of the checkpoint here ''' 
layer_name = ''' Please specify the layer name here '''
task = "memory" # it could be memory, imagenet, or combine

model = get_model(model_name, checkpoint_path, layer_name, task)

model.eval()

outputs = []
for i in tqdm(range(len(ds))):
    x = ds[i]
    x = x.unsqueeze_(0)
    x = x.to(device)
    output = model(x)
    if device == "cpu":
        output = output['it'].flatten(start_dim=0).detach().numpy().reshape(1, -1)
    else:
        output = output['it'].flatten(start_dim=0).detach().cpu().numpy().reshape(1, -1)
    outputs.append(output)

outputs = np.concatenate(outputs, axis=0) # Shape should be (1320, outout_dim_of_it)
print(f"{outputs.shape = }")


In [None]:
# Sample Code

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

root = "/home/soroush1/projects/def-kohitij/soroush1/training_fast_publish_faster/data/muri1320"
ds = get_dataset(root)


model_name = "resnet50" # example: resnet50
checkpoint_path = "/home/soroush1/projects/def-kohitij/soroush1/training_fast_publish_faster/resnet50_logs/resnet50-0/final_weights.pt"
layer_name = "layer3.2.bn1"
task = "imagenet" # it could be memory, imagenet, or combine

model = get_model(model_name, checkpoint_path, layer_name, task)

model.eval()

outputs = []
for i in tqdm(range(len(ds))):
    x = ds[i]
    x = x.unsqueeze_(0)
    x = x.to(device)
    output = model(x)
    if device == "cpu":
        output = output['it'].flatten(start_dim=0).detach().numpy().reshape(1, -1)
    else:
        output = output['it'].flatten(start_dim=0).detach().cpu().numpy().reshape(1, -1)
    outputs.append(output)

outputs = np.concatenate(outputs, axis=0)
print(f"{outputs.shape = }")


# Pseudocode Implementation

In [14]:
model_name = "resnet50"
checkpoint_path = "/home/soroush1/projects/def-kohitij/soroush1/training_fast_publish_faster/weights/resnet50-1/checkpoint_epoch_90_0.77.pth"
layer_name = "layer3.2.bn1"
use_blurpool = True

model = getattr(models, model_name)(weights=None)

if use_blurpool:
    apply_blurpool(model)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
    
checkpoint_state_dict = torch.load(checkpoint_path, map_location=device)
checkpoint_state_dict = checkpoint_state_dict["model_state_dict"]

update_state_dict = match_and_load_weights(checkpoint_state_dict, model)
model.load_state_dict(update_state_dict)

model = create_feature_extractor(model, return_nodes={layer_name: "it"})
model.eval()

outputs = []
for i in tqdm(range(len(ds))):
    x = ds[i]
    x = x.unsqueeze_(0)
    x = x.to(device)
    output = model(x)
    if device == "cpu":
        output = output['it'].flatten(start_dim=0).detach().numpy().reshape(1, -1)
    else:
        output = output['it'].flatten(start_dim=0).detach().cpu().numpy().reshape(1, -1)
    outputs.append(output)

outputs = np.concatenate(outputs, axis=0)
print(f"{outputs.shape = }")


100%|██████████| 1320/1320 [00:07<00:00, 176.44it/s]

outputs.shape = (1320, 65536)





In [15]:
model.state_dict()["conv1.weight"][:10, 0, 0, 0]

tensor([-0.0114, -0.0323,  0.0010,  0.0307, -0.0376,  0.0025, -0.0234, -0.0277,
         0.0225,  0.0516], device='cuda:0')

In [25]:
model_name = "resnet50"
checkpoint_path = "/home/soroush1/projects/def-kohitij/soroush1/training_fast_publish_faster/vgg19_weights/vgg19-0/checkpoint_epoch_90_0.63.pth"
layer_name = "layer3.2.bn1"
use_blurpool = True

model = getattr(models, model_name)(weights=None)

if use_blurpool:
    apply_blurpool(model)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
    
checkpoint_state_dict = torch.load(checkpoint_path, map_location=device)
checkpoint_state_dict = checkpoint_state_dict["state_dict"]

update_state_dict = match_and_load_weights(checkpoint_state_dict, model, prefix='model.model.')
model.load_state_dict(update_state_dict)

KeyError: 'state_dict'