## Batched Feature Maps:

For more flexible manipulation of data in various batch sizes, and for troubleshooting of feature extraction errors that may be inconsistent across batches (i.e. with next-token-predicting LLMs.) Below are two examples of the use of batched feature maps:
- The canonical AlexNet example.
- A more complicated GPT example.

First, some setup...

In [1]:
import sys; sys.path += ['..']
# add deepjuice to your path

In [2]:
%load_ext autoreload
%autoreload complete

In [3]:
from deepjuice import * # the juice

In [4]:
from juicyfruits import NSDBenchmark
benchmark = NSDBenchmark() #load brain data benchmark

model_uid = 'torchvision_alexnet_imagenet1k_v1'
model, preprocess = get_deepjuice_model(model_uid)

# here, we'll subset our images to simulate
# a slightly uneven split across batches:
image_subset = benchmark.image_paths[:160]

dataloader = get_data_loader(image_subset, preprocess, batch_size=64)

Initializing DeepJuice Benchmarks (JuicyFruits)
Loading DeepJuice NSDBenchmark: 
  Image Set: shared1000
  Voxel Set: ['EVC', 'OTC']


**The Modified Extraction Procedure**

...works almost exactly the same as before, except there's a new key argument:

**batch_strategy**: 3 main variants
- *join* constructs empty tensors in full dimension, and fills them iteratively with each batch.
- *list* quite literally adds each batch of features to a list, and returns that list.
- *stack* wraps this list in a class that allows for easier manipulation of the underlying nested list.


*join* is the canonical version you're used to and is most efficient, but also can lead to a lot of degeneracies it turns out if batch_sizes are irregular, transfer between tensor devices is slow, or many other things...*list* is as it's written on the tin; each new batch of feature is dropped into a list, and that list is returned at the end of the function. *stack* is the version that directly uses the BatchedFeatureMaps class, but keep in mind that you can simply use the output of the *list* version after it's been built to initialize this class.

Let's have a look at the list version first:

In [5]:
# MPS is a mac GPU device that I'm using as stopgap:
devices = {'device': 'cuda:0', 'output_device': 'cpu'}

# notice that I'm invoking the get_feature_maps function directly
# this is the function internally called by the FeatureExtractor:
feature_map_list = get_feature_maps(model, dataloader, **devices,
                                   batch_strategy='list') # <- the new argument!

Extracting sample feature_maps with torchinfo (CUDA:0 to CPU)
Keeping 18 / 24 total maps (6 duplicates removed).


Feature Extraction (DataLoader):   0%|          | 0/3 [00:00<?, ?it/s]

So what does feature_map_list look like?

In [6]:
feature_map_list # notice the last of these feature_maps has only 32 inputs

[FeatureMaps Handle
  18 maps; 64 inputs; 262.60 MB 
  0 maps on GPU (0 duplicates),
 FeatureMaps Handle
  18 maps; 64 inputs; 262.60 MB 
  0 maps on GPU (0 duplicates),
 FeatureMaps Handle
  18 maps; 32 inputs; 131.30 MB 
  0 maps on GPU (0 duplicates)]

...in this example, feature_map_list is not a list of dictionaries as you might expect, but a list of what I call "FeatureMap" handles. These basically behave almost EXACTLY like a dictionary, but don't pollute ipython with numerical printouts. they also give us access to a bunch of quick stats.

In [7]:
batch_one_maps = feature_map_list[0]

# note how these behave exactly like dictionaries:
for uid, feature_map in batch_one_maps.items():
    if 'Linear' in uid: # print linear layers:
        print(uid, [x for x in feature_map.shape])
    
# but also give you cool, quick stats:
print('\n Number of inputs:', batch_one_maps.get_input_size())

Linear-2-15 [64, 4096]
Linear-2-18 [64, 4096]
Linear-2-20 [64, 1000]

 Number of inputs: 64


With this in mind now, you can think of BatchedFeatureMaps as simply a wrapper around the FeatureMap wrappers (don't worry -- this is the most recursive this will get, I think...)

In [8]:
from deepjuice.extraction import BatchedFeatureMaps

batched_maps = BatchedFeatureMaps(feature_map_list)

In [9]:
batched_maps # the initial report

Batch Tensor Maps Handler
  Total Batch Count: 3
  # Total Inputs: 160
  # of Unique Feature Maps: 18
  No irregularities found.

In [10]:
# this returns None in this case:
batched_maps.get_irregular_shapes()

{}

Note, this modification to the "batch_strategy" procedure can be wrapped directly into a FeatureExtractor:

In [11]:
devices = {'device': 'cuda:0', 'output_device': 'cpu'}

extractor = FeatureExtractor(model, dataloader, **devices,
                             batch_strategy='stack',
                             max_memory_limit='16GB')

Extracting sample feature_maps with torchinfo (CUDA:0 to CPU)
FeatureExtractor Handle for AlexNet
  24 feature maps (+6 duplicates); 160 inputs
  Memory required for full extraction: 677.15 MB
  Memory usage limiting device set to: cpu
  Memory usage limit currently set to: 289.411 GB
  1 batch(es) required for current memory limit 
   Batch-001: 24 feature maps; 677.15 MB


In [12]:
for batched_feature_maps in tqdm(extractor, 'Global Progress'):
    feature_maps = batched_feature_maps.join_batches()
    for uid, feature_map in feature_maps.items():
        if 'Linear' in uid: # print linear layers:
            print(uid, [x for x in feature_map.shape])

Global Progress:   0%|          | 0/1 [00:00<?, ?it/s]

Joining all regularly shaped feature_maps into a set:


Joining Batched Feature Maps:   0%|          | 0/3 [00:00<?, ?it/s]

Linear-2-15 [160, 4096]
Linear-2-18 [160, 4096]
Linear-2-20 [160, 1000]


In [13]:
# this will also work if we add a flattening modification

extractor.modify_settings(flatten=True)

for batched_feature_maps in tqdm(extractor, 'Global Progress'):
    feature_maps = batched_feature_maps.join_batches()
    for uid, feature_map in feature_maps.items():
        if 'Conv2d' in uid: # print linear layers:
            print(uid, [x for x in feature_map.shape])

Global Progress:   0%|          | 0/1 [00:00<?, ?it/s]

Joining all regularly shaped feature_maps into a set:


Joining Batched Feature Maps:   0%|          | 0/3 [00:00<?, ?it/s]

Conv2d-2-1 [160, 193600]
Conv2d-2-4 [160, 139968]
Conv2d-2-7 [160, 64896]
Conv2d-2-9 [160, 43264]
Conv2d-2-11 [160, 43264]
