# Numerical Simulation for Generlized Exponential Function (GEF)
from the GES paper https://abdullahamdi.com/ges/



In [None]:
# !pip install open_clip_torch celluloid matplotlib tqdm

## plotting the GEF functions family

### $f_{\beta}(x) = Ae^{-\left(\frac{|x - \mu|}{\alpha}\right)^\beta}$

* when $\beta=2$, GEF reduces to Gaussians

In [None]:
import matplotlib.cm as cm
import numpy as np
import matplotlib.pyplot as plt

# Define the parameters
mu = 0  # location
alpha = 1  # scale
beta_values = [0.5, 1, 1.5, 2, 3,10]  # shape parameters

# Define the x range
x = np.linspace(-5, 5, 1000)

# Plotting
colors = cm.viridis(np.linspace(0, 1, len(beta_values)))

plt.figure(figsize=(12, 8))

for beta, color in zip(beta_values, colors):
    # Compute the PDF
    A = 1.0 # (beta / (2 * alpha * gamma(1 / beta)))
    pdf = A * np.exp(-np.abs((x - mu) / alpha) ** beta)
    plt.plot(x, pdf, label=f'$\\beta$ = {beta}', linewidth=2)

# Customize the plot
# plt.title('Generalized Exponential Function for Different $\\beta$ Values', fontsize=20)
plt.xlabel('x', fontsize=22)
plt.ylabel(f"f(x | $\\beta$)", fontsize=22)
plt.legend(fontsize=20)
plt.grid(True)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)

# Show plot
plt.show()


## 1D Simulation 

### models 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import math
class GaussianMixture(nn.Module):
    def __init__(self, N, positive_weights=False):
        super(GaussianMixture, self).__init__()
        
        self.means = nn.Parameter(torch.randn(N))
        self.variances = nn.Parameter(torch.abs(0.1 * torch.randn(N) + 1))
        self.weights = nn.Parameter(torch.randn(N))
        
        self.positive_weights = positive_weights
        
    def forward(self, x):
        x = x.unsqueeze(-1)
        epsilon = 1e-8
        
        # Apply softplus if positive_weights flag is True
        weights = F.softplus(self.weights) if self.positive_weights else self.weights
        gaussians = weights * torch.exp(-(x - self.means)**2 / (2 * self.variances + epsilon))
        
        return gaussians.sum(dim=-1)
class DoGMixture(nn.Module):
    def __init__(self, N, positive_weights=False, ratio=4.0):
        super(DoGMixture, self).__init__()
        
        self.means = nn.Parameter(torch.randn(N))
        self.scales = nn.Parameter(torch.abs(0.1 * torch.randn(N) + 1))  # Single scale parameter for both variances
        self.weights = nn.Parameter(torch.randn(N))
        
        self.ratio = ratio
        self.positive_weights = positive_weights
        
    def forward(self, x):
        x = x.unsqueeze(-1)
        epsilon = 1e-8
        
        var1 = self.scales**2
        var2 = var1 / self.ratio  # Keeping a fixed ratio between the variances
        
        weights = F.softplus(self.weights) if self.positive_weights else self.weights
        dog_components = weights * (torch.exp(-(x - self.means)**2 / (2 * var1 + epsilon)) -
                                    torch.exp(-(x - self.means)**2 / (2 * var2 + epsilon)))
        
        return dog_components.sum(dim=-1)
    
class LoGMixture(nn.Module):
    def __init__(self, N, positive_weights=False):
        super(LoGMixture, self).__init__()
        
        self.means = nn.Parameter(torch.randn(N))
        self.scales = nn.Parameter(torch.abs(0.1 * torch.randn(N) + 1))  # Scale parameter, equivalent to sigma
        self.weights = nn.Parameter(torch.randn(N))
        
        self.positive_weights = positive_weights
        
    def forward(self, x):
        x = x.unsqueeze(-1)
        epsilon = 1e-8
        
        weights = F.softplus(self.weights) if self.positive_weights else self.weights
        log_components = weights * (- (x - self.means)**2 / (self.scales**2) + 1) * torch.exp(-(x - self.means)**2 / (2 * self.scales**2 + epsilon))
        
        return log_components.sum(dim=-1)        
        return dog_components.sum(dim=-1)
class GeneralMixture(nn.Module):
    def __init__(self, N, positive_weights=False, learn_beta=False, fixed_beta=2.0):
        super(GeneralMixture, self).__init__()

        # Parameters for the means, variances, and weights of the Gaussian components.
        self.means = nn.Parameter(torch.randn(N))
        self.variances = nn.Parameter(torch.abs(0.1 * torch.randn(N) + 1))
        self.weights = nn.Parameter(torch.randn(N))

        # Additional parameter beta for the power of the exponent in each Gaussian component.
        if learn_beta:
            # If beta is learnable, we initialize it as a parameter.
            # We will ensure that beta is always greater than 0.5 during optimization.
            self.beta = nn.Parameter(torch.rand(N) * 0.5 + 0.5)  # Initialize in the range [0.5, 1.0]
        else:
            # If beta is fixed, we use a tensor that does not require gradients.
            self.register_buffer('beta', torch.full((N,), fixed_beta))

        self.positive_weights = positive_weights
        self.learn_beta = learn_beta  # Keep track of whether beta is learnable

    def forward(self, x):
        x = x.unsqueeze(-1)  # Make sure that x is treated as a column vector
        epsilon = 1e-8  # Small constant to prevent division by zero

        # Apply softplus if positive_weights flag is True
        weights = F.softplus(self.weights) if self.positive_weights else self.weights

        # Apply transformation to beta if it is learnable to ensure it's greater than 0.5
        beta = (F.softplus(self.beta) + 0.5) if self.learn_beta else self.beta

        # Compute the Gaussian components with the given beta exponent power
        gaussians = weights * torch.exp(-((x - self.means).abs() ** beta) / (2 * self.variances + epsilon))

        # Sum over the Gaussian components to form the final output
        return gaussians.sum(dim=-1)
def train_model(model, optimizer, loss_func, x, y, epochs):
    for epoch in range(epochs):
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        loss.backward()
        optimizer.step()
    return model , y_pred , loss.item()

### ops

In [None]:
def plot_results_gaussian(model, x, y, y_pred, N, loss_rec, positive_weights, show_fig=True, signal_type="square"):
    plt.figure(figsize=(12, 8))
    
    plt.plot(x.numpy(), y.numpy(), 'r-', label='True ' + signal_type, linewidth=2)
    plt.plot(x.numpy(), y_pred.detach().numpy(), 'b--', label='Gaussian Mixture', linewidth=2)

    for i in range(N):
        weights = F.softplus(model.weights) if model.positive_weights else model.weights
        component = weights[i].item() * torch.exp(-(x - model.means[i].item())**2 / (2 * model.variances[i].item()))
        plt.plot(x.numpy(), component.numpy(), 'g-.', alpha=0.5, linewidth=2)

    plt.legend(fontsize=20,loc='upper right')
    plt.title(f"Overfitting Gaussian Mixture to a {signal_type} Function, N={N}, loss={100*loss_rec:.2f}", fontsize=22)
    plt.xlabel("x", fontsize=22)
    plt.ylabel("y", fontsize=22)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.grid(True)
    
    if not os.path.exists(os.path.join("comparisons", "gaussians")):
        os.makedirs(os.path.join("comparisons", "gaussians"))
    
    suffix = "P" if positive_weights else "N"
    plt.savefig(os.path.join("comparisons", "gaussians", signal_type + "_" + suffix + str(N) + ".pdf"), bbox_inches="tight")
    if show_fig:
        plt.show()
def plot_results_dog(model, x, y, y_pred, N, loss_rec, positive_weights, show_fig=True, signal_type="square"):
    plt.figure(figsize=(12, 8))
    
    plt.plot(x.numpy(), y.numpy(), 'r-', label='True ' + signal_type, linewidth=2)
    plt.plot(x.numpy(), y_pred.detach().numpy(), 'b--', label='DoG Mixture', linewidth=2)

    for i in range(N):
        weights = F.softplus(model.weights) if model.positive_weights else model.weights
        var1 = model.scales[i].item()**2
        var2 = var1 / model.ratio
        
        component = weights[i].item() * (torch.exp(-(x - model.means[i].item())**2 / (2 * var1)) -
                                         torch.exp(-(x - model.means[i].item())**2 / (2 * var2)))
        plt.plot(x.numpy(), component.numpy(), 'g-.', alpha=0.5, linewidth=2)

    plt.legend(fontsize=20)
    plt.title(f"Overfitting DoG Mixture to a {signal_type} Function, N={N}, loss={100*loss_rec:.2f}", fontsize=22)
    plt.xlabel("x", fontsize=22)
    plt.ylabel("y", fontsize=22)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.grid(True)
    
    if not os.path.exists(os.path.join("comparisons", "dogs")):
        os.makedirs(os.path.join("comparisons", "dogs"))
    
    suffix = "DP" if positive_weights else "DN"
    plt.savefig(os.path.join("comparisons", "dogs", signal_type + "_" + suffix + str(N) + ".pdf"), bbox_inches="tight")
    if show_fig:
        plt.show()
def plot_results_log(model, x, y, y_pred, N, loss_rec, positive_weights, show_fig=True, signal_type="square"):
    plt.figure(figsize=(12, 8))
    
    plt.plot(x.numpy(), y.numpy(), 'r-', label='True ' + signal_type, linewidth=2)
    plt.plot(x.numpy(), y_pred.detach().numpy(), 'b--', label='LoG Mixture', linewidth=2)

    for i in range(N):
        weights = F.softplus(model.weights) if model.positive_weights else model.weights
        component = weights[i].item() * (- (x - model.means[i].item())**2 / (model.scales[i].item()**2) + 1) * \
                    torch.exp(-(x - model.means[i].item())**2 / (2 * model.scales[i].item()**2))
        plt.plot(x.numpy(), component.numpy(), 'g-.', alpha=0.5, linewidth=2)

    plt.legend(fontsize=20)
    plt.title(f"Overfitting LoG Mixture to a {signal_type} Function, N={N}, loss={100*loss_rec:.2f}", fontsize=22)
    plt.xlabel("x", fontsize=22)
    plt.ylabel("y", fontsize=22)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.grid(True)
    
    if not os.path.exists(os.path.join("comparisons", "logs")):
        os.makedirs(os.path.join("comparisons", "logs"))
    
    suffix = "LP" if positive_weights else "LN"
    plt.savefig(os.path.join("comparisons", "logs", signal_type + "_" + suffix + str(N) + ".pdf"), bbox_inches="tight")
    if show_fig:
        plt.show()
def plot_results_general(model, x, y, y_pred, N, loss_rec, positive_weights, show_fig=True, signal_type="square"):
    plt.figure(figsize=(12, 8))

    plt.plot(x.numpy(), y.numpy(), 'r-', label='True ' + signal_type, linewidth=2)
    plt.plot(x.numpy(), y_pred.detach().numpy(), 'b--', label='GEF Mixture', linewidth=2)

    for i in range(N):
        weights = F.softplus(model.weights) if model.positive_weights else model.weights
        component = weights[i].item() * torch.exp(-((x - model.means[i].item()).abs() ** model.beta[i].item()) / (2 * model.variances[i].item()))
        plt.plot(x.numpy(), component.detach().numpy(), 'g-.', alpha=0.5, linewidth=2)

    plt.legend(fontsize=20)
    plt.title(f"Overfitting General Mixture to a {signal_type} Function, N={N}, loss={100*loss_rec:.2f}", fontsize=22)
    plt.xlabel("x", fontsize=22)
    plt.ylabel("y", fontsize=22)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.grid(True)

    if not os.path.exists(os.path.join("comparisons", "general_mixtures")):
        os.makedirs(os.path.join("comparisons", "general_mixtures"))
    
    suffix = "P" if positive_weights else "N"
    plt.savefig(os.path.join("comparisons", "general_mixtures", signal_type + "_" + suffix + str(N) + ".pdf"), bbox_inches="tight")
    if show_fig:
        plt.show()
def triangle_wave(x, width=2):
    period = width / 2.0
    phase = (x + period) % width - period
    return torch.where(phase < 0, 1 + phase / period, 1 - phase / period)
def get_signal(x, signal_type, width=2):
    if signal_type == "square":
        return torch.where((x > -width/2) & (x < width/2), torch.ones_like(x), torch.zeros_like(x))
    elif signal_type == "triangle":
        return torch.where((x > -width/2) & (x < width/2), width/2 - torch.abs(x), torch.zeros_like(x))
    elif signal_type == "parabolic":
        return torch.where((x > -width/2) & (x < width/2), (width/2)**2 - x**2, torch.zeros_like(x))
    elif signal_type == "half_sinusoid":
        return torch.where((x > -width/2) & (x < width/2), torch.sin((x + width/2) * (math.pi / width)), torch.zeros_like(x))
    elif signal_type == "exponential":
        return torch.where((x > -width/2) & (x < width/2), torch.exp(-torch.abs(x)), torch.zeros_like(x))
    elif signal_type == "gaussian":
        return torch.exp(-x.pow(2) / (2 * (width/3)**2))
    else:
        raise ValueError(f"Unknown signal_type: {signal_type}")


### single fit 

In [None]:
N = 5
positive_weights = True
epochs = 10000
data_size  = 1000
data_extent = 10
signal_width = 6
learn_beta = True # only for general mixture  
fixed_beta = 2.0 # only for general mixture  
signal_types = ["square", "triangle", "parabolic", "half_sinusoid", "gaussian","exponential"]
signal_type = "square" 
x = torch.linspace(-data_extent, data_extent, data_size).unsqueeze(-1).cuda()
y = get_signal(x,signal_type,signal_width).cuda()
model = GaussianMixture(N,positive_weights=positive_weights).cuda()
# model = DoGMixture(N,positive_weights=positive_weights).cuda()
# model = LoGMixture(N,positive_weights=positive_weights).cuda()
# model = GeneralMixture(N,positive_weights=positive_weights,learn_beta=learn_beta,fixed_beta=fixed_beta ).cuda()


optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.MSELoss()
# Training
model,y_pred,loss_rec = train_model(model, optimizer, loss_func, x, y, epochs=epochs)

plot_results_gaussian(model, x.cpu(), y.cpu(), y_pred.cpu(), N=N, loss_rec=loss_rec, positive_weights=positive_weights,show_fig=False,signal_type=signal_type)
# plot_results_dog(model, x.cpu(), y.cpu(), y_pred.cpu(), N=N, loss_rec=loss_rec, positive_weights=positive_weights,show_fig=False,signal_type=signal_type)
# plot_results_log(model, x.cpu(), y.cpu(), y_pred.cpu(), N=N, loss_rec=loss_rec, positive_weights=positive_weights,show_fig=False,signal_type=signal_type)
# plot_results_general(model, x.cpu(), y.cpu(), y_pred.cpu(), N=N, loss_rec=loss_rec, positive_weights=positive_weights,show_fig=False,signal_type=signal_type)
print("loss:   ",loss_rec)

### multiple fits 

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
# Configurations
signal_types = ["square", "triangle", "parabolic", "half_sinusoid", "gaussian","exponential"]
Ns =  [2,5,8, 10, 15,20,50,100]
positive_weights_list = [True, False]
model_types = ['gaussian', 'dog', 'log','general']
epochs = 10000
runs_per_config = 20  # run each configuration multiple times to account for variance
data_size  = 1000
data_extent = 10
signal_width = 6
learn_beta = True # if gernal mixture used 
# Create a directory to store the results
if not os.path.exists("comparisons"):
    os.makedirs("comparisons")

In [None]:
results_per_signal = {}  # Dictionary to hold results for each signal type

for signal_type in signal_types:
    # Data
    x = torch.linspace(-data_extent, data_extent, data_size).unsqueeze(-1).cuda()
    y = get_signal(x,signal_type,signal_width).cuda()
    results = []  # To store the average loss and nan_counts per configuration
    
    for N in Ns:
        for positive_weights in positive_weights_list:
            for model_type in model_types:

                # Keeping track of how many times the training encounters 'nan' loss
                nan_counts = 0
                total_loss = 0  # Sum of losses for the current configuration

                for run in range(runs_per_config):
                    # Add a condition to handle the 'general' model type
                    if model_type == 'gaussian':
                        model = GaussianMixture(N, positive_weights=positive_weights).cuda()
                    elif model_type == 'dog':
                        model = DoGMixture(N, positive_weights=positive_weights).cuda()
                    elif model_type == 'log':
                        model = LoGMixture(N, positive_weights=positive_weights).cuda()
                    elif model_type == 'general':
                        # Adjust parameters as needed for GeneralMixture
                        model = GeneralMixture(N, positive_weights=positive_weights, learn_beta=learn_beta).cuda()  # or other parameters


                    optimizer = optim.Adam(model.parameters(), lr=0.01)
                    loss_func = nn.MSELoss()

                    model, y_pred, loss_rec = train_model(model, optimizer, loss_func, x, y, epochs=epochs)

                    if np.isnan(loss_rec):
                        nan_counts += 1
                    else:
                        total_loss += loss_rec

                    # Plot and save
                    plot_func = {
                        'gaussian': plot_results_gaussian,
                        'dog': plot_results_dog,
                        'log': plot_results_log,
                        'general': plot_results_general  # make sure you have a plot function for GeneralMixture
                    }[model_type]

                    plot_func(model, x.cpu(), y.cpu(), y_pred.cpu(), N=N, loss_rec=loss_rec,
                              positive_weights=positive_weights, show_fig=False,signal_type=signal_type)

                average_loss = total_loss / (runs_per_config - nan_counts) if runs_per_config - nan_counts != 0 else np.nan
                results.append({
                    'Model': model_type,
                    'N': N,
                    'Positive Weights': positive_weights,
                    'NaN Loss Counts': nan_counts,
                    'Average Loss': average_loss
                })

    for r in results:
        print(f"Signal: {signal_type}, Model: {r['Model']}, N: {r['N']}, Positive Weights: {r['Positive Weights']}, NaN Loss Counts: {r['NaN Loss Counts']}/{runs_per_config}, Average Loss: {r['Average Loss']:.2f}")
    
    results_per_signal[signal_type] = results

In [None]:
for signal_type, results in results_per_signal.items():

    # Initializing a dictionary for holding results
    results_dict = {
        'gaussian': {
            'positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns},
            'non-positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns}
        },
        'dog': {
            'positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns},
            'non-positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns}
        },
        'log': {
            'positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns},
            'non-positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns}
        },
         'general': {  # Add the 'general' model type
            'positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns},
            'non-positive': {N_val: {'stability': 0, 'loss': 0} for N_val in Ns}
        }
        
    }
    # Assuming the results list is populated correctly as before:
    for result in results:
        model_type = result['Model'].lower()  # Making sure we match keys in results_dict
        weight_type = 'positive' if result['Positive Weights'] else 'non-positive'
        N_val = result['N']
        stability = 100 * (runs_per_config - result['NaN Loss Counts']) / runs_per_config
        results_dict[model_type][weight_type][N_val]['stability'] = stability
        results_dict[model_type][weight_type][N_val]['loss'] = result['Average Loss']
    # Step 3: Extract data for plotting
    # Extracting stability and loss values
    def extract_values(results_dict, model_type, weight_type, Ns):
        return [results_dict[model_type][weight_type][N_val]['stability'] for N_val in Ns], \
               [results_dict[model_type][weight_type][N_val]['loss'] for N_val in Ns]

    stability_gaussian_positive, loss_gaussian_positive = extract_values(results_dict, 'gaussian', 'positive', Ns)
    stability_gaussian_non_positive, loss_gaussian_non_positive = extract_values(results_dict, 'gaussian', 'non-positive', Ns)

    stability_dog_positive, loss_dog_positive = extract_values(results_dict, 'dog', 'positive', Ns)
    stability_dog_non_positive, loss_dog_non_positive = extract_values(results_dict, 'dog', 'non-positive', Ns)

    stability_log_positive, loss_log_positive = extract_values(results_dict, 'log', 'positive', Ns)
    stability_log_non_positive, loss_log_non_positive = extract_values(results_dict, 'log', 'non-positive', Ns)

    stability_general_positive, loss_general_positive = extract_values(results_dict, 'general', 'positive', Ns)
    stability_general_non_positive, loss_general_non_positive = extract_values(results_dict, 'general', 'non-positive', Ns)

    # Now, let's plot the results:

    color_scheme = {
        'gaussian': 'b',
        'dog': 'g',
        'log': 'r',
        'general': 'm'  # A color for the 'general' model, e.g., magenta

    }

    # Define line styles for weight types
    line_styles = {
        'positive': '-',
        'non-positive': '--'
    }
    if not os.path.exists(os.path.join("comparisons", "stats")):
        os.makedirs(os.path.join("comparisons", "stats"))

    # Stability vs Number of Components
    plt.figure(figsize=(12, 8))

    # Use linewidth=2 for all plots
    plt.plot(Ns, stability_gaussian_positive, color_scheme['gaussian'] + line_styles['positive'] + 'o', label='Gaussian (Positive)', linewidth=2)
    plt.plot(Ns, stability_gaussian_non_positive, color_scheme['gaussian'] + line_styles['non-positive'] + 'x', label='Gaussian (Real)', linewidth=2)
    
    plt.plot(Ns, stability_dog_positive, color_scheme['dog'] + line_styles['positive'] + 'o', label='DoG (Positive)', linewidth=2)
    plt.plot(Ns, stability_dog_non_positive, color_scheme['dog'] + line_styles['non-positive'] + 'x', label='DoG (Real)', linewidth=2)

    plt.plot(Ns, stability_log_positive, color_scheme['log'] + line_styles['positive'] + 'o', label='LoG (Positive)', linewidth=2)
    plt.plot(Ns, stability_log_non_positive, color_scheme['log'] + line_styles['non-positive'] + 'x', label='LoG (Real)', linewidth=2)

    plt.plot(Ns, stability_general_positive, color_scheme['general'] + line_styles['positive'] + 'o', label='GEF (Positive)', linewidth=2)
    plt.plot(Ns, stability_general_non_positive, color_scheme['general'] + line_styles['non-positive'] + 'x', label='GEF (Real)', linewidth=2)
    
    
    plt.title("Stability vs Number of Components on {} signal".format(signal_type), fontsize=22)
    plt.xlabel("Number of Components (N)", fontsize=22)
    plt.ylabel("Stability (%)", fontsize=22)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.legend(fontsize=20)
    plt.grid(True)
    plt.savefig(os.path.join("comparisons", "stats", "stability_{}.pdf".format(signal_type)), bbox_inches="tight")
    plt.show()

    # Loss Values vs Number of Components
    plt.figure(figsize=(12, 8))

    # Use linewidth=2 for all plots
    plt.plot(Ns, loss_gaussian_positive, color_scheme['gaussian'] + line_styles['positive'] + 'o', label='Gaussian (Positive)', linewidth=2)
    plt.plot(Ns, loss_gaussian_non_positive, color_scheme['gaussian'] + line_styles['non-positive'] + 'x', label='Gaussian (Real)', linewidth=2)
    plt.plot(Ns, loss_dog_positive, color_scheme['dog'] + line_styles['positive'] + 'o', label='DoG (Positive)', linewidth=2)
    plt.plot(Ns, loss_dog_non_positive, color_scheme['dog'] + line_styles['non-positive'] + 'x', label='DoG (Real)', linewidth=2)

    plt.plot(Ns, loss_log_positive, color_scheme['log'] + line_styles['positive'] + 'o', label='LoG (Positive)', linewidth=2)
    plt.plot(Ns, loss_log_non_positive, color_scheme['log'] + line_styles['non-positive'] + 'x', label='LoG (Real)', linewidth=2)

    plt.plot(Ns, loss_general_positive, color_scheme['general'] + line_styles['positive'] + 'o', label='GEF (Positive)', linewidth=2)
    plt.plot(Ns, loss_general_non_positive, color_scheme['general'] + line_styles['non-positive'] + 'x', label='GEF (Real)', linewidth=2)
    
    
    plt.title("Loss Value vs Number of Components on {} signal".format(signal_type), fontsize=22)
    plt.xlabel("Number of Components (N)", fontsize=22)
    plt.ylabel("Average Loss (log scale)", fontsize=22)
    plt.yscale('log')
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.legend(fontsize=20)
    plt.grid(True)
    plt.savefig(os.path.join("comparisons", "stats", "loss_{}.pdf".format(signal_type)), bbox_inches="tight")
    plt.show()

# Miscellaneous visualizations 

## Frequency-based image filter $\mathbf{M}_{\omega}$ for target freq $\omega$

from the GES paper https://abdullahamdi.com/ges/

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
import matplotlib
import torch
import os
import torch.nn.functional as F
def tensor_info(tensor):
    """
    Prints information about a PyTorch tensor including min, max, mean, std, and shape.
    
    Args:
        tensor (torch.Tensor): Input tensor
    """
    # print("\nname:", f"{tensor=}")
    print("\nShape:", tensor.shape)
    print("Datatype:", tensor.dtype)
    print("Device:", tensor.device)
    print("Requires grad:", tensor.requires_grad)
    print("Min value:", tensor.min().item())
    print("Max value:", tensor.max().item())
    print("Mean value:", tensor.mean().item())
    print("Standard deviation:", tensor.std().item())

def show(img,c_path=None):
    img = img - img.min()  # Normalize to [0, max]
    img = img / img.max()  # Normalize to [0, 1]
    npimg = img.detach().cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
    plt.show()
def show_mask(img,c_path=None):
    img = img - img.min()  # Normalize to [0, max]
    img = img / img.max()  # Normalize to [0, 1]
    
    # Detach tensor from computation graph and move to CPU
    npimg = img.detach().cpu().squeeze().numpy()  # Squeeze is used to remove any singleton dimensions
    
    plt.imshow(npimg, interpolation='nearest', cmap='inferno')
    plt.colorbar()
    plt.show()
def show_with_masks(image, masks, mask_names=None, alpha=0.5,c_path=None):
    """
    Overlay masks on an image.
    
    Args:
    - image: PyTorch tensor of shape (C, H, W)
    - masks: List of PyTorch tensors, each of shape (H, W)
    - mask_names: List of names for each mask for the legend
    - alpha: Transparency level for masks
    """
    # Normalize image
    image = image - image.min()
    image = image / image.max()
    npimg = image.detach().cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
    
    colors = ['red', 'green', 'blue', 'yellow', 'purple', 'cyan']  # You can add more colors if needed
    
    for idx, mask in enumerate(masks):
        mask = mask - mask.min()
        mask = mask / mask.max()
        npmask = mask.detach().cpu().numpy()

        # Overlay the binary mask with a color
        # First, create a RGB version of the mask where it's colored
        mask_colored = np.zeros((npmask.shape[0], npmask.shape[1], 3))
        for i in range(3):  # for R, G, B channels
            mask_colored[..., i] = npmask * matplotlib.colors.to_rgb(colors[idx % len(colors)])[i]
        
        plt.imshow(mask_colored, interpolation='nearest', alpha=alpha)
    
    if mask_names:
        patches = [plt.Rectangle((0,0),1,1, color=colors[i % len(colors)]) for i in range(len(masks))]
        plt.legend(patches, mask_names, loc='upper left')
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoGFilter(nn.Module):
    def __init__(self, channels, sigma1):
        super(DoGFilter, self).__init__()
        self.channels = channels
        self.sigma1 = sigma1
        self.sigma2 = 2 * sigma1  # Ensure the 1:2 ratio
        self.kernel_size1 = int(2 * round(3 * self.sigma1) + 1)
        self.kernel_size2 = int(2 * round(3 * self.sigma2) + 1)
        self.padding1 = (self.kernel_size1 - 1) // 2
        self.padding2 = (self.kernel_size2 - 1) // 2
        self.weight1 = self.get_gaussian_kernel(self.kernel_size1, self.sigma1)
        self.weight2 = self.get_gaussian_kernel(self.kernel_size2, self.sigma2)


    def get_gaussian_kernel(self, kernel_size, sigma):
        x_cord = torch.arange(kernel_size)
        x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
        y_grid = x_grid.t()
        xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()

        mean = (kernel_size - 1) / 2.
        variance = sigma**2.
        
        kernel = torch.exp(-(xy_grid - mean).pow(2).sum(dim=-1) / (2 * variance))
        kernel = kernel / kernel.sum()  # Normalize the kernel
        kernel = kernel.repeat(self.channels, 1, 1, 1)
        
        return kernel

    @torch.no_grad()
    def forward(self, x):
        gaussian1 = F.conv2d(x, self.weight1.to(x.device), bias=None, stride=1, padding=self.padding1, groups=self.channels)
        gaussian2 = F.conv2d(x, self.weight2.to(x.device), bias=None, stride=1, padding=self.padding2, groups=self.channels)
        return gaussian1 - gaussian2
def apply_dog_filter(batch, freq=50, scale_factor=0.5):
    """
    Apply a Difference of Gaussian filter to a batch of images.
    
    Args:
        batch: torch.Tensor, shape (B, C, H, W)
        freq: Control variable ranging from 0 to 100.
              - 0 means original image
              - 1.0 means smoother difference
              - 100 means sharpest difference
        scale_factor: Factor by which the image is downscaled before applying DoG.
    
    Returns:
        torch.Tensor: Processed image using DoG.
    """
    # Convert to grayscale if it's a color image
    if batch.size(1) == 3:
        batch = torch.mean(batch, dim=1, keepdim=True)

    # Downscale the image
    downscaled = F.interpolate(batch, scale_factor=scale_factor, mode='bilinear', align_corners=False)

    channels = downscaled.size(1)

    # Set sigma1 value based on freq parameter. sigma2 will be 2*sigma1.
    sigma1 = 0.1 + (100 - freq) * 0.1 if freq >=50 else 0.1 + freq * 0.1

    dog_filter = DoGFilter(channels, sigma1)
    mask = dog_filter(downscaled)

    # Upscale the mask back to original size
    upscaled_mask = F.interpolate(mask, size=batch.shape[-2:], mode='bilinear', align_corners=False)

    upscaled_mask = upscaled_mask - upscaled_mask.min()
    upscaled_mask = upscaled_mask / upscaled_mask.max() if freq >=50 else  1.0 - upscaled_mask / upscaled_mask.max()
    
    upscaled_mask = (upscaled_mask >=0.5).to(torch.float)
    return upscaled_mask[:,0,...]
    
class LoGFilter(nn.Module):
    def __init__(self, channels, sigma, max_kernel_size=None):
        super(LoGFilter, self).__init__()
        self.channels = channels
        self.sigma = sigma
        self.kernel_size = int(2 * round(3 * sigma) + 1)
        if max_kernel_size:
            self.kernel_size = min(self.kernel_size, max_kernel_size)
        self.padding = (kernel_size - 1) // 2
        self.weight = self.get_log_kernel()

    def get_log_kernel(self):
        x_cord = torch.arange(self.kernel_size)
        x_grid = x_cord.repeat(self.kernel_size).view(self.kernel_size, self.kernel_size)
        y_grid = x_grid.t()
        xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()

        mean = (self.kernel_size - 1) / 2.
        variance = self.sigma**2.
        
        # Laplacian of Gaussian formula
        kernel = (-1 / (3.14 * variance**4)) * (1 - (xy_grid - mean).pow(2).sum(dim=-1) / (2 * variance))
        kernel *= torch.exp(-(xy_grid - mean).pow(2).sum(dim=-1) / (2 * variance))
        
        kernel = kernel - kernel.mean()
        kernel = kernel.repeat(self.channels, 1, 1, 1)
        
        return kernel

    @torch.no_grad()
    def forward(self, x):
        return F.conv2d(x, self.weight, bias=None, stride=1, padding=self.padding, groups=self.channels)

# Global dictionary to store kernels
kernels = {}

def apply_log_filter(batch, sigma=1.0):
    """
    Apply a Laplacian of Gaussian filter to a batch of images to highlight high-frequency areas.
    Args:
        batch: torch.Tensor, shape (B, C, H, W)
        sigma: control variable that determines the frequency band of the highlighted region.
    Returns:
        torch.Tensor: a grayscale mask highlighting high-frequency areas.
    """
    # Convert to grayscale if it's a color image
    if batch.size(1) == 3:
        batch = torch.mean(batch, dim=1, keepdim=True)
    
    channels = batch.size(1)
    kernel_size = int(2 * round(3 * sigma) + 1)  # Ensure kernel size is odd

    # Use existing kernel if it's in the dictionary
    if (channels, kernel_size, sigma) in kernels:
        weight = kernels[(channels, kernel_size, sigma)]
    else:
        log_filter = LoGFilter(channels, kernel_size, sigma)
        weight = log_filter.weight
        kernels[(channels, kernel_size, sigma)] = weight

    padding = (kernel_size - 1) // 2
    mask = F.conv2d(batch, weight, bias=None, stride=1, padding=padding, groups=channels)
    
    # Normalize the mask to [0, 1]
    mask = mask - mask.min()
    mask = mask / mask.max()
    
    return mask[:,0,...]

In [None]:
# Example usage:
import torchvision
freq = 20
scale_factor = 0.2
# batch_images = torch.randn(16, 3, 224, 224)  # Example batch of images
batch_images = torchvision.io.read_image(os.path.join(".","assets","example.png"))[None,...].to(torch.float).to("cuda")
tensor_info(batch_images)

# mask = apply_log_filter(batch_images, sigma = sigma)
for freq in range(5,101,5):
    mask = apply_dog_filter(batch_images, freq = freq,scale_factor=scale_factor)
    tensor_info(mask)
    show(batch_images[0])
    show_mask(mask[0])
    show_with_masks(batch_images[0], [mask[0]], mask_names=["freq:{}%".format(int(freq))], alpha=0.5,c_path=os.path.join("output",f"mask_{freq}.png"))


## studying fourier domain of different signals 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Time-domain signal definitions
def get_signal(x, signal_type, width=2):
    if signal_type == "square":
        return np.where((x > -width/2) & (x < width/2), np.ones_like(x), np.zeros_like(x))
    elif signal_type == "triangle":
        return np.where((x > -width/2) & (x < width/2), width/2 - np.abs(x), np.zeros_like(x))
    elif signal_type == "parabolic":
        return np.where((x > -width/2) & (x < width/2), (width/2)**2 - x**2, np.zeros_like(x))
    elif signal_type == "half_sinusoid":
        return np.where((x > -width/2) & (x < width/2), np.sin((x + width/2) * (np.pi / width)), np.zeros_like(x))
    elif signal_type == "exponential":
        return np.where((x > -width/2) & (x < width/2), np.exp(-np.abs(x)), np.zeros_like(x))
    elif signal_type == "gaussian":
        return np.exp(-x**2 / (2 * (width/3)**2))

# Fourier transform definitions
def fourier_square_wave(x, width):
    return np.sinc(x * width / np.pi)

def fourier_triangle_wave(x, width):
    return (np.sinc(x * width / (2 * np.pi)))**2

def fourier_parabolic_wave(x, width):
    return (3 * (np.sinc(x * width / (2 * np.pi)))**2) / (np.pi**2 * x**2)

def fourier_half_sinusoid(x, width):
    return np.where(x == 0, width / 2, width * np.sin(np.pi * x * width) / (np.pi**2 * x**2))

def fourier_exponential(x, width):
    return width / (x**2 + (width/2)**2)

def fourier_gaussian(x, width):
    sigma = width / 3
    return np.sqrt(2 * np.pi) * sigma * np.exp(-2 * (np.pi**2) * sigma**2 * x**2)

In [None]:
###### Signal types and their Fourier transform functions
signal_types = ["square", "triangle", "parabolic", "half_sinusoid", "exponential", "gaussian"]
# signal_types = ["square", "triangle","gaussian"]

fourier_functions = {
  "square": fourier_square_wave, 
  "triangle": fourier_triangle_wave, 
  "parabolic": fourier_parabolic_wave, 
  "half_sinusoid": fourier_half_sinusoid, 
  "exponential": fourier_exponential, 
  "gaussian": fourier_gaussian
}

# Time and frequency domain range
x_time = np.linspace(-10, 10, 1000)
x_freq = np.linspace(-10, 10, 1000)
width = 2

# Create the plots
plt.figure(figsize=(20, 12))
# plt.figure(figsize=(20, 6))


for i, signal_type in enumerate(signal_types):
    # Time-domain signal
    signal = get_signal(x_time, signal_type, width)

    # Fourier transform
    fourier_transform = fourier_functions[signal_type](x_freq, width)

    # Plotting time-domain signal
    plt.subplot(len(signal_types), 2, 2*i + 1)
    plt.plot(x_time, signal, label=f"{signal_type} Signal", linewidth=2)
    plt.title(f"{signal_type.capitalize()} Signal", fontsize=18)
    plt.xlabel("Time", fontsize=16)
    plt.ylabel("Amplitude", fontsize=16)
    plt.grid(True, which="both", ls="--")
    # plt.legend(fontsize=16)
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.tick_params(axis='both', which='minor', labelsize=15)

    # Plotting Fourier transform
    plt.subplot(len(signal_types), 2, 2*i + 2)
    plt.plot(x_freq, fourier_transform, label=f"Fourier of {signal_type}", linewidth=2)
    plt.title(f"Fourier Transform of {signal_type.capitalize()}", fontsize=18)
    plt.xlabel("Frequency", fontsize=16)
    plt.ylabel("Magnitude", fontsize=16)
    plt.grid(True, which="both", ls="--")
    # plt.legend(fontsize=16)
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.tick_params(axis='both', which='minor', labelsize=15)

plt.tight_layout()
plt.savefig("output/part_signals_and_fourier_transforms.pdf")
plt.show()