# Init

In [None]:
%cd ..

import matplotlib.pyplot as plt
import re
import numpy as np
from collections import defaultdict
import os

plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman"],
    "text.usetex": True,
    'font.size': 13,
})


def plot_psnr_analysis(resolution, snr=30, seeds=[42], types=["linear", "mlp", "conv", "twoconv", "zeroshot"]):

    type_name_dict = {
        "linear": "Linear",
        "mlp": "MLP",
        "conv": "1-layer CNN",
        "twoconv": "2-layer CNN",
        "zeroshot": "PFE"
    }
    
    aligner_data = defaultdict(lambda: defaultdict(list))
    
    unaligned_values = []
    aligned_values = []
    zeroshot_values = []
    
    # process each seed
    for seed in seeds:
        # read baseline values for this seed
        lines_path = f'alignment/logs_{resolution}/lines_snr_{snr}_seed_{seed}.txt'
        
        if os.path.exists(lines_path):
            with open(lines_path, 'r') as file:
                lines = file.readlines()
                unaligned_values.append(float(lines[0].split()[1]))
                aligned_values.append(float(lines[1].split()[1]))
                zeroshot_values.append(float(lines[2].split()[1]))

        else:
            print(f"Warning: {lines_path} not found, skipping seed {seed}")
            continue
        
        # process each aligner type for this seed
        for aligner_type in types:
            log_path = f'alignment/logs_{resolution}/aligner_{aligner_type}_snr_{snr}_seed_{seed}.txt'
            
            if os.path.exists(log_path):
                with open(log_path, 'r', encoding='utf-8') as file:
                    content = file.read()
                
                # extract sample counts and PSNR values
                matches = re.findall(r'(\d+) samples got a PSNR of ([\d.]+)', content)
                
                for sample_str, psnr_str in matches:
                    sample_count = int(sample_str)
                    psnr_value = float(psnr_str)
                    aligner_data[aligner_type][sample_count].append(psnr_value)

            else:
                print(f"Warning: {log_path} not found")
    
    # calculate means for baseline values
    unaligned_mean = np.mean(unaligned_values) if unaligned_values else 0
    aligned_mean = np.mean(aligned_values) if aligned_values else 0
    zeroshot_mean = np.mean(zeroshot_values) if zeroshot_values else 0
    
    # create the plot
    plt.figure(figsize=(8, 6))
    
    # plot each aligner type
    for aligner_type in types:
        if aligner_type in aligner_data:
            # get all sample counts and sort them
            sample_counts = sorted(aligner_data[aligner_type].keys())
            mean_psnr = []
            std_psnr = []
            
            for sample_count in sample_counts:
                psnr_values = aligner_data[aligner_type][sample_count]
                mean_psnr.append(np.mean(psnr_values))
                std_psnr.append(np.std(psnr_values))
            
            # convert to numpy arrays for easier manipulation
            sample_counts = np.array(sample_counts)
            mean_psnr = np.array(mean_psnr)
            std_psnr = np.array(std_psnr)
            
            # plot mean line
            line = plt.plot(sample_counts, mean_psnr, marker='o', label=f'{type_name_dict[aligner_type]}',
                            linewidth=3)[0]
            
            # fill area representing standard deviation
            plt.fill_between(sample_counts, 
                           mean_psnr - std_psnr, 
                           mean_psnr + std_psnr, 
                           alpha=0.2, 
                           color=line.get_color())
    
    # plot baseline horizontal lines
    if unaligned_values:
        plt.axhline(y=unaligned_mean, color='red', linestyle='--', linewidth=2, 
                   label=f'Unaligned')
    if aligned_values:
        plt.axhline(y=aligned_mean, color='green', linestyle='--', linewidth=2, 
                   label=f'No mismatch')
    if zeroshot_values:
        plt.axhline(y=zeroshot_mean, color='blue', linestyle='--', linewidth=2, 
                   label=f'PFE Full')
    
    # plt.title(f"SNR {snr} - Resolution {resolution}x{resolution} - Mean across {len(seeds)} seeds")
    plt.xlabel("Number of Semantic Pilots", fontsize=18)
    plt.ylabel("PSNR (dB)", fontsize=18)
    plt.xscale('log')
    plt.tick_params(axis='both', labelsize=18)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    os.makedirs('output', exist_ok=True)
    
    filename = f"psnr_vs_pilots_snr_{snr}"
    
    # save in both formats
    pdf_path = os.path.join('output', f"{filename}.pdf")
    png_path = os.path.join('output', f"{filename}.png")
    
    plt.savefig(pdf_path, format='pdf', dpi=300, bbox_inches='tight')
    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight')
    
    print(f"Plots saved as:")
    print(f"  PDF: {pdf_path}")
    print(f"  PNG: {png_path}")

    plt.show()

def plot_psnr_vs_snr_analysis(resolution, snr_ae=-10, snr_values=[-20, -10, 0, 10, 20, 30], seeds=[42, 43, 44, 45, 46],
                              types=["linear", "mlp", "conv", "twoconv", "zeroshot"],
                              target_sample_count=10000, noise_types=["AWGN", "Rayleigh"]):
    """
    Generalized function to analyze PSNR vs SNR for multiple noise types on a single plot.
    
    Args:
        resolution: Image resolution
        snr_ae: SNR for autoencoder
        snr_values: List of SNR values to analyze
        seeds: List of random seeds
        types: List of aligner types
        target_sample_count: Target sample count for analysis
        noise_types: List of noise types (e.g., ["AWGN", "Rayleigh"])
    """
    from collections import defaultdict
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    import re

    type_name_dict = {
        "linear": "Linear",
        "conv": "1-layer CNN",
        "twoconv": "2-layer CNN",
        "zeroshot": "PFE"
    }
    
    # Markers for different noise types
    noise_markers = {
        "AWGN": "o",      # circles
        "Rayleigh": "^"   # triangles
    }
    
    plt.figure(figsize=(8, 6))
    
    for noise_type in noise_types:
        aligner_data = defaultdict(lambda: defaultdict(list))  # {methodology: {snr: [psnr_values]}}
        
        # dictionary to store baseline values across SNR values
        baseline_data = {
            'unaligned': defaultdict(list),
            'aligned': defaultdict(list),
            'zeroshot_baseline': defaultdict(list)
        }
        
        # process each SNR value for current noise type
        for snr in snr_values:
            # Process each seed
            for seed in seeds:
                # read baseline values for this SNR, seed, and noise type
                lines_path = f'alignment/logs_{resolution}_{noise_type}/lines_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt'
                
                if os.path.exists(lines_path):
                    with open(lines_path, 'r') as file:
                        lines = file.readlines()
                        baseline_data['unaligned'][snr].append(float(lines[0].split()[1]))
                        baseline_data['aligned'][snr].append(float(lines[1].split()[1]))
                        baseline_data['zeroshot_baseline'][snr].append(float(lines[2].split()[1]))
                else:
                    print(f"Warning: {lines_path} not found, skipping {noise_type} SNR {snr}, seed {seed}")
                    continue
                
                # process each aligner type for this SNR, seed, and noise type
                for aligner_type in types:
                    log_path = f'alignment/logs_{resolution}_{noise_type}/aligner_{aligner_type}_ae_{snr_ae}_snr_{snr}_seed_{seed}.txt'
                    
                    if os.path.exists(log_path):
                        with open(log_path, 'r', encoding='utf-8') as file:
                            content = file.read()
                        
                        # extract sample counts and PSNR values
                        matches = re.findall(r'got a PSNR of ([\d.]+)', content)
                        
                        if matches:
                            for psnr_str in matches:
                                aligner_data[aligner_type][snr].append(float(psnr_str))
                    else:
                        print(f"Warning: {log_path} not found")
        
        # plot baseline lines for current noise type
        baseline_labels = {
            'unaligned': 'Unaligned',
            'aligned': 'No Mismatch', 
            'zeroshot_baseline': 'PFE Full'
        }
        
        baseline_colors = {
            'unaligned': 'red',
            'aligned': 'green',
            'zeroshot_baseline': 'blue'
        }
        
        for baseline_name, baseline_values in baseline_data.items():
            if baseline_name == "unaligned":
                continue

            snr_list = []
            mean_psnr = []
            std_psnr = []
            
            for snr in sorted(baseline_values.keys()):
                psnr_values = baseline_values[snr]
                if psnr_values:
                    snr_list.append(snr)
                    mean_psnr.append(np.mean(psnr_values))
                    std_psnr.append(np.std(psnr_values))
            
            if snr_list:
                snr_array = np.array(snr_list)
                mean_psnr = np.array(mean_psnr)
                std_psnr = np.array(std_psnr)
                
                label = f"{baseline_labels[baseline_name]} ({noise_type})"
                marker = noise_markers.get(noise_type, "o")
                
                if baseline_name == "aligned":
                    line = plt.plot(snr_array, mean_psnr, 
                                color=baseline_colors[baseline_name], 
                                linestyle='--',  # dashed for baseline
                                linewidth=3,
                                marker=marker, markersize=8,
                                label=label)[0]
                else:
                    line = plt.plot(snr_array, mean_psnr, 
                                   marker=marker, markersize=8,
                                   label=label,
                                   linestyle='--',  # dashed for baseline
                                   linewidth=3)[0]
                
                plt.fill_between(snr_array, 
                               mean_psnr - std_psnr, 
                               mean_psnr + std_psnr, 
                               alpha=0.1, 
                               color=baseline_colors[baseline_name])
        
        # plot each aligner type for current noise type
        for aligner_type in types:
            if aligner_type in aligner_data:
                snr_list = []
                mean_psnr = []
                std_psnr = []
                
                for snr in sorted(aligner_data[aligner_type].keys()):
                    psnr_values = aligner_data[aligner_type][snr]
                    if psnr_values:
                        snr_list.append(snr)
                        mean_psnr.append(np.mean(psnr_values))
                        std_psnr.append(np.std(psnr_values))
                
                if snr_list:
                    snr_array = np.array(snr_list)
                    mean_psnr = np.array(mean_psnr)
                    std_psnr = np.array(std_psnr)
                    
                    label = f"{type_name_dict[aligner_type]} ({noise_type})"
                    marker = noise_markers.get(noise_type, "o")
                    
                    line = plt.plot(snr_array, mean_psnr, 
                                   marker=marker, markersize=8,
                                   label=label,
                                   linestyle='-',  # solid for aligners
                                   linewidth=3)[0]
                    
                    plt.fill_between(snr_array, 
                                   mean_psnr - std_psnr, 
                                   mean_psnr + std_psnr, 
                                   alpha=0.2, 
                                   color=line.get_color())
    
    # formatting
    plt.xlabel("SNR (dB)", fontsize=18)
    plt.ylabel("PSNR (dB)", fontsize=18)
    plt.tick_params(axis='both', labelsize=18)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    # save the plot
    os.makedirs('output', exist_ok=True)
    
    sample_suffix = f"_samples_{target_sample_count}" if target_sample_count else "_max_samples"
    noise_suffix = "_".join(noise_types)
    filename = f"psnr_vs_snr_{snr_ae}_{noise_suffix}_combined{sample_suffix}"
    
    # Save in both formats
    pdf_path = os.path.join('output', f"{filename}.pdf")
    png_path = os.path.join('output', f"{filename}.png")
    
    plt.savefig(pdf_path, format='pdf', dpi=300, bbox_inches='tight')
    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight')
    
    print(f"Combined plots saved as:")
    print(f"  PDF: {pdf_path}")
    print(f"  PNG: {png_path}")

    plt.show()

# Exec

In [None]:
plot_psnr_analysis(resolution=96, snr=20, seeds=[42, 43, 44, 45, 46])

In [None]:
plot_psnr_analysis(resolution=96, snr=-10, seeds=[42, 43, 44, 45, 46])

In [None]:
plot_psnr_vs_snr_analysis(resolution=96, snr_ae=-10, target_sample_count=10000, seeds=[42],
                          types=["linear", "conv", "twoconv"],
                          snr_values=[-20, -10, 0, 10, 20, 30])

In [None]:
plot_psnr_vs_snr_analysis(resolution=96, snr_ae=20, target_sample_count=10000, seeds=[42],
                          types=["linear", "conv", "twoconv"],
                          snr_values=[-20, -10, 0, 10, 20, 30])