In [None]:
import os, yaml, sys
import numpy as np
from sklearn.decomposition import IncrementalPCA
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
import torch 

ENV = os.getenv("MY_ENV", "dev")
with open("../../config.yaml", "r") as f:
    config = yaml.safe_load(f)
paths = config[ENV]["paths"]
sys.path.append(paths["src_path"])
from general_utils.utils import print_wise
from image_processing.utils import read_video

## Steps
(parallel over the layers of the video)
1. Load videos until you get approx 2 or 3 times the batch-size (or more?) (for very long vids (some of the ones with arcaro for instance - but check), just take the first 20 seconds or so...) (can we estimate the optimal order based on get_video_dimensions?) -> then shuffle, split evenly the frames and pass it to iPCA 
2. Normalize, format and shuffle them
    - If gaze dependent, load gaze, upsample (in time) video and extract gaze-dep spatial window
3. Compute iPCA
4. Save eigenvectors and eigenvalues


In [None]:
"""
get_layer_out_shape
Computes the output shape (excluding batch size) of a specific layer 
from a given PyTorch feature extractor when applied to a dummy input 
image of size (1, 3, 224, 224).
INPUT:
- feature_extractor: torch.nn.Module -> A PyTorch model (typically a feature extractor created via torchvision.models.feature_extraction.create_feature_extractor)
                                        which outputs a dictionary of intermediate activations.
            
- layer_name: str -> The name of the layer for which the output shape is desired. This must be one of the keys returned by the feature_extractor.

OUTPUT:
- tmp_shape: Tuple(Int) -> A tuple representing the shape of the output tensor from the specified layer, excluding the batch dimension. For example,
                          (512, 7, 7) for a convolutional layer or (768,) for a transformer block.
            
Example Usage:
    >>> from torchvision.models import resnet18
    >>> from torchvision.models.feature_extraction import create_feature_extractor
    >>> model = resnet18(pretrained=True).eval()
    >>> feat_ext = create_feature_extractor(model, return_nodes=["layer1.0.relu_1"])
    >>> shape = get_layer_out_shape(feat_ext, "layer1.0.relu_1")
    >>> print(shape)
    (64, 56, 56)
"""
def get_layer_output_shape(feature_extractor, layer_name):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # thi if leave it here or not...
    with torch.no_grad():
        in_proxy = torch.randn(1, 3, 224, 224).to(device)
        tmp_shape = feature_extractor(in_proxy)[layer_name].shape[1:]
    return tmp_shape


def ipca_core(paths, rank, layer_name, model_name, n_components, model, loader, device):
    save_name = (f"imagenet_val_{model_name}_{layer_name}_pca_model_{n_components}_PCs.pkl")
    path = os.path.join(paths["results_path"], save_name)
    if os.path.exists(path):
        print_wise(f"{path} already exists")
    else:
        print_wise(f"Fitting PCA for layer: {layer_name}", rank=rank)
        feature_extractor = create_feature_extractor(
            model, return_nodes=[layer_name]
        ).to(device)
        tmp_shape = get_layer_output_shape(feature_extractor, layer_name)
        n_features = np.prod(tmp_shape)  # [C, H, W] -> C*H*W
        n_components_layer = min(n_features, n_components)  # Limit to number of features
        pca = IncrementalPCA(n_components=n_components_layer)
        counter = 0
        for inputs, _ in loader:
            counter += 1
            print_wise(f"starting batch {counter}", rank=rank)
            with torch.no_grad():
                inputs = inputs.to(device)
                feats = feature_extractor(inputs)[layer_name]
                feats = feats.view(feats.size(0), -1).cpu().numpy()
                pca.partial_fit(feats)

        joblib.dump(pca, path) # better this or pkl?
        print_wise(f"Saved PCA for {layer_name} at {path}", rank=rank)

