# Pytorch Model Inspection:

In [None]:
from torch.utils.data import random_split, DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
from ultralytics import YOLO
import torch

### Get Model output for every layer:

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize((640, 480)), # Resize images to 640x640
    transforms.ToTensor() # Convert to tensor
])

In [None]:
data_path = 'example_images'
own_dataset = datasets.ImageFolder(root=data_path, transform=data_transforms)

train_loader = DataLoader(own_dataset, batch_size=1, shuffle=False)

image = [i for i in train_loader][0]


In [None]:
yolo_model = YOLO('yolov8n-pose.pt')
depth = 0
def print_size(module, input, output):
    global depth
    if isinstance(output, tuple):
        depth += 1
        for element in output:
            print_size(module, input, element)
    elif isinstance(output, list):
        for element in output:
            print_size(module, input, element)        
    else:
        #print(type(output))
        print(f"depth: {depth}, {module.__class__.__name__}: {output.size()}")
        if depth > 0:
            depth -= 1

# Assuming `yolo` is your model instance
for layer in yolo_model.modules():
    layer.register_forward_hook(print_size)

yolo_model(image[0])

#### print the summary of the YOLO model

! starts the training

In [None]:
from torchinfo import summary
summary(yolo_model, (3, 320, 320))  

## Alternative:

- inspect the output in front of YOLOs pose.head 
- cut the pose.head from yolo


In [None]:
height = 320
width= 320

# Assuming `model` is your YOLO model and it's already defined
dummy_input = torch.randn(1, 3, height, width)  # Replace height and width with actual input dimensions

# Initialize a temporary variable for passing data through layers
temp_output = dummy_input

# Forward pass through the model up to layer 21
with torch.no_grad():
    for i, module in enumerate(yolo_model.children()):  # Adjust this line based on the actual structure of your model
        if isinstance(temp_output, tuple):
            # If the module expects a single tensor but the current output is a tuple, 
            # you might need to adjust this part depending on how the module expects its inputs
            temp_output = module(*temp_output)  
        else:
            temp_output = module(temp_output)
        if i == 20:  # Layer indices are 0-based; layer 21 is index 20
            break

# Check if the output is a tuple and print sizes

# If the final output is a tuple, select the appropriate element
if isinstance( temp_output, tuple):
    print("is tuple")
    for i, elem in enumerate(temp_output):
        try:
            output =  temp_output[i]  # Adjust this based on which part of the tuple you need
            print("i: ", i, output.size())
        except Exception as e:
            if isinstance(elem, tuple):
                output = elem[i]
                print("i: ", i, output.size())
            if isinstance(elem, torch.Tensor):
                output = elem
                print("i: ", i, output.size())
            if type(elem) == list:
                for j, e in enumerate(elem):
                    output = elem[j]
                    print("j: ", j, output.size(), i)
            else:
                print(e)
else: 
    output = temp_output.size()
    print("Output size of layer 21: ", output.size())