In [None]:
len(test_loader), len(test_dataset)

In [None]:
resnet_model = model.model


model_weights = [] # we will save the conv layer weights in this list
conv_layers = [] # we will save the 49 conv layers in this list
# get all the model children as list
model_children = list(resnet_model.children())

In [None]:
from torch import nn

counter = 0 
# append all the conv layers and their respective weights to the list
for i in range(len(model_children)):
    if type(model_children[i]) == nn.Conv2d:
        counter += 1
        model_weights.append(model_children[i].weight)
        conv_layers.append(model_children[i])
    elif type(model_children[i]) == nn.Sequential:
        for j in range(len(model_children[i])):
            for child in model_children[i][j].children():
                if type(child) == nn.Conv2d:
                    counter += 1
                    model_weights.append(child.weight)
                    conv_layers.append(child)
print(f"Total convolutional layers: {counter}")

In [None]:
# take a look at the conv layers and the respective weights
for weight, conv in zip(model_weights, conv_layers):
    # print(f"WEIGHT: {weight} \nSHAPE: {weight.shape}")
    print(f"CONV: {conv} ====> SHAPE: {weight.shape}")

In [None]:
skip = True
if not skip:
    plt.figure(figsize=(20, 17))
    for i, filter in enumerate(model_weights[0]):
        plt.subplot(8, 8, i+1) # (8, 8) because in conv0 we have 7x7 filters and total of 64 (see printed shapes)
        plt.imshow(filter[0, :, :].cpu().detach(), cmap='gray')
        plt.axis('off')
    plt.show()

In [None]:
with torch.no_grad():
    for batch in test_loader:
        keep_idx = [i for i, b in enumerate(batch) if b[0].shape[0] == batch[0][0].shape[0]][2:]
        # batch = np.array(batch)
        images = [b[0] for i, b in enumerate(batch) if i in keep_idx]
        labels = [torch.tensor(b[1]) for i, b in enumerate(batch) if i in keep_idx]
        for image, label in zip(images, labels):
            image, label = image.to('cuda').float(), label.to('cuda').float()

            # pass the image through all the layers
            results = [conv_layers[0](image.unsqueeze(0))]
            for i in range(1, len(conv_layers)):
                # pass the result from the last layer to the next layer
                results.append(conv_layers[i](results[-1]))
            # make a copy of the `results`
            outputs = results
            break
        break
        
# visualize 64 features from each layer 
# (although there are more feature maps in the upper layers)
for num_layer in range(len(outputs)):
    plt.figure(figsize=(30, 30))
    layer_viz = outputs[num_layer][0, :, :, :].cpu()
    layer_viz = layer_viz.data
    print(layer_viz.size())
    for i, filter in enumerate(layer_viz):
        if i == 64: # we will visualize only 8x8 blocks from each layer
            break
        plt.subplot(8, 8, i + 1)
        plt.imshow(filter, cmap='gray')
        plt.axis("off")
    #print(f"Saving layer {num_layer} feature maps...")
    #plt.savefig(f"../outputs/layer_{num_layer}.png")
    plt.show()
    plt.close()