# Tinnitus InfoMax quick demo (Python)

Fast, self-contained notebook that mirrors the MATLAB InfoMax workflow:
- input generation (tones + noise)
- attenuation profile
- lightweight learning loop
- diagnostic plots (attenuation, cost, response histograms)


In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt

rng = np.random.default_rng(7)


In [None]:
def _normal_pdf(x, mean, sigma):
    return (1.0 / (sigma * math.sqrt(2 * math.pi))) * np.exp(-0.5 * ((x - mean) / sigma) ** 2)


def tonotopic_map(inputs, outputs):
    f_num = 50
    f_min = 20
    f_max = f_min + (f_num - 1) * 10
    f_inp = np.linspace(f_min, f_max, inputs)
    a = 0.01

    sigma = (f_inp[-1] - f_inp[0]) / (f_num - 1)
    f_out = np.linspace(f_inp[0], f_inp[-1], outputs)

    w = np.zeros((outputs, len(f_inp)))
    for i in range(outputs):
        w[i, :] = _normal_pdf(f_inp, f_out[i], sigma)

    return a * w / (np.sum(w) / outputs)


def sigmoid(x, with_derivative=False):
    g = 1.0 / (1.0 + np.exp(-x))
    if with_derivative:
        gp = g * (1 - g)
        return g, gp
    return g


In [None]:
class Infomax:
    def __init__(
        self,
        inputs,
        outputs,
        threshold=True,
        ff_eta=0.1,
        rec_eta=0.001,
        ff_ridge=0.001,
        rec_ridge=0.209,
        beta=0.0,
        niter=400,
        alpha=0.2,
        tolfun=1e-6,
    ):
        self.Inputs = inputs
        self.Outputs = outputs
        self.Threshold = threshold
        self.FF_eta = ff_eta
        self.Rec_eta = rec_eta
        self.FF_ridge = ff_ridge
        self.Rec_ridge = rec_ridge
        self.Beta = beta
        self.niter = niter
        self.alpha = alpha
        self.tolfun = tolfun
        self.reset_connections()

    def reset_connections(self):
        tmp = tonotopic_map(self.Inputs, self.Outputs)
        if self.Threshold:
            self.W = np.zeros((self.Outputs, self.Inputs + 1))
            self.W[:, : self.Inputs] = tmp
        else:
            self.W = tmp
        self.K = np.zeros((self.Outputs, self.Outputs))

    def _gfunc(self, x):
        g = sigmoid(x)
        gp = g * (1 - g)
        gpp = gp * (1 - 2 * g)
        return g, gp, gpp

    def _gcalc(self, x):
        h = self.W @ x
        if np.all(self.K == 0):
            return self._gfunc(h)

        g = sigmoid(np.zeros(self.Outputs))
        for _ in range(self.niter):
            g1, gp1 = sigmoid(h + self.K @ g, with_derivative=True)
            g_mat = np.diag(gp1)
            psi = np.eye(self.Outputs) - g_mat @ self.K
            delta = np.linalg.solve(psi, g1 - g)
            g = g + delta
            if np.mean(np.abs(delta)) < self.tolfun:
                break

        h = h + self.K @ g
        return self._gfunc(h)

    def evaluate(self, x):
        if x.ndim == 1:
            x = x[:, None]
        if self.Threshold:
            x = np.vstack([x, -np.ones((1, x.shape[1]))])
        g = np.zeros((self.Outputs, x.shape[1]))
        gp = np.zeros_like(g)
        gpp = np.zeros_like(g)
        for i in range(x.shape[1]):
            g[:, i], gp[:, i], gpp[:, i] = self._gcalc(x[:, i])
        return g, gp, gpp

    def learn(self, x):
        if x.ndim == 1:
            x = x[:, None]
        v = self.W
        r = self.K
        d_w = np.zeros_like(v)
        d_k = np.zeros_like(r)
        i_n = np.eye(self.Outputs)

        for nsample in range(x.shape[1]):
            x1 = x[:, nsample]
            s, gp, gpp = self.evaluate(x1)
            s = s[:, 0]
            gp = gp[:, 0]
            gpp = gpp[:, 0]

            phi = np.diag(gp)
            if not np.all(r == 0):
                phi = np.linalg.solve(i_n - phi @ r, phi)
            chi = phi @ v[:, : self.Inputs]
            chi_pinv = np.linalg.pinv(chi)
            xi = i_n if self.Outputs <= self.Inputs else chi @ chi_pinv

            with np.errstate(divide="ignore", invalid="ignore"):
                gamma = np.diag(xi @ phi) * gpp / (gp**3)
                gamma = np.nan_to_num(gamma, nan=0.0, posinf=0.0, neginf=0.0)

            if self.FF_eta != 0:
                d_v = phi.T @ (chi_pinv.T + np.outer(gamma, x1[: self.Inputs]))
                if self.Threshold:
                    d_v = np.concatenate([d_v, -phi.T @ gamma[:, None]], axis=1)
                d_w += d_v

            if self.Rec_eta != 0:
                d_k += phi.T @ (xi + np.outer(gamma, s))

        n_samples = x.shape[1]
        if self.FF_eta != 0:
            if self.FF_ridge != 0:
                self.W = self.W - (self.FF_eta * self.FF_ridge * np.sign(self.W))
            self.W = self.W + (self.FF_eta / n_samples) * d_w
        if self.Rec_eta != 0:
            if self.Rec_ridge != 0:
                self.K = self.K - (self.Rec_eta * self.Rec_ridge * self.K)
            np.fill_diagonal(d_k, 0)
            self.K = self.K + (self.Rec_eta / n_samples) * d_k

    def cost(self, x):
        if x.ndim == 1:
            x = x[:, None]
        if self.Threshold:
            x = np.vstack([x, -np.ones((1, x.shape[1]))])
        cost = 0.0
        i_n = np.eye(self.Outputs)
        no_k = np.all(self.K == 0)
        for nsample in range(x.shape[1]):
            _, gp, _ = self._gcalc(x[:, nsample])
            phi = np.diag(gp)
            if not no_k:
                phi = np.linalg.solve(i_n - phi @ self.K, phi)
            chi = phi @ self.W
            svd_vals = np.linalg.svd(chi, compute_uv=False)
            cost += np.sum(np.log(svd_vals + 1e-12))
        return -cost / x.shape[1]


In [None]:
def generate_input(inputs, n_samples, n_tones_max=5, a_min=7.0, a_max=10.0, noise_fac=0.5, white_noise=0.5):
    x = np.zeros((inputs, n_samples))
    f_inp = np.arange(1, inputs + 1)

    for sample_count in range(n_samples):
        n_tones = rng.integers(1, n_tones_max + 1)
        sound = inputs * rng.random(n_tones)
        sound_std = 0.5 * inputs * np.abs(rng.standard_normal(n_tones))
        amp = a_min + (a_max - a_min) * rng.random(n_tones)
        x1 = np.zeros(inputs)
        for k in range(n_tones):
            x1 += amp[k] * math.sqrt(2 * math.pi) * sound_std[k] * _normal_pdf(f_inp, sound[k], sound_std[k])
        x[:, sample_count] = x1

    noise = noise_fac * (2 * rng.random(x.shape) - 1)
    x = x + noise + white_noise
    return 0.5 * x / np.max(np.abs(x))


def attenuate_inputs(inputs, beta=10.0, f_0=20, minval=0.0):
    j = np.arange(1, inputs.shape[0] + 1)
    scale = minval + (1 - minval) * (1.0 / (1.0 + np.exp(-beta * (f_0 - j))))
    return inputs * scale[:, None], scale


In [None]:
inputs = 32
outputs = 64
n_samples = 2000
batch_size = 16
steps = 220
attenuate_step = 140

x = generate_input(inputs, n_samples)
x_att, attenuation_curve = attenuate_inputs(x, beta=10.0, f_0=inputs // 2, minval=0.2)

net = Infomax(inputs, outputs)
cost_history = []

for step in range(steps):
    batch_source = x if step < attenuate_step else x_att
    idx = rng.choice(batch_source.shape[1], size=batch_size, replace=False)
    net.learn(batch_source[:, idx])
    if step % 10 == 0:
        cost_history.append(net.cost(batch_source[:, :100]))

cost_history = np.array(cost_history)


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(attenuation_curve, color="tab:purple")
axes[0].set_title("Attenuation profile")
axes[0].set_xlabel("Frequency index")
axes[0].set_ylabel("Scale")

axes[1].plot(np.arange(len(cost_history)) * 10, cost_history, marker="o")
axes[1].axvline(attenuate_step, color="gray", linestyle="--", label="Attenuation on")
axes[1].set_title("Cost during learning")
axes[1].set_xlabel("Learning step")
axes[1].set_ylabel("Cost")
axes[1].legend()

g_pre, _, _ = net.evaluate(x[:, :200])
g_post, _, _ = net.evaluate(x_att[:, :200])
axes[2].hist(g_pre.flatten(), bins=40, alpha=0.6, label="Before attenuation")
axes[2].hist(g_post.flatten(), bins=40, alpha=0.6, label="After attenuation")
axes[2].set_title("Output activation distribution")
axes[2].set_xlabel("Activation")
axes[2].set_ylabel("Count")
axes[2].legend()

plt.tight_layout()
plt.show()
