In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat

import numpy as np

def extracting_shots(data, header, shot_num, p):
    """
    Extracts a single seismic shot gather.

    Parameters:
        data (ndarray): Complete seismic data matrix.
        header (list): Seismic data header information.
        shot_num (int): The shot gather number to extract.
        p (int): 0 for offset (default) or 1 for trace numbers.

    Returns:
        Dshot (ndarray): Extracted single shot gather.
        dt (float): Time sampling interval in seconds.
        dx (float): Spatial sampling interval.
        t (ndarray): Time vector.
        offset (ndarray): Offset or trace number vector.
    """
    nt, nx = data.shape
    XX = [h['fldr'][0][0] for h in header]  # Extract shot numbers
    dt = header[0]['dt'][0][0] / 1e6  # Convert dt to seconds
    t = np.arange(0, nt * dt, dt)  # Time vector

    # Extract indices for the specified shot number
    indices = [i for i, x in enumerate(XX) if x == shot_num]
    if not indices:
        raise ValueError(f"Shot number {shot_num} not found in the header.")

    # Extract offsets and corresponding data
    offset = np.array([header[i]['offset'][0][0] for i in indices])
    Dshot = data[:, indices]

    # Calculate spatial sampling interval
    dx = offset[1] - offset[0] if len(offset) > 1 else 1

    # Use trace numbers instead of offsets if p == 1
    if p == 1:
        offset = np.arange(1, Dshot.shape[1] + 1)

    return Dshot, dt, dx, t, offset

# Define the mwigb function
def mwigb(data, scale, x, z):
    """
    Plot seismic data using wiggle plot.
    """
    nz, nx = data.shape
    dx = np.median(np.diff(x))
    dz = np.median(np.diff(z))
    data = data * scale / np.max(np.abs(data))
    
    plt.figure(figsize=(12, 6))
    for i in range(nx):
        trace = data[:, i]
        trace_offset = x[i]
        trace_max = np.max(np.abs(trace))
        trace = trace / trace_max * dx * 0.5
        plt.fill_betweenx(z, trace_offset, trace_offset + trace, where=(trace > 0), color='black', alpha=0.5)
        plt.plot(trace_offset + trace, z, color='black', linewidth=0.5)
    plt.xlabel('Offset (ft)' if len(x) > 1 else 'Trace Number', fontsize=14)
    plt.ylabel('Time (s)', fontsize=14)
    plt.gca().invert_yaxis()
    plt.grid(which='both', linestyle='--', alpha=0.5)
    

import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat

# Define the Independent Amplitude Correction (IAC) Function
def iac(data, time, power, T):
    """
    Perform Independent Amplitude Correction (IAC) on seismic data.

    Parameters:
        data (2D ndarray): Seismic shot gather data.
        time (1D ndarray): Time vector.
        power (float): Power value for correction.
        T (int): Time correction method (0 for power, 1 for exponential).

    Returns:
        corrected_data (2D ndarray): Gain-corrected seismic data.
    """
    nt, nx = data.shape
    correction = np.zeros((nt, 1))

    if T == 0:  # Power correction
        correction = (time[:, None] ** power)
    elif T == 1:  # Exponential correction
        correction = np.exp(power * time[:, None])

    corrected_data = data * correction
    return corrected_data

# Define the Envelope Plotting Function
def seis_env_dB(original_data, processed_data, time, trace_num):
    """
    Plot amplitude envelopes of a seismic trace before and after gain correction.

    Parameters:
        original_data (2D ndarray): Original seismic data.
        processed_data (2D ndarray): Gain-corrected seismic data.
        time (1D ndarray): Time vector.
        trace_num (int): Trace number to analyze.
    """
    # Compute the absolute amplitude envelopes for the given trace
    original_trace = np.abs(original_data[:, trace_num])
    processed_trace = np.abs(processed_data[:, trace_num])

    # Plot the amplitude envelopes in decibels (dB)
    plt.figure(figsize=(10, 6))
    plt.plot(time, 20 * np.log10(original_trace), label="Original", color="blue")
    plt.plot(time, 20 * np.log10(processed_trace), label="Processed", color="red")
    plt.xlabel("Time (s)", fontsize=14)
    plt.ylabel("Amplitude (dB)", fontsize=14)
    plt.title(f"Amplitude Envelope - Trace {trace_num}", fontsize=16)
    plt.legend()
    plt.grid()
    plt.show()




# Define Frequency-Space Representation (f-x)
def fx(data, dt):
    """
    Computes the frequency-space (f-x) representation of seismic data.

    Parameters:
        data (2D ndarray): Seismic data matrix.
        dt (float): Time sampling interval (s).

    Returns:
        Data_f (2D ndarray): f-x representation of the seismic data.
        f (1D ndarray): Frequency vector (Hz).
    """
    nt, nx = data.shape
    Data_f = fft(data, axis=0)  # FFT along time (rows)
    f = np.fft.fftfreq(nt, d=dt)  # Frequency vector
    return Data_f, f

# Define Frequency-Wavenumber Representation (f-kx)
def fk(data, dt, dx):
    """
    Computes the frequency-wavenumber (f-kx) representation of seismic data.

    Parameters:
        data (2D ndarray): Seismic data matrix.
        dt (float): Time sampling interval (s).
        dx (float): Spatial sampling interval (ft).

    Returns:
        Data_fk (2D ndarray): f-kx representation of the seismic data.
        f (1D ndarray): Frequency vector (Hz).
        kx (1D ndarray): Wavenumber vector (1/ft).
    """
    nt, nx = data.shape
    Data_f = fft(data, axis=0)  # FFT along time (rows)
    Data_fk = fft(Data_f, axis=1)  # FFT along space (columns)

    # Frequency vector
    f = np.fft.fftfreq(nt, d=dt)

    # Wavenumber vector
    kx = np.fft.fftfreq(nx, d=dx)

    return Data_fk, f, kx