# Pytorch Model Slicing:

In [1]:
from torch.utils.data import random_split, DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
from ultralytics import YOLO
import torch.nn as nn
import torch
from torchinfo import summary

### Finetune YOLOv8 pose:

In [2]:
data_transforms = transforms.Compose([
    transforms.Resize((640, 480)), # Resize images to 640x640
    transforms.ToTensor() # Convert to tensor
])

In [8]:
data_path = 'example_images'
own_dataset = datasets.ImageFolder(root=data_path, transform=data_transforms)

train_loader = DataLoader(own_dataset, batch_size=1, shuffle=False)

image = [i for i in train_loader][0][0]

In [4]:
yolo_model=YOLO("yolov8n-pose.pt")

In [21]:
num_classes = 3  
num_features = 1083600
h1 = 1024
h2 = 512

In [67]:
class FinetunedYOLO(nn.Module):
    def __init__(self, yolo_model, input_features, num_classes, h1, h2):
        super(FinetunedYOLO, self).__init__()
        self.first_yolo_block = nn.Sequential(*list(yolo_model.model.modules())[0:1])
        self.classifier = nn.Sequential(
            #nn.Flatten(),
            nn.Linear(input_features, h1),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(h2, h2 // 2),
            nn.ReLU(),
            nn.Linear(h2 // 2, num_classes)
        )

    def recursive_flatten(self, tensor_struct):
        """Recursively collect tensors from a nested structure."""
        if isinstance(tensor_struct, (list, tuple)):
            # For list or tuple, extend by recursively processing each item
            tensors = []
            for item in tensor_struct:
                tensors.extend(self.recursive_flatten(item))
            return tensors
        elif isinstance(tensor_struct, torch.Tensor):
            # For tensors, return in a list
            return [torch.flatten(tensor_struct)]
        else:
            # Non-tensor, non-list/tuple items are ignored
            return []
        
    def flatten_tensors(self, tensor_struct):
        return torch.flatten(torch.cat(self.recursive_flatten(tensor_struct)))
    
    def forward(self, x):
        x = self.first_yolo_block(x)
        flattened = self.flatten_tensors(x)
        print(flattened.shape)
        #print(torch.nn.Flatten(flattened).shape)
        x = self.classifier(flattened)
        
        return x

    # def forward(self, x):
    #     x = self.first_yolo_block(x)
    #     flattened = torch.flatten(torch.cat(self.recursive_flatten(x)))
    #     x = self.classifier(flattened)
        
    #     return x
    
    # def recursive_flatten(self, tensor_struct):
    #     """Recursively collect tensors from a nested structure."""
    #     if isinstance(tensor_struct, (list, tuple)):
    #         # For list or tuple, extend by recursively processing each item
    #         tensors = []
    #         for item in tensor_struct:
    #             tensors.extend(self.recursive_flatten(item))
    #         return tensors
    #     elif isinstance(tensor_struct, torch.Tensor):
    #         # For tensors, return in a list
    #         return [torch.flatten(tensor_struct)]
    #     else:
    #         # Non-tensor, non-list/tuple items are ignored
    #         return []
        


In [68]:
newmodel = FinetunedYOLO(yolo_model, num_features, num_classes, h1, h2)

In [69]:
out = newmodel(image)

torch.Size([1083600])


In [70]:
out

tensor([ 6.6704,  5.2526, -5.9985], grad_fn=<ViewBackward0>)

### datapreparation

In [None]:
yolo_backbone= nn.Sequential(*list(yolo_model.model.children())[0][:22]) # keeps layer 0 to 21, without the pose head (layer 22)
pose_head= nn.Sequential(*list(yolo_model.model.children())[0][22:23])
pose_stem= nn.Sequential(*list(pose_head[0].children())[:-1])
#display(summary(nn.Sequential(pose_stem)))
#summary(pose_head)
#pose_stem=nn.Sequential(pose_head[0].cv2, pose_head[0].cv3, pose_head[0].dfl)
yolo_pose = nn.Sequential(*yolo_backbone, pose_stem)
display(summary(yolo_model))
summary(yolo_pose)

In [None]:
yolo_features = nn.Sequential(*list(yolo_model.model.children())[0][:22])

yolo_pose_layer = nn.Sequential(*list(yolo_model.model.children())[0][22:])

yolo_pose_stem = nn.Sequential(*list(yolo_pose_layer[0].children())[:-1])

print the output size of every layer of the model:

In [None]:
depth = 0
def print_size(module, input, output):
    global depth
    if isinstance(output, tuple):
        depth += 1
        for element in output:
            print_size(module, input, element)
    elif isinstance(output, list):
        for element in output:
            print_size(module, input, element)        
    else:
        #print(type(output))
        print(f"depth: {depth}, {module.__class__.__name__}: {output.size()}")
        if depth > 0:
            depth -= 1

# Assuming `yolo` is your model instance
for layer in yolo_features.children():
    layer.register_forward_hook(print_size)

#### example flattening of the tensors:
(this works)

In [41]:
def flatten_tensors(tensor_struct):
    """Recursively collect tensors from a nested structure."""
    if isinstance(tensor_struct, (list, tuple)):
        # For list or tuple, extend by recursively processing each item
        tensors = []
        for item in tensor_struct:
            tensors.extend(flatten_tensors(item))
        return tensors
    elif isinstance(tensor_struct, torch.Tensor):
        # For tensors, return in a list
        return [torch.flatten(tensor_struct)]
    else:
        # Non-tensor, non-list/tuple items are ignored
        return []

In [52]:
first_yolo_block = nn.Sequential(*list(yolo_model.model.modules())[0:1])
output = first_yolo_block(image)

flattened = torch.flatten(torch.cat(flatten_tensors(output)))
flattened.size()

torch.Size([1083600])