### Model Definition

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random


manualSeed = 42
DEFAULT_THRESHOLD = 5e-3

random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
np.random.seed(manualSeed)
cudnn.benchmark = False
torch.backends.cudnn.enabled = False
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
GEN_KERNEL = 3
num_cf = 2


class TemplateBank(nn.Module):
    def __init__(self, num_templates, in_planes, out_planes, kernel_size):
        super(TemplateBank, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.coefficient_shape = (num_templates, 1, 1, 1, 1)
        self.kernel_size = kernel_size
        templates = [
            torch.Tensor(out_planes, in_planes, kernel_size, kernel_size)
            for _ in range(num_templates)
        ]
        for i in range(num_templates):
            nn.init.kaiming_normal_(templates[i])
        self.templates = nn.Parameter(
            torch.stack(templates)
        )  # this is what we will freeze later

    def forward(self, coefficients):
        weights = (self.templates * coefficients).sum(0)
        return weights

    def __repr__(self):
        return (
            self.__class__.__name__
            + " ("
            + "num_templates="
            + str(self.coefficient_shape[0])
            + ", kernel_size="
            + str(self.kernel_size)
            + ")"
            + ", in_planes="
            + str(self.in_planes)
            + ", out_planes="
            + str(self.out_planes)
        )


class SConv2d(nn.Module):
    # TARGET MODULE
    def __init__(self, bank, stride=1, padding=1):
        super(SConv2d, self).__init__()
        self.stride = stride
        self.padding = padding
        self.bank = bank
        self.num_templates = bank.coefficient_shape[0]

        self.coefficients = nn.ParameterList(
            [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
        )

    def forward(self, input):
        param_list = []
        for i in range(len(self.coefficients)):
            params = self.bank(self.coefficients[i])
            param_list.append(params)

        final_params = torch.stack(param_list).mean(0)
        return F.conv2d(input, final_params, stride=self.stride, padding=self.padding)


class CustomResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride=1,
        downsample=None,
        bank1=None,
        bank2=None,
    ):
        super(CustomResidualBlock, self).__init__()
        self.bank1 = bank1
        self.bank2 = bank2

        # Ensure padding is always 1 for 3x3 convolutions
        if self.bank1 and self.bank2:
            self.conv1 = SConv2d(bank1, stride=stride, padding=1)
            self.conv2 = SConv2d(bank2, stride=1, padding=1)
        else:
            self.conv1 = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                padding=1,
                bias=False,
            )
            self.conv2 = nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            )

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Implement downsample as 1x1 convolution when needed
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.downsample = None

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, SConv2d):
                for coefficient in m.coefficients:
                    nn.init.orthogonal_(coefficient)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNetTPB(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNetTPB, self).__init__()
        self.inplanes = 64
        self.layers = layers
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes, planes, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))

        # DYNAMICALLY CALCULATE THE NUMBER OF TEMPLATES TO USE FOR EACH RESIDUAL BLOCK
        # Calculate parameters for remaining blocks
        params_per_conv = 9 * planes * planes
        params_per_template = 9 * planes * planes
        num_templates1 = max(
            1, int((blocks - 1) * params_per_conv / params_per_template)
        )
        num_templates2 = (
            num_templates1  # You could potentially use a different calculation here
        )

        print(
            f"Layer with {planes} planes, {blocks} blocks, using {num_templates1} templates for conv1 and {num_templates2} for conv2"
        )

        # Create separate TemplateBanks for conv1 and conv2
        tpbank1 = TemplateBank(num_templates1, planes, planes, GEN_KERNEL)
        tpbank2 = TemplateBank(num_templates2, planes, planes, GEN_KERNEL)

        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(
                block(
                    in_channels=self.inplanes,
                    out_channels=planes,
                    bank1=tpbank1,
                    bank2=tpbank2,
                )
            )

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


### Load Weight for non-target module

In [None]:
from torchvision.models import resnet34, resnet18
resnet34 = resnet34(pretrained=True)
# resnet34 = resnet18(pretrained=True)

FOUND = []

def load_weights(current_model, target_model):
    current_dict = current_model.state_dict()
    target_dict = target_model.state_dict()

    new_dict = {}
    for k, v in target_dict.items():
        if k in current_dict:
            if current_dict[k].shape == target_dict[k].shape:
                new_dict[k] = v
                FOUND.append(k)
            else:
                print(f"Shape mismatch for key: {k}")
        else:
            print(f"Key not found: {k}")

    current_dict.update(new_dict)
    current_model.load_state_dict(current_dict)
    return current_model

my_model = ResNetTPB(CustomResidualBlock, [3,4,6,3], num_classes=1000)
my_model = load_weights(my_model, resnet34)
my_model_state = list(my_model.state_dict().keys())
print(f"Found: {len(FOUND)}")
print(f"Total: {len(my_model_state)}")

target_params = []
for k in my_model_state:
    if k not in FOUND:
        target_params.append(k)

print(f"Target params: {len(target_params)}")
for name, param in my_model.named_parameters():
    if name in target_params:
        param.requires_grad = True
    else:
        param.requires_grad = False

### Reconstruction loop

In [None]:
def reconstruction_loss_dynamic(current_model, pretrained_model, criterion=nn.SmoothL1Loss(), w1_weight=3.5, w2_weight=2.5):
    corr_state_dict = pretrained_model.state_dict()
    loss_dict = {}
    total_loss = 0.0
    w1_loss = 0.0
    w2_loss = 0.0
    
    # Determine the device of the current model
    device = next(current_model.parameters()).device
    
    for layer_idx, layer in enumerate([current_model.layer1, current_model.layer2, current_model.layer3, current_model.layer4]):
        for block_idx, block in enumerate(layer):
            if isinstance(block, CustomResidualBlock) and block.bank1 is not None:
                conv1_cf = block.conv1.coefficients
                conv1_bank = block.bank1
                weights1 = []
                noise_std1 = 0.0
                for c in conv1_cf:
                    if current_model.training:
                        noise = torch.randn_like(c) * noise_std1
                        c = c + noise
                    w = conv1_bank(c)
                    weights1.append(w)
        
                _weights1 = torch.stack(weights1).mean(0)
                corr_weight1 = corr_state_dict[f'layer{layer_idx+1}.{block_idx}.conv1.weight'].to(device)
                w1_l = criterion(_weights1, corr_weight1) * w1_weight

                loss_dict[f'layer{layer_idx+1}.{block_idx}.bank1.templates'] = w1_l
                for i, cf in enumerate(conv1_cf):
                    loss_dict[f'layer{layer_idx+1}.{block_idx}.conv1.coefficients.{i}'] = w1_l

                w1_loss += w1_l.item()
                total_loss += w1_l

            if isinstance(block, CustomResidualBlock) and block.bank2 is not None:
                conv2_cf = block.conv2.coefficients
                conv2_bank = block.bank2
                noise_std2 = 0.0
                weights2 = []
                for c in conv2_cf:
                    if current_model.training:
                        noise = torch.randn_like(c) * noise_std2
                        c = c + noise
                    w = conv2_bank(c)
                    weights2.append(w)
                _weights2 = torch.stack(weights2).mean(0)
                corr_weight2 = corr_state_dict[f'layer{layer_idx+1}.{block_idx}.conv2.weight'].to(device)
                w2_l = criterion(_weights2, corr_weight2) * w2_weight

                loss_dict[f'layer{layer_idx+1}.{block_idx}.bank2.templates'] = w2_l
                for i, cf in enumerate(conv2_cf):
                    loss_dict[f'layer{layer_idx+1}.{block_idx}.conv2.coefficients.{i}'] = w2_l

                w2_loss += w2_l.item()
                total_loss += w2_l

    return loss_dict, total_loss, w1_loss, w2_loss
import copy
optimizer = torch.optim.RMSprop(my_model.parameters(), lr=0.2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)
best_model = None
best_model_loss = 1e9
my_model.train()
for epoch in range(3000):
    loss_dict, total_loss, conv1_loss, conv2_loss = reconstruction_loss_dynamic(my_model, resnet34)
    if total_loss < best_model_loss:
        best_model_loss = total_loss.item()
        best_model = copy.deepcopy(my_model)
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(my_model.parameters(), 1)
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    if epoch % 100 == 0:
        print(f"Epoch: {epoch} Total Loss: {total_loss.item()} ")
        print(f"Conv1 Loss: {conv1_loss} Conv2 Loss: {conv2_loss}")
        print(f"LR: {scheduler.get_last_lr()}")

### Evaluate Reconstruction

#### Comparing Cosine Similarity

In [None]:
def evaluate_dynamic_reconstruction(current_model, pretrained_model):
    similarity = 0.0
    metric = nn.CosineSimilarity(dim=0)
    corr_state_dict = pretrained_model.state_dict()
    w1_loss = 0.0
    w2_loss = 0.0
    total_blocks = 0

    for layer_idx, layer in enumerate([current_model.layer1, current_model.layer2, current_model.layer3, current_model.layer4], 1):
        for block_idx, block in enumerate(layer):
            w1_l = 0
            w2_l = 0
            
            if block_idx == 0 or not isinstance(block, CustomResidualBlock):
                # Handle first block or non-CustomResidualBlock
                if hasattr(block, 'conv1') and hasattr(block.conv1, 'weight'):
                    w1 = block.conv1.weight
                    corr_w1 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv1.weight']
                    w1_l = metric(w1.view(-1), corr_w1.view(-1))
                    w1_loss += w1_l.item()

                if hasattr(block, 'conv2') and hasattr(block.conv2, 'weight'):
                    w2 = block.conv2.weight
                    corr_w2 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv2.weight']
                    w2_l = metric(w2.view(-1), corr_w2.view(-1))
                    w2_loss += w2_l.item()

            else:
                # Handle CustomResidualBlock with template banks
                if hasattr(block, 'bank1') and block.bank1 is not None:
                    conv1_cf = block.conv1.coefficients
                    conv1_bank = block.bank1
                    weights1 = torch.stack([conv1_bank(cf) for cf in conv1_cf]).mean(0)
                    corr_weight1 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv1.weight']
                    w1_l = metric(weights1.view(-1), corr_weight1.view(-1))
                    w1_loss += w1_l.item()

                if hasattr(block, 'bank2') and block.bank2 is not None:
                    conv2_cf = block.conv2.coefficients
                    conv2_bank = block.bank2
                    weights2 = torch.stack([conv2_bank(cf) for cf in conv2_cf]).mean(0)
                    corr_weight2 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv2.weight']
                    w2_l = metric(weights2.view(-1), corr_weight2.view(-1))
                    w2_loss += w2_l.item()

            print(f"Layer: {layer_idx}, Block: {block_idx}, W1: {w1_l:.4f}, W2: {w2_l:.4f}")
            total_blocks += 1

    avg_w1 = w1_loss / total_blocks
    avg_w2 = w2_loss / total_blocks

    print(f"Average - W1: {avg_w1:.4f}, W2: {avg_w2:.4f}")
    return w1_loss, w2_loss

similarity_dict = evaluate_dynamic_reconstruction(best_model, resnet34)

#### Comparing layerwise feature similarity

In [None]:
from PIL import Image
calculate_parameters(my_model)
import torch
import torch.nn as nn
from PIL import Image
import os
import random
import numpy as np
# Load the batch from disk
batch = torch.load("<DATALOADER_BATCH_PATH>")

import torch
import torch.nn as nn

# Modified forward function to return intermediate features
def forward_with_intermediate(model, x, is_custom=True):
    if is_custom:
        print("Using custom model")
    else:
        print("Using default model")
    intermediate_features = []
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)

    layer1 = model.layer1(x)
    intermediate_features.append(layer1)
    layer2 = model.layer2(layer1)
    intermediate_features.append(layer2)
    layer3 = model.layer3(layer2)
    intermediate_features.append(layer3)
    layer4 = model.layer4(layer3)
    intermediate_features.append(layer4)

    x = model.avgpool(layer4)
    x = torch.flatten(x, 1)
    x = model.fc(x)

    return x, intermediate_features

# Function to compare features layerwise
def compare_features_layerwise(model1, model2, batch):
    batch = batch.to("cpu")  # Assuming CPU for computation
    _, features1 = forward_with_intermediate(model1.to("cpu"), batch, is_custom=True)
    _, features2 = forward_with_intermediate(model2.to("cpu"), batch, is_custom=False)
    
    cos = nn.CosineSimilarity(dim=1)  # Change dim to 1 for channel dimension
    similarities = []
    
    for f1, f2 in zip(features1, features2):
        # Flatten the features except for the batch and channel dimensions
        f1_flat = f1.view(f1.size(0), f1.size(1), -1)
        f2_flat = f2.view(f2.size(0), f2.size(1), -1)
        
        # Compare the flattened features
        similarity = cos(f1_flat.mean(dim=2), f2_flat.mean(dim=2)).mean()
        similarities.append(similarity.item())
    
    return similarities




similarities = compare_features_layerwise(best_model, resnet34, batch)
av_sim = sum(similarities) / len(similarities)
print(f"Average similarity: {av_sim:.4f}")
# Print or plot the similarities
for i, sim in enumerate(similarities):
    print(f"Layer {i+1} similarity: {sim:.4f}")

# You can also plot these similarities
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(range(1, len(similarities) + 1), similarities, marker='o')
plt.title("Layer-wise Feature Similarity")
plt.xlabel("Layer")
plt.ylabel("Cosine Similarity")
plt.grid(True)
plt.show()