<h1 align=center>Reshape the Data to 4D matrices</h1>

In [1]:
import os
import numpy as np

In [None]:
# Define the path to the directory containing the files
output_dir = f'{ os.getcwd() }/data/output'

In [None]:
models = {
  'alexnet': {
      'features.2': (64, 27, 27), 
      'features.7': (384, 13, 13),
      'features.7': (384, 13, 13), 
      'features.12': (256, 6, 6)
      },
  'resnet-50-robust': {
      'layer3.0.downsample.0': (1024, 14, 14), 
      'layer4.0.downsample.0': (2048, 7, 7), 
      'layer3.0.downsample.0': (1024, 14, 14), 
      'layer4.0.downsample.0': (2048, 7, 7)
      },
  'resnet152_imagenet_full': {
      'layer1.0.bn1': (64, 56, 56), 
      'layer3.3.bn3': (1024, 14, 14), 
      'layer3.0.bn3': (1024, 14, 14), 
      'layer3.34.bn3': (1024, 14, 14)
    },
  'resnext101_32x32d_wsl': {
      'layer1.0.relu': (256, 56, 56), 
      'layer3.0.relu': (1024, 14, 14), 
      'layer2.0.relu': (512, 28, 28), 
      'layer3.21.relu': (1024, 14, 14)
    },
  'convnext_small_imagenet_100_seed-0': {
      'features.5.2.block.0': (384, 14, 14), 
      'features.5.17.block.0': (384, 14, 14), 
      'features.4.0': (192, 28, 28), 
      'features.5.9.block.0': (384, 14, 14)
    },
  'convnext_small_imagenet_10_seed-0': {
      'features.5.2.block.0': (384, 14, 14), 
      'features.5.17.block.0': (384, 14, 14), 
      'features.4.0': (192, 28, 28), 
      'features.5.9.block.0': (384, 14, 14)
    },
  'resnext101_32x48d_wsl': {
      'layer2.2.relu': (512, 28, 28), 
      'layer3.0.relu': (1024, 14, 14), 
      'layer2.0.relu': (512, 28, 28), 
      'layer3.20.relu': (1024, 14, 14)
      },
  'resnet50_ecoset_full': {
      'layer1.0.bn1': (64, 56, 56), 
      'layer4.0.conv2': (512, 7, 7), 
      'layer3.0.conv1': (256, 28, 28), 
      'layer4.0.relu': (2048, 7, 7)
    },
  'resnet50_imagenet_100_seed-0': {
      'layer1.0.conv1': (64, 56, 56), 
      'layer3.5.bn3': (1024, 14, 14), 
      'layer3.0.conv1': (256, 28, 28), 
      'layer4.0.relu': (2048, 7, 7)
    },
  'resnet101_ecoset_full': {
      'layer1.0.bn1': (64, 56, 56), 
      'layer3.4.relu': (1024, 14, 14), 
      'layer3.0.bn3': (1024, 14, 14), 
      'layer4.0.relu': (2048, 7, 7)
      },
  'resnext101_32x8d_wsl': {
      'layer2.3.relu': (512, 28, 28), 
      'layer3.4.relu': (1024, 14, 14), 
      'layer2.1.relu': (512, 28, 28), 
      'layer3.3.relu': (1024, 14, 14)
      },
  'convnext_small_imagenet_full_seed-0': {
      'features.5.2.block.0': (384, 14, 14), 
      'features.5.17.block.0': (384, 14, 14), 
      'features.4.0': (192, 28, 28), 
      'features.5.9.block.0': (384, 14, 14)
      },
  'convnext_tiny_imagenet_full_seed-0': {
      'features.6.0': (384, 14, 14), 
      'features.5.4.block.0': (384, 14, 14), 
      'features.4.0': (192, 28, 28), 
    },
  'convnext_base_imagenet_full_seed-0': {
      'features.5.7.block.0': (512, 14, 14), 
      'features.5.12.block.0': (512, 14, 14), 
      'features.4.0': (256, 28, 28), 
      'features.5.11.block.0': (512, 14, 14)
      },
  'resnet50_tutorial': {
      'layer2': (512, 28, 28), 
      'layer3': (1024, 14, 14)
      },
  'resnet101_imagenet_full': {
      'layer1.0.bn1': (64, 56, 56), 
      'layer4.0.bn1': (512, 14, 14), 
      'layer3.0.bn3': (1024, 14, 14), 
      'layer4.0.relu': (2048, 7, 7)
      },
  'convnext_large_imagenet_full_seed-0': {
      'features.5.7.block.5': (14, 14, 768), 
      'features.5.7.block.0': (768, 14, 14), 
      'features.4.1': (768, 14, 14), 
      'features.5.11.block.0': (768, 14, 14)
      },
  'resnet50_imagenet_full': {
      'layer1.0.conv1': (64, 56, 56), 
      'layer3.5.bn3': (1024, 14, 14), 
      'layer3.0.conv1': (256, 28, 28), 
      'layer4.0.relu': (2048, 7, 7)
      },
  'resnet18_imagenet_full': {
      'layer1.0.bn1': (64, 56, 56), 
      'layer3.0.conv2': (256, 14, 14), 
      'layer2.0.bn2': (128, 28, 28), 
      'layer4.0.bn1': (512, 7, 7)
      },
  'resnet152_ecoset_full': {
      'layer1.0.bn1': (64, 56, 56), 
      'layer3.3.bn3': (1024, 14, 14), 
      'layer3.0.bn3': (1024, 14, 14), 
      'layer4.0.relu': (2048, 7, 7)
      },
  'resnet18_ecoset_full': {
      'layer1.0.conv1': (64, 56, 56), 
      'layer3.0.conv1': (256, 14, 14), 
      'layer2.0.bn2': (128, 28, 28), 
      'layer4.0.bn1': (512, 7, 7)
      },
  'resnet-152_v2_pytorch': {
      'avgpool': (2048, 1, 1), 
      'layer4.1.relu': (2048, 10, 10),
      'layer4.1.bn2': (512, 10, 10)
      },
  'resnet34_ecoset_full': {
      'layer1.0.bn1': (64, 56, 56), 
      'layer3.1.conv1': (256, 14, 14), 
      'layer3.0.conv1': (256, 14, 14), 
      'layer4.0.conv1': (512, 7, 7)
      },
  'resnet18_imagenet21kP': {
      'layer2.0.relu': (128, 28, 28),
      'layer4.0.relu': (512, 7, 7)
      },
}

In [None]:
def reshape(key, shapes):
  key_folder = os.path.join(output_dir, key)
  for i in range(1, 26):
    file_path = os.path.join(key_folder, f'key-{i}.npz')
    if not os.path.exists(file_path):
      print(f"File {file_path} does not exist, skipping...")
      continue
    print(f"\nProcessing file: {file_path}")
    data = np.load(file_path)
    reshaped_data = {}
    for layer_name in data.files:
      print(f"  Processing layer: {layer_name}")
      layer_data = data[layer_name]
      batch_size, neuroids = layer_data.shape
        
      # Fetch the target shape for the current layer
      if layer_name in shapes:
        channels, height, width = shapes[layer_name]
      else:
        raise ValueError(f"Shape for layer {layer_name} not defined.")
        
      # Validate shape consistency
      assert channels * height * width == neuroids, (
        f"Mismatch in reshaping dimensions for {layer_name}: "
        f"Expected {channels * height * width}, got {neuroids}"
      )
        
      # Reshape to 4D tensor
      reshaped = layer_data.reshape(batch_size, channels, height, width)
      reshaped_data[layer_name] = reshaped
        
      print(f"Reshaped {layer_name} to {reshaped.shape}")
    np.savez(file_path, **reshaped_data)
    print(f"Overwritten file with reshaped data: {file_path}")

Reshape the data of each model

In [None]:
for key, shapes in models.items():
  reshape(key, shapes)