In [3]:
import torch
from torch.distributions.normal import Normal
import plotly.graph_objects as go
from tqdm import tqdm  # for progress bar
from datetime import datetime

# Class Definition

## Blahut Arimoto Algorithm

In [4]:
class BlahutArimoto:
    def __init__(self, A=1, sigma=1, max_iter=10000, NX=500, NY=1000, tolerance=1e-6, epsilon=1e-12,
                 printInit=False, earlyStop=True, device=None):
        # Default variables
        self.A = A
        self.sigma = sigma
        self.max_iter = max_iter
        self.NX = NX
        self.NY = NY
        self.tolerance = tolerance
        self.epsilon = epsilon
        self.printInit = printInit
        self.earlyStop = earlyStop
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")

        if self.printInit:
            print(f"Device: {self.device}")
            print(f"A = {self.A}, sigma = {self.sigma}, max_iter = {self.max_iter}, NX = {self.NX}, NY = {self.NY}")
            print(f"tolerance = {self.tolerance}, epsilon = {self.epsilon}")

        # To be computed
        self.x = None
        self.y = None
        self.p_x = None
        self.p_y = None
        self.q_x_given_y = None
        self.p_y_given_x = None
        self.p_x_records = []
        self.p_y_records = []
        self.iter = None
        self.capacity = None
        self.theoretical_c = None
        self.mean_power = None
        self.recordFrequency = 50

        self.result = None

    def computeChannelMatrix(self):
        """
        Compute the transition probabilities P(y|x) for a Gaussian noise channel.
        Returns:
            p_y_given_x (2D tensor): Transition probability matrix P(y|x).
        """
        diff_x_y = (self.x.view(1, -1) - self.y.view(-1, 1))  # Compute pairwise differences

        # Compute half-width of interval (dy)
        dy = (self.y[1] - self.y[0]) * 0.5 if self.y.size(0) > 1 else 0.5  # Handle single-element case

        # Compute probabilities using the standard normal CDF
        normal_dist = Normal(0, self.sigma)
        p_y_given_x = normal_dist.cdf(diff_x_y + dy) - normal_dist.cdf(diff_x_y - dy)

        # Normalize columns to sum to 1 and avoid numerical instability
        p_y_given_x = torch.clamp(p_y_given_x, min=self.epsilon)
        p_y_given_x /= p_y_given_x.sum(dim=0, keepdim=True)

        self.p_y_given_x = p_y_given_x.to(self.device)

    def computeChannelCapacity(self):
        log_ratio = torch.log(torch.clamp(self.p_y_given_x / self.p_y.view(-1, 1), min=self.epsilon))
        self.capacity = torch.sum(self.p_x.flatten() * torch.sum(self.p_y_given_x * log_ratio, dim=0)).item()

    @staticmethod
    def theoreticalCapacity(P=1, _sigma=1):
        return 0.5 * torch.log2(1 + torch.tensor(P / _sigma**2)).item()

    def getMeanPower(self):
        self.mean_power = torch.sum(self.p_x * (self.x**2)).item()
        return self.mean_power

    def clearRecords(self):
        self.p_x_records = []
        self.p_y_records = []
        self.iter = None
        self.capacity = None
        self.theoretical_c = None
        self.mean_power = None

    def runAlgorithm(self, recordFrequency=50):
        self.clearRecords()
        # Discretize the input and output alphabets
        self.x = torch.linspace(-self.A, self.A, self.NX, device=self.device)
        self.y = torch.linspace(-self.A - 4 * self.sigma, self.A + 4 * self.sigma, self.NY, device=self.device)
        self.p_x = torch.ones((self.NX, 1), device=self.device) / self.NX  # Step 1 - Uniform initial input distribution

        self.computeChannelMatrix()  # Transition probabilities P(y|x)

        if self.printInit:
            print(f"Running on device: {self.device}")
            print(f"A = {self.A}, sigma = {self.sigma}, max_iter = {self.max_iter}, NX = {self.NX}, NY = {self.NY}")
            print(f"Initial input distribution p(x): max= {self.p_x.max()} min = {self.p_x.min()}")
            print(f"Transition probabilities p(y|x): max = {self.p_y_given_x.max()} min = {self.p_y_given_x.min()}")

        # Blahut-Arimoto iterations
        for iter in tqdm(range(self.max_iter), desc="Iterations", ncols=50):
            # Compute P(y)
            self.p_y = torch.matmul(self.p_y_given_x, self.p_x)  # Add epsilon for numerical stability
            self.p_y /= torch.sum(self.p_y)  # Normalize

            # Update q(x|y)
            self.q_x_given_y = (self.p_y_given_x * self.p_x.T).T / self.p_y.T
            self.q_x_given_y /= self.q_x_given_y.sum(dim=0, keepdim=True)  # Normalize rows

            # Update p(x)
            p_x_new = torch.exp(
                torch.sum(self.p_y_given_x * torch.log(torch.clamp(self.q_x_given_y.T, min=self.epsilon)), dim=0)
            )
            p_x_new /= torch.sum(p_x_new)  # Normalize

            # Check for convergence
            if self.earlyStop and torch.max(torch.abs(p_x_new - self.p_x)) < self.tolerance:
                print(f"Converged in {iter + 1} iterations.")
                break

            self.p_x = p_x_new

            # Record intermediate results
            if iter % recordFrequency == 0:
                self.p_x_records.append(self.p_x.clone())
                self.p_y_records.append(self.p_y.clone())

        self.computeChannelCapacity()
        self.getMeanPower()



        return {
            "A": self.A,
            "capacity": self.capacity,
            "theoretical_c": self.theoreticalCapacity(self.A, self.sigma),
            "x": self.x.cpu(),
            "y": self.y.cpu(),
            "p_x": self.p_x.cpu(),
            "p_y": self.p_y.cpu(),
            "q_x_given_y": self.q_x_given_y.cpu(),
            "p_y_given_x": self.p_y_given_x.cpu(),
            "p_x_records": [record.cpu() for record in self.p_x_records],
            "p_y_records": [record.cpu() for record in self.p_y_records],
            "iter": iter + 1,
            "mean_power": self.mean_power,
        }


## Visualizer

In [5]:
class VisualizeBA():
    def __init__(self) -> None:
        self.records = []
        self.As = []
        self.capacities = []
        self.theoretical_capacitys = []

    def addData(self, data):
        self.As.append(data["A"])
        self.capacities.append(data["capacity"])
        self.records.append(data)
        self.theoretical_capacitys.append(data["theoretical_c"])

    def clearData(self):
        self.records = []
        self.As = []
        self.capacities = []
        self.theoretical_capacitys = []

    def plot_A_vs_capacity(self):
        fig = go.Figure()

        fig.add_trace(go.Scatter(
            x = self.As,
            y = self.capacities,
            mode='lines',
            name='Channel Capacity'
        ))

        fig.update_layout(
            title="Channel Capacity vs. Peak Power Constraint",
            xaxis_title="A (Peak Power Constraint)",
            yaxis_title="Channel Capacity (bits)",
            template="plotly_dark",
            showlegend=False,
            autosize=True,
        )

        fig.show()

        # Print results
        for i in range(len(self.As)):
            print(f"A = {self.As[i]:.2f} | Channel Capacity = {self.capacities[i]:.3f}")


    def plot_capcity_vs_theoretical_capacity(self):
        fig = go.Figure()

        # Add the computed capacity
        fig.add_trace(go.Scatter(
            x=self.As,
            y=self.capacities,
            mode='lines',
            name='Computed Capacity'
        ))

        # Add the theoretical capacity
        fig.add_trace(go.Scatter(
            x=self.As,
            y=self.theoretical_capacitys,
            mode='lines',
            name='Theoretical Capacity'
        ))

        fig.update_layout(
            title="Channel Capacity vs. Peak Power Constraint",
            xaxis_title="A (Peak Power Constraint)",
            yaxis_title="Capacity (bits)",
            template="plotly_dark",
            showlegend=True,
            autosize=True,
        )

        fig.show()

        # Print results
        for i in range(len(self.As)):
            print(f"A = {self.As[i]:.2f} | Channel Capacity = {self.capacities[i]:.3f} Theoretical = {self.theoretical_capacitys[i]:.3f}")

    def plot_A_vs_p_x(self):
        outputs = self.records
        fig = go.Figure()

        for i, output in enumerate(outputs):
            fig.add_trace(go.Scatter(
                x=output["x"],    # x-axis: the input symbols (discretized x)
                y=output["p_x"],  # y-axis: the probability distribution p(x)
                mode='lines',
                name=f"A = {(output['A']):.2f}",
            ))

        fig.update_layout(
            title="Input Probability Distributions for Different A",
            xaxis_title="x (Input Symbol)",
            yaxis_title="p(x) (Probability Distribution)",
            template="plotly_dark",
            showlegend=True,
            autosize=True,
        )

        fig.show()


    def plot_A_vs_p_x_dynamic(self):
        outputs = self.records
        fig = go.Figure()

        # Add traces for each A, initially all hidden
        for i, output in enumerate(outputs):
            fig.add_trace(go.Scatter(
                x=output["x"],    # x-axis: the input symbols (discretized x)
                y=output["p_x"],  # y-axis: the probability distribution p(x)
                mode='lines',
                name=f"A = {output['A']:.2f}",
                visible='legendonly'  # Initially, all traces are hidden
            ))

        # Create the slider steps to control the visibility of traces
        steps = []
        for i, output in enumerate(outputs):
            step = dict(
                method="update",
                args=[{"visible": [False] * len(outputs)},  # Hide all traces
                      {"title": f"A = {output['A']:.2f}"}],  # Set the title to the selected A value
            )
            step["args"][0]["visible"][i] = True  # Make the current trace visible
            steps.append(step)

        # Add the slider to control A
        fig.update_layout(
            title="Input Probability Distributions for Different A",
            xaxis_title="x (Input Symbol)",
            yaxis_title="p(x) (Probability Distribution)",
            template="plotly_dark",
            showlegend=True,
            sliders=[dict(
                currentvalue={"prefix": f"A = ", "visible": True, "xanchor": "center"},
                steps=steps
            )]
        )

        # Show the dynamic plot
        fig.show()

        return fig


    def plot_A_vs_p_y(self):
        fig = go.Figure()

        for i, output in enumerate(self.records):
            fig.add_trace(go.Scatter(
                x=output["y"],  # x-axis: the output symbols (discretized y)
                y=output["p_y"],  # y-axis: the probability distribution p(y)
                mode='lines',
                name=f"A = {self.As[i]:.2f}",
            ))

        fig.update_layout(
            title="Output Probability Distributions for Different A",
            xaxis_title="y (Output Symbol)",
            yaxis_title="p(y) (Probability Distribution)",
            template="plotly_dark",
            showlegend=True,
            autosize=True,
        )

        fig.show()


    def plot_p_x_records(self, idx=-1):
        """   !!!!STATIC!!!
        Example usage: Assuming p_x_records is a list of probability distributions for each iteration
        p_x_records = [p_x_iteration_1, p_x_iteration_2, ..., p_x_iteration_n]"""

        fig = go.Figure()

        output = self.records[idx]

        for i, p_x in enumerate(output["p_x_records"]):
            fig.add_trace(go.Scatter(
                x=output["x"],    # x-axis: the input symbols (discretized x)
                y=p_x,  # y-axis: the probability distribution p(x)
                mode='lines',
                name=f"Record {(i+1)}",
            ))

        fig.update_layout(
            title="Input Probability Distributions Over Iterations",
            xaxis_title="x (Input Symbol)",
            yaxis_title="p(x) (Probability Distribution)",
            template="plotly_dark",
            showlegend=True,
            autosize=True,
        )

        fig.show()

    def plot_p_x_records_dynamic(self, idx=-1):
        """Example usage: Assuming p_x_records is a list of probability distributions for each iteration
        p_x_records = [p_x_iteration_1, p_x_iteration_2, ..., p_x_iteration_n]"""

        # Create an empty figure
        fig = go.Figure()

        output = self.records[idx]

        # Add traces for each iteration (initially all traces hidden)
        for i, p_x in enumerate(output["p_x_records"]):
            fig.add_trace(go.Scatter(
                x=output["x"],  # x-axis: the input symbols (discretized x)
                y=p_x,  # y-axis: the probability distribution p(x)
                mode='lines',  # Connect the points with lines
                name=f"A = {output['A']:.2f}, Record {(i+1)}",  # Label for the iteration
                visible='legendonly',  # Initially set to not visible
            ))

        # Create the slider steps (one per iteration)
        steps = []
        for i in range(len(output["p_x_records"])):
            step = dict(
                method="update",
                args=[{"visible": [False] * len(output["p_x_records"])},  # Hide all traces
                    {"title": f"Record {(i+1)}"}],  # Update the title
            )
            step["args"][0]["visible"][i] = True  # Make the current iteration visible
            steps.append(step)

        # Add the slider to the layout
        fig.update_layout(
            title="Input Probability Distributions Over Iterations",
            xaxis_title="x (Input Symbol)",
            yaxis_title=f"p(x) (Probability Distribution)",
            template="plotly_dark",  # Optional: use a dark theme
            showlegend=True,
            sliders=[dict(
                currentvalue={"prefix": f"Iteration: ", "visible": True, "xanchor": "center"},
                steps=steps
            )]
        )

        # Show the dynamic plot
        fig.show()

        # print
        print("==========")
        print(f"Total Iterations: {(output['iter'])}")
        print(f"A = {output['A']:.2f}")
        print(f"Channel Capacity = {output['capacity']:.3f}")
        print(f"Theoretical Capacity = {output['theoretical_c']:.3f}")
        print(f"Mean power = {output['mean_power']:.3f}")


        def savePlot(self, fig, filename, ext="html"):
            # Save the interactive graph as an HTML file
            fig.write_html(f"{filename}.{ext}")




## Experiment

In [8]:
class Experiment(BlahutArimoto, VisualizeBA):
    def __init__(self, NX=500, NY=1000, sigma=1, max_iter=10000, tolerance=1e-6, epsilon=1e-12,
                 printInit=False, earlyStop=True, device=None):
        # Initialize both parents' constructors
        BlahutArimoto.__init__(self, sigma=sigma, max_iter=max_iter, NX=NX, NY=NY, tolerance=tolerance,
                               epsilon=epsilon, printInit=printInit, earlyStop=earlyStop, device=device)
        print("BlahutArimoto initialized")

        VisualizeBA.__init__(self)
        print("Visualizer initialized")

        print("Experiment initialized")

    def run(self, As):
        self.clearData()
        print("================")
        print(f"Experiment started for {len(As)} many As")
        print(f"device {self.device}")
        print(f"NX = {self.NX}, NY = {self.NY}")
        print(f"Tolerance = {self.tolerance}, epsilon = {self.epsilon}")
        print(f"Max Iterations = {self.max_iter}")
        print("================")
        with torch.no_grad():
            for Ai in As:
                self.A = Ai
                result = self.runAlgorithm()
                self.addData(result)
        print("================")
        print("Experiment finished")
        print("================")


    def appendRunHistory(self, anotherExperiment: "Experiment"):
        if not isinstance(anotherExperiment, Experiment):
            raise TypeError(f"Expected an instance of 'Experiment', but got {type(anotherExperiment).__name__}.")

        self.records.extend(anotherExperiment.records)
        self.As.extend(anotherExperiment.As)
        self.capacities.extend(anotherExperiment.capacities)



# Playground

In [None]:
# As = [1]
# As = [0.1, 1, 10]
As = torch.logspace(-1.1, 1, 100, base=10)
exp = Experiment(NX=1000, NY=4000, max_iter=10000)
exp.run(As)

In [None]:
fig = exp.plot_A_vs_p_x_dynamic()
fig.write_html("plot_A_vs_p_x.html")

In [None]:
exp.plot_A_vs_capacity()