In [None]:
from colorviz.birds_dataset.data import ImageDataset
from colorviz.conv_color.config_objects import ImageDatasetCfg
import keras
import numpy as np
import tensorflow as tf
import pickle
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import torchinfo

from colorviz.conv_color.visualizations import *


%load_ext autoreload
%autoreload 1
%aimport colorviz.conv_color.config_objects,colorviz.birds_dataset.data,colorviz.conv_color.visualizations
%aimport

In [None]:
model = keras.models.load_model("bird_data/EfficientNetB0-525-(224 X 224)- 98.97.h5", custom_objects={'F1_score':'F1_score'})

In [None]:
net = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)

In [None]:
torchinfo.summary(net, input_size=(1, 3, 224, 224))

In [None]:
model.layers[8].__dict__

In [None]:
model.summary(line_length=150)

In [None]:
pytorch_layers = list(net.modules())
keras_layers = model.layers

for i, keras_layer in enumerate(keras_layers):    
    if hasattr(keras_layer, 'weights'):
        keras_weights = keras_layer.get_weights()
        if not keras_weights:
            continue
        keras_weights = [np.transpose(w) for w in keras_weights]  # Transpose weights for compatibility
        # print(keras_layer)
        # Find matching PyTorch layer based on size
        for pytorch_layer in pytorch_layers:
            if isinstance(pytorch_layer, (nn.Conv2d, nn.Linear)):
                pytorch_weights = pytorch_layer.weight.data
                if pytorch_weights.size() == keras_weights[0].shape:
                    pytorch_layer.weight.data = torch.from_numpy(keras_weights[0])
                    if len(keras_weights) > 1:
                        pytorch_layer.bias.data = torch.from_numpy(keras_weights[1])
                    break
            elif isinstance(pytorch_layer, nn.BatchNorm2d):
                pytorch_weights = pytorch_layer.weight.data
                if pytorch_weights.size() == keras_weights[0].shape:
                    pytorch_layer.weight.data = torch.from_numpy(keras_weights[0])
                    pytorch_layer.bias.data = torch.from_numpy(keras_weights[1])
                    pytorch_layer.running_mean = torch.from_numpy(keras_weights[2])
                    pytorch_layer.running_var = torch.from_numpy(keras_weights[3])
                    break
        else:
            print("Failed to find Pytorch match on ", keras_layer, i)

In [None]:
model.layers[4].get_weights()[0]

In [None]:
name_conversion = {"kernel": "weight", 
                   "moving_mean": "running_mean",
                   "moving_variance": "running_variance",
                   "gamma": "weight",
                   "beta": "bias",
                   "depthwise_kernel": "weight",
                   "bias": "bias"
                   }

In [None]:
all_names = set()
n_params = net.named_parameters()
for layer in model.layers:
    for k,v in layer.__dict__.items():
        if isinstance(v, tf.Variable):
            print(k, v.shape)
            all_names.add(k)
            

In [None]:
for n,p in net.named_parameters():
    print(n)

In [None]:
model.layers[20].kernel.numpy().shape

In [None]:
model.summary()

In [None]:
net

In [None]:
model = keras.models.load_model("without_f1.h5")

In [None]:
model.layers[19].__dict__

In [None]:
model.metrics

In [None]:
model.compiled_metrics._metrics_in_order = []

In [None]:
model.save("without_f1.h5")

In [None]:
dsets = {split: ImageDataset(split, ImageDatasetCfg(data_dir="bird_data", device="cuda")) for split in ["train", "valid", "test"]}

In [None]:
train_dset = tf.keras.utils.image_dataset_from_directory("bird_data/train", 
                                                         image_size=(224,224),
                                                         batch_size=16)

In [None]:
samp = np.concatenate([s[0].numpy() for s, i in zip(train_dset, range(16384//16))]).astype(np.float16)

In [None]:
default_scales = [3,5,7,9,13,15]
pca_dirs = find_pca_directions(train_dset, 8192, default_scales, 1)

In [None]:
import pickle
with open("bird_data/big_sample_pca_dirs.pkl", "wb") as p:
    pickle.dump(pca_dirs, p)

In [None]:
import pickle
with open("bird_data/pca_dirs.pkl", "wb") as p:
    pickle.dump(pca_dirs, p)

In [None]:
visualize_pca_directions(pca_dirs, "test", default_scales, lines=False)  # sample size 8192

In [None]:
visualize_pca_directions(pca_dirs, "test", default_scales, lines=False)  # sample size 4096

In [None]:
visualize_pca_directions(pca_dirs, "test", default_scales, lines=False)  # sample size 4096