In [6]:
import torch
import torch.nn as nn
from torchvision import transforms

from tqdm import tqdm

import sys

if '..' not in sys.path:
    sys.path.append('..')   

In [7]:
from datasets.Muri.MuriDataset import MuriDataset

In [8]:
input_size = 256

normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
imagenet_transform = transforms.Compose(
    [
        transforms.Resize(input_size),
        transforms.ToTensor(),
        normalize,
    ]
)

root = "/home/soroush1/projects/def-kohitij/soroush1/training_fast_publish_faster/data/muri1320"

ds = MuriDataset(root=root, transforms=imagenet_transform)

In [9]:
from torchvision import models
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)

import torch.nn.functional as F
import numpy as np

class BlurPoolConv2d(torch.nn.Module):

    # Purpose: This class creates a convolutional layer that first applies a blurring filter to the input before performing the convolution operation.
    # Condition: The function apply_blurpool iterates over all layers of the model and replaces convolution layers (ch.nn.Conv2d) with BlurPoolConv2d if they have a stride greater than 1 and at least 16 input channels.
    # Preventing Aliasing: Blurring the output of convolution layers (especially those with strides greater than 1) helps to reduce aliasing effects. Aliasing occurs when high-frequency signals are sampled too sparsely, leading to incorrect representations.
    # Smooth Transitions: Applying a blur before downsampling ensures that transitions between pixels are smooth, preserving important information in the feature maps.
    # Stabilizing Training: Blurring can help stabilize training by reducing high-frequency noise, making the model less sensitive to small changes in the input data.
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer("blur_filter", filt)

    def forward(self, x):
        blurred = F.conv2d(
            x,
            self.blur_filter,
            stride=1,
            padding=(1, 1),
            groups=self.conv.in_channels,
            bias=None,
        )
        return self.conv.forward(blurred)
    
    
def apply_blurpool(mod: torch.nn.Module):
    for name, child in mod.named_children():
        if isinstance(child, torch.nn.Conv2d) and (
            np.max(child.stride) > 1 and child.in_channels >= 16
        ):
            setattr(mod, name, BlurPoolConv2d(child))
        else:
            apply_blurpool(child)

def remove_prefix(state_dict, prefix):
    """
    Remove a prefix from the state_dict keys.

    Args:
    state_dict (dict): State dictionary from which the prefix will be removed.
    prefix (str): Prefix to be removed.

    Returns:
    dict: State dictionary with prefix removed from keys.
    """
    return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)}

def match_and_load_weights(checkpoint_state_dict, model, prefix='module.'):
    """
    Match weights from checkpoint_state_dict with model's state_dict and load them into the model.

    Args:
    checkpoint_state_dict (dict): State dictionary from checkpoint.
    model (torch.nn.Module): The model instance.
    prefix (str): Prefix to be removed from checkpoint keys.

    Returns:
    None
    """
    # Remove the prefix from checkpoint state dict keys
    cleaned_checkpoint_state_dict = remove_prefix(checkpoint_state_dict, prefix)
    
    model_state_dict = model.state_dict()
    matched_weights = {}

    # Iterate over the cleaned checkpoint state dict
    for ckpt_key, ckpt_weight in cleaned_checkpoint_state_dict.items():
        if ckpt_key in model_state_dict:
            # If the layer name matches, add to the matched_weights dict
            matched_weights[ckpt_key] = ckpt_weight
        else:
            print(f"Layer {ckpt_key} from checkpoint not found in the model state dict.")
    
    return matched_weights

In [14]:
model_name = "resnet50"
checkpoint_path = "/home/soroush1/projects/def-kohitij/soroush1/training_fast_publish_faster/weights/resnet50-1/checkpoint_epoch_90_0.77.pth"
layer_name = "layer3.2.bn1"
use_blurpool = True

model = getattr(models, model_name)(weights=None)

if use_blurpool:
    apply_blurpool(model)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

model.to(device)
    
checkpoint_state_dict = torch.load(checkpoint_path, map_location=device)
checkpoint_state_dict = checkpoint_state_dict["model_state_dict"]

update_state_dict = match_and_load_weights(checkpoint_state_dict, model)
model.load_state_dict(update_state_dict)

model = create_feature_extractor(model, return_nodes={layer_name: "it"})
model.eval()

outputs = []
for i in tqdm(range(len(ds))):
    x = ds[i]
    x = x.unsqueeze_(0)
    x = x.to(device)
    output = model(x)
    if device == "cpu":
        output = output['it'].flatten(start_dim=0).detach().numpy().reshape(1, -1)
    else:
        output = output['it'].flatten(start_dim=0).detach().cpu().numpy().reshape(1, -1)
    outputs.append(output)

outputs = np.concatenate(outputs, axis=0)
print(f"{outputs.shape = }")
# x = torch.cat(imgs, dim=0)
# print(f"{x.size() = }")

# import time

# start = time.time()
# output = model(x)

# if device == "cpu":
#     output = output['it'].flatten(start_dim=0).detach().numpy().reshape(1, -1)
# else:
#     output = output['it'].flatten(start_dim=0).detach().cpu().numpy().reshape(1, -1)
# # print time taken to extract features
# print(f"Time taken to extract features: {time.time() - start:.2f} seconds")

# print(f"{output.shape = }")

100%|██████████| 1320/1320 [00:07<00:00, 176.44it/s]

outputs.shape = (1320, 65536)



