# Init

In [None]:
%cd ..

import matplotlib.pyplot as plt
import re
import numpy as np
from collections import defaultdict
import os

plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman"],
    "text.usetex": True,
    'font.size': 13,
})

In [None]:
def plot_psnr_analysis(resolution, snr=30, seeds=[42], types=["linear", "mlp", "conv", "twoconv", "zeroshot"]):

    type_name_dict = {
        "linear": "Linear",
        "mlp": "MLP",
        "conv": "1-layer CNN",
        "twoconv": "2-layer CNN",
        "zeroshot": "PFE"
    }
    
    style_config = {
        "Linear": {"color": "tab:blue", "marker": "o", "linestyle": "-"},
        "MLP": {"color": "tab:orange", "marker": "o", "linestyle": "-"},
        "1-layer CNN": {"color": "tab:green", "marker": "o", "linestyle": "-"},
        "2-layer CNN": {"color": "tab:red", "marker": "o", "linestyle": "-"},
        "PFE": {"color": "tab:purple", "marker": "o", "linestyle": "-"},
        "Unaligned": {"color": "red", "linestyle": "--"},
        "No mismatch": {"color": "tab:cyan", "linestyle": "--"},
        "PFE Full": {"color": "gold", "linestyle": "-"}
    }
    
    aligner_data = defaultdict(lambda: defaultdict(list))
    
    unaligned_values = []
    aligned_values = []
    zeroshot_values = []
    
    # process each seed
    for seed in seeds:
        # read baseline values for this seed
        lines_path = f'alignment/logs_{resolution}/lines_snr_{snr}_seed_{seed}.txt'
        
        if os.path.exists(lines_path):
            with open(lines_path, 'r') as file:
                lines = file.readlines()
                unaligned_values.append(float(lines[0].split()[1]))
                aligned_values.append(float(lines[1].split()[1]))
                zeroshot_values.append(float(lines[2].split()[1]))

        else:
            print(f"Warning: {lines_path} not found, skipping seed {seed}")
            continue
        
        # process each aligner type for this seed
        for aligner_type in types:
            log_path = f'alignment/logs_{resolution}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt'
            
            if os.path.exists(log_path):
                with open(log_path, 'r', encoding='utf-8') as file:
                    content = file.read()
                
                # extract sample counts and PSNR values
                matches = re.findall(r'(\d+) samples got a PSNR of ([\d.]+)', content)
                
                for sample_str, psnr_str in matches:
                    sample_count = int(sample_str)
                    psnr_value = float(psnr_str)
                    aligner_data[aligner_type][sample_count].append(psnr_value)

            else:
                print(f"Warning: {log_path} not found")
    
    # calculate means for baseline values
    unaligned_mean = np.mean(unaligned_values) if unaligned_values else 0
    aligned_mean = np.mean(aligned_values) if aligned_values else 0
    zeroshot_mean = np.mean(zeroshot_values) if zeroshot_values else 0
    
    # create the plot
    plt.figure(figsize=(9, 6))
    
    # plot each aligner type
    for aligner_type in types:
        if aligner_type in aligner_data:
            # get all sample counts and sort them
            sample_counts = sorted(aligner_data[aligner_type].keys())
            mean_psnr = []
            std_psnr = []
            
            for sample_count in sample_counts:
                psnr_values = aligner_data[aligner_type][sample_count]
                mean_psnr.append(np.mean(psnr_values))
                std_psnr.append(np.std(psnr_values))
            
            # convert to numpy arrays for easier manipulation
            sample_counts = np.array(sample_counts)
            mean_psnr = np.array(mean_psnr)
            std_psnr = np.array(std_psnr)
            
            # get style for this aligner type
            name = type_name_dict[aligner_type]
            style = style_config[name]
            
            # plot mean line
            line = plt.plot(sample_counts, mean_psnr, 
                          marker=style["marker"], 
                          label=name,
                          linewidth=3,
                          color=style["color"],
                          linestyle=style["linestyle"])[0]
            
            # fill area representing standard deviation
            plt.fill_between(sample_counts, 
                           mean_psnr - std_psnr, 
                           mean_psnr + std_psnr, 
                           alpha=0.2, 
                           color=style["color"])
    
    # plot baseline horizontal lines
    if zeroshot_values:
        style = style_config["PFE Full"]
        
        # Add circles at specific pilot points for PFE Full
        n_points = 20
        pilots_sets = np.unique(np.logspace(0, np.log10(10000), num=n_points, base=10).astype(int))
        y_values = np.full_like(pilots_sets, zeroshot_mean, dtype=float)
        
        # Plot dashed line without label
        plt.axhline(y=zeroshot_mean, xmin=1, xmax=10000, color=style["color"], 
                   linestyle=style["linestyle"], linewidth=2)
        # Plot circles with label for legend
        plt.plot(pilots_sets, y_values, 'o', color=style["color"], linestyle=style["linestyle"],
                linewidth=3, label='PFE Full')
    
    if unaligned_values:
        style = style_config["Unaligned"]
        plt.axhline(y=unaligned_mean, color=style["color"], 
                   linestyle=style["linestyle"], linewidth=2, 
                   label='Unaligned')
    if aligned_values:
        style = style_config["No mismatch"]
        plt.axhline(y=aligned_mean, color=style["color"], 
                   linestyle=style["linestyle"], linewidth=2, 
                   label='No mismatch')

    # plt.title(f"SNR {snr} - Resolution {resolution}x{resolution} - Mean across {len(seeds)} seeds")
    plt.xlabel("Number of Semantic Pilots", fontsize=18)
    plt.ylabel("PSNR (dB)", fontsize=18)
    plt.xscale('log')
    plt.tick_params(axis='both', labelsize=18)
    plt.legend(framealpha=1)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs('output', exist_ok=True)
    
    filename = f"psnr_vs_pilots_snr_{snr}"
    
    # save in both formats
    pdf_path = os.path.join('output', f"{filename}.pdf")
    png_path = os.path.join('output', f"{filename}.png")
    
    plt.savefig(pdf_path, format='pdf', dpi=300, bbox_inches='tight')
    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight')
    
    print(f"Plots saved as:")
    print(f"  PDF: {pdf_path}")
    print(f"  PNG: {png_path}")

    plt.show()

In [None]:
def plot_psnr_vs_snr_analysis(resolution, snr_ae=-10, snr_values=[-20, -10, 0, 10, 20, 30], seeds=[42, 43, 44, 45, 46],
                              types=["linear", "mlp", "conv", "twoconv", "zeroshot"],
                              target_sample_count=10000, noise_types=["AWGN", "Rayleigh"]):

    type_name_dict = {
        "linear": "Linear",
        "mlp": "MLP",
        "conv": "1-layer CNN",
        "twoconv": "2-layer CNN",
        "zeroshot": "PFE"
    }
    
    style_config = {
        "Linear": {"color": "tab:blue", "marker": "o", "linestyle": "-"},
        "MLP": {"color": "tab:orange", "marker": "o", "linestyle": "-"},
        "1-layer CNN": {"color": "tab:green", "marker": "o", "linestyle": "-"},
        "2-layer CNN": {"color": "tab:red", "marker": "o", "linestyle": "-"},
        "PFE": {"color": "tab:purple", "marker": "o", "linestyle": "-"},
        "Unaligned": {"color": "red", "linestyle": "--"},
        "No mismatch": {"color": "tab:cyan", "linestyle": "--"},
        "PFE Full": {"color": "gold", "linestyle": "-"}
    }
    
    noise_linestyles = {
        "AWGN": "--",      # dashed
        "Rayleigh": "-"    # solid
    }
    
    fig, ax = plt.subplots(figsize=(9, 6))
    
    # Store handles and labels for each noise type
    legend_handles = {noise_type: [] for noise_type in noise_types}
    legend_labels = {noise_type: [] for noise_type in noise_types}
    
    for noise_type in noise_types:
        aligner_data = defaultdict(lambda: defaultdict(list))  # {methodology: {snr: [psnr_values]}}
        
        # dictionary to store baseline values across SNR values
        baseline_data = {
            'unaligned': defaultdict(list),
            'aligned': defaultdict(list),
            'zeroshot_baseline': defaultdict(list)
        }
        
        # process each SNR value for current noise type
        for snr in snr_values:
            # Process each seed
            for seed in seeds:
                # read baseline values for this SNR, seed, and noise type
                lines_path = f'alignment/logs_{resolution}_{noise_type}/lines_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt'
                
                if os.path.exists(lines_path):
                    with open(lines_path, 'r') as file:
                        lines = file.readlines()
                        baseline_data['unaligned'][snr].append(float(lines[0].split()[1]))
                        baseline_data['aligned'][snr].append(float(lines[1].split()[1]))
                        baseline_data['zeroshot_baseline'][snr].append(float(lines[2].split()[1]))
                else:
                    print(f"Warning: {lines_path} not found, skipping {noise_type} SNR {snr}, seed {seed}")
                    continue
                
                # process each aligner type for this SNR, seed, and noise type
                for aligner_type in types:
                    log_path = f'alignment/logs_{resolution}_{noise_type}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt'
                    
                    if os.path.exists(log_path):
                        with open(log_path, 'r', encoding='utf-8') as file:
                            content = file.read()
                        
                        # extract sample counts and PSNR values
                        matches = re.findall(r'got a PSNR of ([\d.]+)', content)
                        
                        if matches:
                            for psnr_str in matches:
                                aligner_data[aligner_type][snr].append(float(psnr_str))
                    else:
                        print(f"Warning: {log_path} not found")
        
        # plot baseline lines for current noise type
        baseline_names = {
            'aligned': 'No mismatch', 
            'zeroshot_baseline': 'PFE Full'
        }
        
        linestyle = noise_linestyles[noise_type]
        
        for baseline_key, baseline_label in baseline_names.items():
            baseline_values = baseline_data[baseline_key]
            
            snr_list = []
            mean_psnr = []
            std_psnr = []
            
            for snr in sorted(baseline_values.keys()):
                psnr_values = baseline_values[snr]
                if psnr_values:
                    snr_list.append(snr)
                    mean_psnr.append(np.mean(psnr_values))
                    std_psnr.append(np.std(psnr_values))
            
            if snr_list:
                snr_array = np.array(snr_list)
                mean_psnr = np.array(mean_psnr)
                std_psnr = np.array(std_psnr)
                
                style = style_config[baseline_label]
                
                line = ax.plot(snr_array, mean_psnr, 
                            color=style["color"], 
                            linestyle=linestyle,
                            linewidth=3,
                            marker='o', markersize=8)[0]
                
                ax.fill_between(snr_array, 
                               mean_psnr - std_psnr, 
                               mean_psnr + std_psnr, 
                               alpha=0.1, 
                               color=style["color"])
                
                # Add to legend for this noise type
                legend_handles[noise_type].append(line)
                legend_labels[noise_type].append(baseline_label)
        
        # plot each aligner type for current noise type
        for aligner_type in types:
            if aligner_type in aligner_data:
                snr_list = []
                mean_psnr = []
                std_psnr = []
                
                for snr in sorted(aligner_data[aligner_type].keys()):
                    psnr_values = aligner_data[aligner_type][snr]
                    if psnr_values:
                        snr_list.append(snr)
                        mean_psnr.append(np.mean(psnr_values))
                        std_psnr.append(np.std(psnr_values))
                
                if snr_list:
                    snr_array = np.array(snr_list)
                    mean_psnr = np.array(mean_psnr)
                    std_psnr = np.array(std_psnr)
                    
                    label = type_name_dict[aligner_type]
                    style = style_config[label]
                    
                    line = ax.plot(snr_array, mean_psnr, 
                                   marker=style["marker"], markersize=8,
                                   linestyle=linestyle,
                                   linewidth=3,
                                   color=style["color"])[0]
                    
                    ax.fill_between(snr_array, 
                                   mean_psnr - std_psnr, 
                                   mean_psnr + std_psnr, 
                                   alpha=0.2, 
                                   color=style["color"])
                    
                    # Add to legend for this noise type
                    legend_handles[noise_type].append(line)
                    legend_labels[noise_type].append(label)

    # Create Method legend (using AWGN data)
    if "AWGN" in legend_handles and legend_handles["AWGN"]:
        if snr_ae == 20:
            method_legend = ax.legend(
                legend_handles["AWGN"],
                legend_labels["AWGN"],
                # title="Method",
                loc="upper left",
                bbox_to_anchor=(0, 1),
                framealpha=1
            )
        
        else:
            method_legend = ax.legend(
                legend_handles["AWGN"],
                legend_labels["AWGN"],
                # title="Method",
                loc="lower right",
                bbox_to_anchor=(1, 0),
                framealpha=1
            )

        ax.add_artist(method_legend)
    
    # Create Noise legend with custom line styles
    from matplotlib.lines import Line2D
    noise_legend_elements = [
        Line2D([0], [0], color='black', linestyle='--', linewidth=2, label='Without Fading'),
        Line2D([0], [0], color='black', linestyle='-', linewidth=2, label='With Fading'),
    ]
    if snr_ae == 20:
        noise_legend = ax.legend(
            handles=noise_legend_elements,
            # title="Noise",
            loc="upper left",
            bbox_to_anchor=(0.25, 1),
            framealpha=1
        )

    else: 
        noise_legend = ax.legend(
            handles=noise_legend_elements,
            # title="Noise",
            loc="lower right",
            bbox_to_anchor=(0.75, 0),
            framealpha=1
        )
    
    # formatting
    ax.set_xlabel("SNR (dB)", fontsize=18)
    ax.set_ylabel("PSNR (dB)", fontsize=18)
    ax.tick_params(axis='both', labelsize=18)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()

    # save the plot
    os.makedirs('output', exist_ok=True)
    
    sample_suffix = f"_samples_{target_sample_count}" if target_sample_count else "_max_samples"
    noise_suffix = "_".join(noise_types)
    filename = f"psnr_vs_snr_{snr_ae}_{noise_suffix}{sample_suffix}"
    
    # Save in both formats
    pdf_path = os.path.join('output', f"{filename}.pdf")
    png_path = os.path.join('output', f"{filename}.png")
    
    plt.savefig(pdf_path, format='pdf', dpi=300, bbox_inches='tight')
    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight')
    
    print(f"Combined plots saved as:")
    print(f"  PDF: {pdf_path}")
    print(f"  PNG: {png_path}")

    plt.show()

# Exec

In [None]:
plot_psnr_analysis(resolution=96, snr=20, seeds=[42, 43, 44, 45, 46])

In [None]:
plot_psnr_analysis(resolution=96, snr=-10, seeds=[42, 43, 44, 45, 46])

In [None]:
plot_psnr_vs_snr_analysis(resolution=96, snr_ae=-10, target_sample_count=10000, seeds=[42],
                          types=["linear", "conv", "twoconv"],
                          snr_values=[-20, -10, 0, 10, 20, 30])

In [None]:
plot_psnr_vs_snr_analysis(resolution=96, snr_ae=20, target_sample_count=10000, seeds=[42],
                          types=["linear", "conv", "twoconv"],
                          snr_values=[-20, -10, 0, 10, 20, 30])