# Initialize


In [None]:
%load_ext autoreload
%autoreload 2

#––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn

import sys
sys.path.append('/Users/enricofrausin/Programmazione/PythonProjects/Tesi/Autoencoders')

from AE.models import AE_0, ProgressiveAE
from AE.datasets import MNISTDigit2Dataset, MNISTDigit2OnlyDataset, FEMNISTDataset

from AE.depth_utils import calc_hfm_kld_with_optimal_g, compute_bottleneck_neurons_activ_freq, compute_emp_states_dict_gauged, compute_bottleneck_neurons_activ_freq_gauged, compute_dataset_klds_gs_dict_with_optimal_threshold_, compute_dataset_klds_gs_dict_from_sampled_binarized_vectors_
from AE.plotter_functions import visualize_bottleneck_neurons, plot_KLs_vs_hidden_layers, datasets_dicts_comparison
from AE.plotter_functions import datasets_dicts_comparison_colored

#––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Utilizzo Apple Silicon GPU (MPS)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Utilizzo NVIDIA GPU (CUDA)")
else:
    device = torch.device("cpu")
    print("Utilizzo la CPU")

device = torch.device("cpu")  # Fallback to CPU if no GPU is available

SEED = 42
torch.manual_seed(SEED)




# Datasets


In [None]:

batch_size = 64

## MNIST
train_loader_MNIST = torch.utils.data.DataLoader(
    datasets.MNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        train=True,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=True
    )

val_loader_MNIST = torch.utils.data.DataLoader(
    datasets.MNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        train=False,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=False
    )



## ExtendedMNIST

train_loader_EMNIST = torch.utils.data.DataLoader(
    datasets.EMNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        split='balanced',
        train=True,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=True
    )

val_loader_EMNIST = torch.utils.data.DataLoader(
    datasets.EMNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        split='balanced',
        train=False,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=False
    )


## 2MNIST

dataset_2MNIST_train = MNISTDigit2Dataset(train=True, download=True, target_size=60000)
print(f"Dataset size: {len(dataset_2MNIST_train)}")
print(f"Image shape: {dataset_2MNIST_train[0][0].shape}")
print(f"Label: {dataset_2MNIST_train[0][1]}")
train_loader_2MNIST = DataLoader(dataset_2MNIST_train, batch_size=batch_size, shuffle=True)

batch_images, batch_labels = next(iter(train_loader_2MNIST))
print(f"Batch images shape: {batch_images.shape}")
print(f"Batch labels shape: {batch_labels.shape}")
print(f"All labels are 2: {torch.all(batch_labels == 2)}")

print("\n––––––––––––––––––––––––––––––––––––––––––––––––––––––\n")

dataset_2MNIST_val = MNISTDigit2Dataset(train=False, download=True, target_size=10000)
print(f"Dataset size: {len(dataset_2MNIST_train)}")
print(f"Image shape: {dataset_2MNIST_train[0][0].shape}")
print(f"Label: {dataset_2MNIST_train[0][1]}")
print(f"All labels are 2: {torch.all(batch_labels == 2)}")
val_loader_2MNIST = DataLoader(dataset_2MNIST_val, batch_size=batch_size, shuffle=True)

print(f"Batch images shape: {batch_images.shape}")
print(f"Batch labels shape: {batch_labels.shape}")
print(f"All labels are 2: {torch.all(batch_labels == 2)}")





dataset_2MNISTonly_train = MNISTDigit2OnlyDataset(train=True, download=True)
train_loader_2MNISTonly = DataLoader(dataset_2MNISTonly_train, batch_size=batch_size, shuffle=True)

dataset_2MNISTonly_val = MNISTDigit2OnlyDataset(train=False, download=True)
val_loader_2MNISTonly = DataLoader(dataset_2MNISTonly_val, batch_size=batch_size, shuffle=True)


#-------------------------------------------------------------------


datasets = ["MNIST", "EMNIST", "2MNIST", "2MNISTonly"]
train_loaders = {
    "MNIST": train_loader_MNIST,
    "EMNIST": train_loader_EMNIST,
    "2MNIST": train_loader_2MNIST,
    "2MNISTonly": train_loader_2MNISTonly
}
val_loaders = {
    "MNIST": val_loader_MNIST,
    "EMNIST": val_loader_EMNIST,
    "2MNIST": val_loader_2MNIST,
    "2MNISTonly": val_loader_2MNISTonly
}


In [None]:

batch_size = 64


train_loader_MNIST = torch.utils.data.DataLoader(
    datasets.MNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        train=True,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=True
    )

val_loader_MNIST = torch.utils.data.DataLoader(
    datasets.MNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        train=False,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=False
    )



train_loader_EMNIST = torch.utils.data.DataLoader(
    datasets.EMNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        split='balanced',
        train=True,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=True
    )

val_loader_EMNIST = torch.utils.data.DataLoader(
    datasets.EMNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        split='balanced',
        train=False,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=False
    )




dataset_2MNISTonly_train = MNISTDigit2OnlyDataset(train=True, download=True)
train_loader_2MNISTonly = DataLoader(dataset_2MNISTonly_train, batch_size=batch_size, shuffle=True)

dataset_2MNISTonly_val = MNISTDigit2OnlyDataset(train=False, download=True)
val_loader_2MNISTonly = DataLoader(dataset_2MNISTonly_val, batch_size=batch_size, shuffle=True)




datasets = ["MNIST", "EMNIST", "2MNISTonly"]
train_loaders = {
    "MNIST": train_loader_MNIST,
    "EMNIST": train_loader_EMNIST,
    "2MNISTonly": train_loader_2MNISTonly
}
val_loaders = {
    "MNIST": val_loader_MNIST,
    "EMNIST": val_loader_EMNIST,
    "2MNISTonly": val_loader_2MNISTonly
}


In [None]:
dataset_FEMNIST_train = FEMNISTDataset(train=True, download=True)
train_loader_FEMNIST = DataLoader(dataset_FEMNIST_train, batch_size=batch_size, shuffle=True)
dataset_FEMNIST_val = FEMNISTDataset(train=False, download=True)
val_loader_FEMNIST = DataLoader(dataset_FEMNIST_val, batch_size=batch_size, shuffle=True)


train_loaders["FEMNIST"] = train_loader_FEMNIST
val_loaders["FEMNIST"] = val_loader_FEMNIST
print(len(train_loader_FEMNIST.dataset))



## FashionMNIST


In [None]:

train_loader_FashionMNIST = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        train=True,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=True
    )

val_loader_FashionMNIST = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        '/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data',
        train=False,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=False
    )





## OTHERS


In [None]:

from AE.datasets import Dataset_HFM, Dataset_pureHFM

batch_size = 64


## train over pureHFM

dataset_HFM_train = Dataset_pureHFM(csv_file='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM/512features/glog2_train60000.pt',
                        root_dir='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM')

train_loader_pureHFM = DataLoader(
    dataset_HFM_train,
    batch_size= batch_size,
    shuffle=True
)

dataset_HFM_val = Dataset_pureHFM(csv_file='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM/512features/glog2_validation10000.pt',
                            root_dir='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM')

val_loader_pureHFM = DataLoader(
    dataset_HFM_val, # Importante: usa dataset_HFM_val qui, non dataset_HFM
    batch_size= batch_size,
    shuffle=False
)
## train over expandedHFM
dataset_HFM_train = Dataset_HFM(csv_file='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/expandedHFM/16_1024features/2hl_glog2_train60000.pt',
                        root_dir='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM')

train_loader_expandedHFM = DataLoader(
    dataset_HFM_train,
    batch_size= batch_size,
    shuffle=True
)

dataset_HFM_val = Dataset_HFM(csv_file='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/expandedHFM/16_1024features/2hl_glog2_validation10000.pt',
                            root_dir='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM')

val_loader_expandedHFM = DataLoader(
    dataset_HFM_val, # Importante: usa dataset_HFM_val qui, non dataset_HFM
    batch_size= batch_size,
    shuffle=False
)
## train over expandedHFM 32-1024
dataset_HFM_train = Dataset_HFM(csv_file='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/expandedHFM/32_1024features/2hl_glog2_train60000.pt',
                        root_dir='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM')

train_loader_expandedHFM_32_1024 = DataLoader(
    dataset_HFM_train,
    batch_size= batch_size,
    shuffle=True
)

dataset_HFM_val = Dataset_HFM(csv_file='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/expandedHFM/32_1024features/2hl_glog2_validation10000.pt',
                            root_dir='/Users/enricofrausin/Programmazione/PythonProjects/Fisica/data/pureHFM')

val_loader_expandedHFM_32_1024 = DataLoader(
    dataset_HFM_val, # Importante: usa dataset_HFM_val qui, non dataset_HFM
    batch_size= batch_size,
    shuffle=False
)


# Loss

In [None]:
from AE.overlaps import load_model


In [None]:

def calc_MSE_loss(model, dataset_loader, l1_lambda=0.0, device=None):
    """
    Return the loss computed exactly as in the train(...) function:
    sum over batches of (MSE(output, data) + l1_lambda * sum(abs(params))) 
    divided by len(dataset_loader.dataset).
    """
    import torch
    import torch.nn as nn

    model.eval()
    if device is None:
        try:
            device = model.device
        except AttributeError:
            # fallback to first parameter device
            device = next(model.parameters()).device

    total_loss = 0.0
    mse = nn.MSELoss()
    with torch.no_grad():
        for data, _ in dataset_loader:
            data = data.to(device)
            output = model(data)
            batch_loss = mse(output, data)
            # L1 term added exactly as in train()
            l1_term = l1_lambda * sum(p.abs().sum() for p in model.parameters())
            total_loss += (batch_loss + l1_term).item()

    # Use same denominator as train(): len(dataset_loader.dataset) if available
    try:
        denom = len(dataset_loader.dataset)
    except Exception:
        # fallback: total number of samples iterated (approx)
        denom = sum(d.size(0) for d, _ in dataset_loader)

    return total_loss / denom



In [None]:

def dataset_mse(model, data_loader, device=None):
    """
    Compute MSE over a whole dataset using the same per-batch mean loss as `train`.
    This matches train(...) which uses nn.MSELoss() (default reduction='mean'),
    accumulates batch losses and then divides by len(dataset).
    Args:
        model: torch.nn.Module (AE_0) - forward(x) should return reconstruction or (recon, ...)
        data_loader: torch.utils.data.DataLoader
        device: torch.device or None (if None, prefer model.device then model params)
    Returns:
        float: loss value computed the same way as in train()
    """
    import torch
    import torch.nn as nn
    import math

    model.eval()
    if device is None:
        device = getattr(model, "device", None)
        if device is None:
            try:
                device = next(model.parameters()).device
            except StopIteration:
                device = torch.device("cpu")

    loss_fn = nn.MSELoss(reduction='mean')  # default reduction='mean' to match train()
    total_loss = 0.0

    with torch.no_grad():
        for batch in data_loader:
            x = batch[0] if isinstance(batch, (list, tuple)) else batch
            x = x.to(device)
            out = model(x)
            recon = out[0] if isinstance(out, (list, tuple)) else out
            # ensure shapes compatible
            if recon.shape != x.shape and recon.dim() == 2 and x.dim() > 2 and recon.size(0) == x.size(0):
                recon = recon.view_as(x)
            total_loss += loss_fn(recon, x).item()

    dataset_size = len(getattr(data_loader, "dataset", []))
    return total_loss / dataset_size if dataset_size > 0 else float("nan")


In [None]:

latent_dim = 10

for i, dataset in enumerate(("2MNISTonly", "2MNIST", "MNIST", "EMNIST", "FEMNIST")):
    model_kwargs = {
        'input_dim': 28*28,
        'latent_dim': latent_dim,
        'decrease_rate': 0.6,
        'device': device,
        'output_activation_encoder': nn.Sigmoid
    }
    model_path_kwargs = {
        'output_activation_encoder': 'sigmoid output',
        'train_type': 'simultaneous train',
        'latent_dim': f"{model_kwargs['latent_dim']}ld",
        'dataset': dataset,
        'decrease_rate': '06',
        'train_num': 0
    }


    if i == 0:
        rep_dataset_train_loss_dict = {}
        rep_dataset_val_loss_dict = {}

    num_hidden_layers_range = range(1,8)
    for train_num in range(6):
        for num_hidden_layers in num_hidden_layers_range:
            model_path_kwargs['num_hidden_layers'] = num_hidden_layers
            model_kwargs['hidden_layers'] = num_hidden_layers
            model_path_kwargs['train_num'] = train_num

            if train_num not in rep_dataset_train_loss_dict:
                rep_dataset_train_loss_dict[train_num] = {'2MNISTonly': [], "2MNIST": [], 'MNIST': [], 'EMNIST': [], 'FEMNIST': []}
            if train_num not in rep_dataset_val_loss_dict:
                rep_dataset_val_loss_dict[train_num] = {'2MNISTonly': [], "2MNIST": [], 'MNIST': [], 'EMNIST': [], 'FEMNIST': []}

            model = load_model(model_path_kwargs, model_kwargs)

            rep_dataset_train_loss_dict[train_num][dataset].append(calc_MSE_loss(model, train_loaders[dataset], device=device))
            # rep_dataset_val_loss_dict[train_num][dataset].append(calc_MSE_loss(model, val_loaders[dataset], device=device))
import pickle

with open(f'../savings/losses/sigmoid decoder output/{latent_dim}ld/rep_dataset_train_loss_dict.pkl', 'wb') as f:
    pickle.dump(rep_dataset_train_loss_dict, f)


with open(f'../savings/losses/sigmoid decoder output/{latent_dim}ld/rep_dataset_val_loss_dict.pkl', 'wb') as f:
    pickle.dump(rep_dataset_val_loss_dict, f)
# ...existing code...
import numpy as np
import matplotlib.pyplot as plt

def plot_mean_loss_over_depth(rep_dataset_train_loss_dict,
                              datasets=None,
                              num_hidden_layers_range=range(1,8),
                              show_std=True,
                              figsize=(8,5),
                              title="Mean loss vs # hidden layers",
                              xlabel="Number of hidden layers",
                              ylabel="MSE",
                              save_path=None,
                              ax=None):
    """
    Compute mean (and std) across train_num realizations and plot loss vs depth.
    Args:
        rep_dataset_train_loss_dict: dict keyed by train_num -> dict(dataset -> list(losses per depth))
        datasets: list of dataset names to plot (default: keys from first inner dict)
        num_hidden_layers_range: iterable of depth values (default 1..7)
        show_std: if True fill ±1 std region
        figsize, title, xlabel, ylabel: plot params
        save_path: if given, save figure to path
        ax: optional matplotlib Axes to draw into
    Returns:
        mean_dict, std_dict: dicts dataset -> list(mean), dataset -> list(std)
    """
    # collect train_nums
    train_nums = sorted(rep_dataset_train_loss_dict.keys())
    if len(train_nums) == 0:
        raise ValueError("rep_dataset_train_loss_dict is empty")

    # infer datasets
    first_inner = rep_dataset_train_loss_dict[train_nums[0]]
    if datasets is None:
        datasets = list(first_inner.keys())

    # prepare output
    mean_dict = {}
    std_dict = {}

    depths = list(num_hidden_layers_range)

    for ds in datasets:
        # gather lists from each train_num if available
        rows = []
        for tn in train_nums:
            inner = rep_dataset_train_loss_dict.get(tn, {})
            vals = inner.get(ds)
            if vals is None:
                continue
            # ensure it's a numpy 1d array
            rows.append(np.asarray(vals))

        if len(rows) == 0:
            # no data for this dataset
            mean_dict[ds] = []
            std_dict[ds] = []
            continue

        # align lengths: use minimum available length across realizations
        min_len = min(r.shape[0] for r in rows)
        stacked = np.vstack([r[:min_len] for r in rows])
        mean = stacked.mean(axis=0)
        std = stacked.std(axis=0)

        mean_dict[ds] = mean.tolist()
        std_dict[ds] = std.tolist()

    # plotting
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    for ds, mean in mean_dict.items():
        if len(mean) == 0:
            continue
        x_plot = depths[:len(mean)]
        ax.plot(x_plot, mean, marker='o', label=ds)
        if show_std and len(std_dict.get(ds, [])) == len(mean):
            std = np.asarray(std_dict[ds])
            ax.fill_between(x_plot, np.asarray(mean)-std, np.asarray(mean)+std, alpha=0.25)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.set_xticks(depths)
    ax.grid(True, linestyle='--', alpha=0.4)
    ax.legend()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')

    return mean_dict, std_dict
# ...existing code...
# compute and plot mean train loss over depths
mean_losses, std_losses = plot_mean_loss_over_depth(rep_dataset_train_loss_dict,
                                                    datasets=['2MNISTonly','2MNIST','MNIST','EMNIST','FEMNIST'],
                                                    num_hidden_layers_range=range(1,8),
                                                    show_std=False,
                                                    title="Mean training MSE vs depth",
                                                    save_path=None)
plt.show()
# compute and plot mean train loss over depths
mean_losses, std_losses = plot_mean_loss_over_depth(rep_dataset_train_loss_dict,
                                                    datasets=['2MNISTonly','2MNIST','MNIST','EMNIST','FEMNIST'],
                                                    num_hidden_layers_range=range(1,8),
                                                    show_std=False,
                                                    title="Mean training MSE vs depth",
                                                    save_path=None)
plt.show()