In [None]:
import numpy as np

def fmridesign(frametimes, slicetimes=None, events=None, S=None, 
               hrf_parameters=None, shift=None):
    # Defaults
    if slicetimes is None:
        slicetimes = np.array([0])
    if events is None:
        events = np.array([[1, 0]])
    if S is None:
        S = np.empty((0, 0))
    if hrf_parameters is None:
        hrf_parameters = np.array([5.4, 5.2, 10.8, 7.35, 0.35])
    if shift is None:
        if isinstance(hrf_parameters, dict):
            shift = np.array([[-4.5, 4.5]])
        else:
            shift = np.array([[-4.5, 4.5]]) * max(max(hrf_parameters[:, 1]) / 5.2, 1)

    n = len(frametimes)
    numslices = len(slicetimes)

    # Keep time points that are not excluded:
    if events.size > 0:
        numevents = events.shape[0]
        eventid = events[:, 0]
        numeventypes = int(np.max(eventid))
        eventime = events[:, 1]
        if events.shape[1] >= 3:
            duration = events[:, 2]
        else:
            duration = np.zeros(numevents)
        if events.shape[1] >= 4:
            height = events[:, 3]
        else:
            height = np.ones(numevents)
        mineventime = np.min(eventime)
        maxeventime = np.max(eventime + duration)
    else:
        numeventypes = 0
        mineventime = np.inf
        maxeventime = -np.inf

    if S.size > 0:
        numcolS = S.shape[1]
    else:
        numcolS = 0

    # Set up response matrix:
    dt = 0.02
    startime = np.min([mineventime, np.min(frametimes) + np.min([slicetimes, 0])])
    finishtime = np.max([maxeventime, np.max(frametimes) + np.max([slicetimes, 0])])
    numtimes = int(np.ceil((finishtime - startime) / dt)) + 1
    numresponses = numeventypes + numcolS
    response = np.zeros((numtimes, numresponses))

    if events.size > 0:
        height = height / (1 + (duration == 0) * (dt - 1))
        for k in range(numevents):
            type = int(eventid[k]) - 1  # MATLAB is 1-indexed, Python is 0-indexed
            n1 = int(np.ceil((eventime[k] - startime) / dt)) + 1
            n2 = int(np.ceil((eventime[k] + duration[k] - startime) / dt)) + (duration[k] == 0)
            if n2 >= n1:
                response[n1 - 1:n2, type] += height[k] * np.ones(n2 - n1 + 1)

    if S.size > 0:
        for j in range(numcolS):
            for i in np.where(S[:, j])[0]:
                n1 = int(np.ceil((frametimes[i] - startime) / dt)) + 1
                if i < n - 1:
                    n2 = int(np.ceil((frametimes[i + 1] - startime) / dt))
                else:
                    n2 = numtimes
                if n2 >= n1:
                    response[n1 - 1:n2, numeventypes + j] += S[i, j] * np.ones(n2 - n1 + 1)

    # Handle hrf_parameters
    if isinstance(hrf_parameters, dict):
        if hrf_parameters['T'].shape[0] == 1:
            hrf_parameters['T'] = np.tile(hrf_parameters['T'], (numresponses, 1))
        if hrf_parameters['H'].shape[0] == 1:
            hrf_parameters['H'] = np.tile(hrf_parameters['H'], (numresponses, 1))
    else:
        if hrf_parameters.shape[0] == 1:
            hrf_parameters = np.tile(hrf_parameters, (numresponses, 1))
    
    if shift.shape[0] == 1:
        shift = np.tile(shift, (numresponses, 1))

    eventmatrix = np.zeros((numtimes, numresponses, 4))
    nd = 41
    X_cache_W = np.zeros((nd, numresponses, 5))

    for k in range(numresponses):
        Delta1 = shift[k, 0]
        Delta2 = shift[k, 1]
        if isinstance(hrf_parameters, dict):
            numlags = int(np.ceil((np.max(hrf_parameters['T'][k, :]) + Delta2 - Delta1) / dt)) + 1
        else:
            peak1 = hrf_parameters[k, 0]
            fwhm1 = hrf_parameters[k, 1]
            peak2 = hrf_parameters[k, 2]
            fwhm2 = hrf_parameters[k, 3]
            dip = hrf_parameters[k, 4]
            numlags = int(np.ceil((np.max([peak1 + 3 * fwhm1, peak2 + 3 * fwhm2]) + Delta2 - Delta1) / dt)) + 1
        
        numlags = min(numlags, numtimes)
        time = np.arange(numlags) * dt

        # Taylor:
        if isinstance(hrf_parameters, dict):
            hrf = np.interp(time, hrf_parameters['T'][k, :], hrf_parameters['H'][k, :], left=0, right=0)
            d_hrf = -np.gradient(hrf, dt)
        else:
            tinv = (time > 0) / (time + (time <= 0))
            if peak1 > 0 and fwhm1 > 0:
                alpha1 = peak1**2 / fwhm1**2 * 8 * np.log(2)
                beta1 = fwhm1**2 / peak1 / 8 / np.log(2)
                gamma1 = (time / peak1) ** alpha1 * np.exp(-(time - peak1) / beta1)
                d_gamma1 = -(alpha1 * tinv - 1 / beta1) * gamma1
            else:
                gamma1 = (np.abs(time - peak1) == np.min(np.abs(time - peak1))).astype(float)
                d_gamma1 = np.zeros(numlags)

            if peak2 > 0 and fwhm2 > 0:
                alpha2 = peak2**2 / fwhm2**2 * 8 * np.log(2)
                beta2 = fwhm2**2 / peak2 / 8 / np.log(2)
                gamma2 = (time / peak2) ** alpha2 * np.exp(-(time - peak2) / beta2)
                d_gamma2 = -(alpha2 * tinv - 1 / beta2) * gamma2
            else:
                gamma2 = (np.abs(time - peak2) == np.min(np.abs(time - peak2))).astype(float)
                d_gamma2 = np.zeros(numlags)

            hrf = gamma1 - dip * gamma2
            d_hrf = d_gamma1 - dip * d_gamma2

        HS = np.vstack((hrf, d_hrf)) / np.sum(hrf)
        temp = np.convolve(response[:, k], HS.flatten(), mode='same')
        eventmatrix[:, k, 0:2] = temp[:numtimes].reshape((numtimes, 2))

        # Shifted hrfs:
        H = np.zeros((numlags, nd))
        delta = np.linspace(Delta1, Delta2, nd)
        for id in range(nd):
            if isinstance(hrf_parameters, dict):
                t = time + Delta1 - delta[id]
                hrf = np.interp(t, hrf_parameters['T'][k, :], hrf_parameters['H'][k, :], left=0, right=0)
            else:
                t = (time + Delta1 - delta[id]) * ((time + Delta1) > delta[id])
                if peak1 > 0 and fwhm1 > 0:
                    gamma1 = (t / peak1) ** alpha1 * np.exp(-(t - peak1) / beta1)
                else:
                    gamma1 = (np.abs(t - peak1) == np.min(np.abs(t - peak1))).astype(float)

                if peak2 > 0 and fwhm2 > 0:
                    gamma2 = (t / peak2) ** alpha2 * np.exp(-(t - peak2) / beta2)
                else:
                    gamma2 = (np.abs(t - peak2) == np.min(np.abs(t - peak2))).astype(float)

            hrf = gamma1 - dip * gamma2
            H[:, id] = hrf / np.sum(hrf)

        # Taylor coefs:
        origin = -int(np.round(Delta1 / dt))
        HS0 = np.vstack((np.zeros((origin, 2)), HS[:numlags - origin, :]))
        WS = np.linalg.pinv(HS0) @ H
        X_cache_W[:, k, 0:2] = WS.T
        prcnt_var_taylor = np.sum(H * (HS0 @ WS)) / np.sum(H * H) * 100

        # SVD:
        U, SS, Vt = np.linalg.svd(H, full_matrices=False)
        prcnt_var_spectral = (SS[0] ** 2 + SS[1] ** 2) / np.sum(SS ** 2) * 100
        sumU = np.sum(U[:, 0])
        US = U[:, :2] / sumU
        WS = Vt.T[:, :2] @ np.diag(SS[:2]) * sumU
        if delta @ WS[:, 1] < 0:
            US[:, 1] = -US[:, 1]
            WS[:, 1] = -WS[:, 1]

        temp = np.convolve(response[:, k], US.flatten(), mode='same')
        eventmatrix[:, k, 2:4] = temp[(np.arange(numtimes) - int(np.round(Delta1 / dt)))].reshape((numtimes, 2))
        X_cache_W[:, k, 2:4] = WS
        X_cache_W[:, k, 4] = delta

        if not np.all(WS[:, 0] > 0):
            print(f'Warning: use only for magnitudes, not delays. First coef not positive for stimulus {k+1}')

        cubic_coef = np.linalg.pinv(np.vstack((delta, delta ** 3)).T) @ (WS[:, 1] / WS[:, 0])
        if np.prod(cubic_coef) < 0:
            print(f'\nWarning: use only for magnitudes, not delays. SVD ratio not invertible for stimulus {k+1}')

    X_cache_X = np.zeros((n, numresponses, 4, numslices))

    for slice in range(numslices):
        subtime = np.ceil((frametimes + slicetimes[slice] - startime) / dt).astype(int)
        X_cache_X[:, :, :, slice] = eventmatrix[subtime[:numtimes], :, :]

    X_cache_TR = (np.max(frametimes) - np.min(frametimes)) / (len(frametimes) - 1)

    return {'X': X_cache_X, 'W': X_cache_W, 'TR': X_cache_TR}