In [None]:
import numpy as np
from matplotlib import pyplot as plt
from scipy.signal import spectrogram

In [None]:
T = np.pi * 2
dt = 0.01
n = 4
phases = np.linspace(start=0, stop=np.pi, num=n)
frequencies = np.ones((n,)) * 2

In [None]:
for i in range(n):
    plt.plot(np.sin(
        np.arange(start=0, stop=T, step=dt) * frequencies[i] + phases[i]
    ))

In [None]:
class SimpleOscillator:
    def __init__(self, frequency, phase, dt):
        self.frequency = frequency
        self.dt = dt
        self.initial_phase = phase
        self.current_phase = phase
        
    def __call__(self):
        value = np.sin(self.current_phase)
        self.current_phase += self.frequency * self.dt

        return value

for i in range(n):
    oscillator = SimpleOscillator(frequency=frequencies[i], phase=phases[i], dt=dt)
    plt.plot([oscillator() for _ in range(int(T / dt))])

In [None]:
class MultipleSimpleOscillators:
    def __init__(self, frequencies, phases, dt):
        self.frequencies = frequencies  # (n,)
        self.dt = dt
        self.initial_phases = phases  # (n,)
        self.current_phases = phases  # (n,)
        
    def __call__(self):
        value = np.sin(self.current_phases)
        self.current_phases += self.frequencies * self.dt

        return value

oscillators = MultipleSimpleOscillators(frequencies, phases, dt)
waves = np.array([oscillators() for _ in range(int(T / dt))])

for i in range(n):
    plt.plot(waves[:, i])

In [None]:
class SimpleKuramotoOscillators:
    def __init__(self, frequencies, phases, dt, coupling_strength):
        self.frequencies = frequencies  # (n,)
        self.dt = dt
        self.coupling_strength = coupling_strength
        self.phases = phases.copy()  # (n,)
        
    def __call__(self):
        value = np.sin(self.phases)

        # d_phase = self.frequencies   # <-- vanilla oscillators way
        d_phase = self.frequencies + (
            self.coupling_strength / len(self.frequencies) 
            * np.sin(self.phases - self.phases[:, np.newaxis])
        ).sum(axis=1)

        self.phases += d_phase * self.dt

        return value

oscillators = SimpleKuramotoOscillators(
    # frequencies=np.ones((n,)) * 10, 
    frequencies=np.linspace(start=20, stop=40, num=n), 
    phases=np.linspace(start=0, stop=2 * np.pi, num=n), 
    dt=0.01,
    coupling_strength=15.9,
)
waves = np.array([oscillators() for _ in range(4 * int(T / oscillators.dt))])

plt.figure(figsize=(19, 2))
for i in range(n):
    plt.plot(waves[:1000, i], label=i)
plt.legend()

In [None]:
plt.figure(figsize=(19, 2))
for i in range(n):
    plt.plot(waves[-1000:, i], label=i)
plt.legend()

In [None]:
nperseg = 630
noverlap = nperseg // 4
nfft = 4000

for i in range(n):
    freq, segments, sxx = spectrogram(waves[:, i], nperseg=nperseg, noverlap=noverlap, nfft=nfft)
    if i == 0:
        print(freq.shape, segments.shape, sxx.shape)
        plt.imshow(sxx[50:150], interpolation='nearest', aspect='auto')
        plt.show()
        plt.close()
    plt.plot(sxx.max(axis=0))

In [None]:
def get_frequencies(wave, dt):
    power = np.fft.fft(wave)
    freq = np.fft.fftfreq(len(power), dt / (np.pi * 2))
    
    power = np.abs(power)
    
    power = power[freq >= 0]
    freq = freq[freq >= 0]
    
    power = power[freq < 20]
    freq = freq[freq < 20]
    
    return power, freq

In [None]:
for i in range(n):
    power, freq = get_frequencies(waves[:, i], oscillators.dt)
    plt.plot(freq, power, label=i)
plt.legend()

In [None]:
nperseg = 1000
noverlap = nperseg // 4
nfft = 4000

for i in range(n):
    freq, segments, sxx = spectrogram(waves[:, i], nperseg=nperseg, noverlap=noverlap, nfft=nfft)
    if i == 0:
        print(freq.shape, segments.shape, sxx.shape)
        plt.imshow(sxx, interpolation='nearest', aspect='auto')
        plt.show()
        plt.close()
    plt.plot(sxx.max(axis=0))

In [None]:
freq[266]

In [None]:
plt.plot(sxx.max(axis=0))

In [None]:
freq, segments, sxx2 = spectrogram(waves[:, 1], nperseg=200, nfft=4000)
print(freq.shape, segments.shape, sxx2.shape)
plt.imshow(sxx2[:100, :], interpolation='nearest', aspect='auto')

In [None]:
plt.imshow(sxx2[:100, :] + sxx[:100, :], interpolation='nearest', aspect='auto')