# Pytorch Model Inspection:

In [113]:
from torch.utils.data import random_split, DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
from ultralytics import YOLO
import torch.nn as nn
import torch
from torchinfo import summary

### Get Model output for every layer:

In [2]:
data_transforms = transforms.Compose([
    transforms.Resize((640, 480)), # Resize images to 640x640
    transforms.ToTensor() # Convert to tensor
])

In [69]:
data_path = 'example_images'
own_dataset = datasets.ImageFolder(root=data_path, transform=data_transforms)

train_loader = DataLoader(own_dataset, batch_size=1, shuffle=False)

image = [i for i in train_loader][0][0]

In [105]:
yolo_model=YOLO("yolov8n-pose.pt")
yolo_features = nn.Sequential(*list(yolo_model.model.children())[0][:22])

yolo_pose_layer = nn.Sequential(*list(yolo_model.model.children())[0][22:])

yolo_pose_stem = nn.Sequential(*list(yolo_pose_layer[0].children())[:-1])

In [80]:
depth = 0
def print_size(module, input, output):
    global depth
    if isinstance(output, tuple):
        depth += 1
        for element in output:
            print_size(module, input, element)
    elif isinstance(output, list):
        for element in output:
            print_size(module, input, element)        
    else:
        #print(type(output))
        print(f"depth: {depth}, {module.__class__.__name__}: {output.size()}")
        if depth > 0:
            depth -= 1

# Assuming `yolo` is your model instance
for layer in yolo_features.children():
    layer.register_forward_hook(print_size)

In [None]:
class FinetunedYOLO(nn.ModuleList):
    def __init__(self, yolo_model, h1, h2):
        super(FinetunedYOLO, self).__init__()
        self.first_yolo_block = nn.Sequential(*list(yolo_model.model.modules())[0:1])
        self.classifier = nn.Sequential(
            ...
        )

    def forward(self, x):
        x = self.first_yolo_block(x)
        x = self.flatten_tensors(x)
        x = self.classifier(x)
        
        return x
    
    def recursive_flatten(self, tensor_struct):
        """Recursively collect tensors from a nested structure."""
        if isinstance(tensor_struct, (list, tuple)):
            # For list or tuple, extend by recursively processing each item
            tensors = []
            for item in tensor_struct:
                tensors.extend(self.flatten_tensors(item))
            return tensors
        elif isinstance(tensor_struct, torch.Tensor):
            # For tensors, return in a list
            return [torch.flatten(tensor_struct)]
        else:
            # Non-tensor, non-list/tuple items are ignored
            return []
        
    def flatten_tensors(self, tensor_struct):
        return torch.flatten(torch.cat(self.recursive_flatten(tensor_struct)))

In [150]:
first_yolo_block = nn.Sequential(*list(yolo_model.model.modules())[0:1])
output = first_yolo_block(image)

flattened = torch.flatten(torch.cat(flatten_tensors(output)))
flattened.size()


tensor([11.1705, 18.5389, 27.3171,  ..., -3.3760, -3.2733, -3.0503])

In [None]:
yolo_backbone= nn.Sequential(*list(yolo_model.model.children())[0][:22]) # keeps layer 0 to 21, without the pose head (layer 22)
pose_head= nn.Sequential(*list(yolo_model.model.children())[0][22:23])
pose_stem= nn.Sequential(*list(pose_head[0].children())[:-1])
#display(summary(nn.Sequential(pose_stem)))
#summary(pose_head)
#pose_stem=nn.Sequential(pose_head[0].cv2, pose_head[0].cv3, pose_head[0].dfl)
yolo_pose = nn.Sequential(*yolo_backbone, pose_stem)
display(summary(yolo_model))
summary(yolo_pose)