In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import binom
from scipy.ndimage import gaussian_filter

def generate_blob(grid_size, blob_radius, jitter):
    """Generate a single blob with random jitter."""
    center = np.random.randint(blob_radius, grid_size - blob_radius, size=2) + jitter
    y, x = np.ogrid[-center[0]:grid_size-center[0], -center[1]:grid_size-center[1]]
    mask = x**2 + y**2 <= blob_radius**2
    blob = np.zeros((grid_size, grid_size))
    blob[mask] = 1
    return blob

def generate_map(grid_size, max_blobs, jitter):
    """Generate a spatial map with blobs and jitter."""
    map = np.zeros((grid_size, grid_size))
    num_blobs = binom.rvs(n=3, p=0.5, size=1)[0]  # Binomial distribution for number of blobs
    blob_radii = np.random.randint(5, 15, size=num_blobs)  # Random radius for each blob

    for _ in range(max_blobs):
        if num_blobs > 0:
            radius = np.random.choice(blob_radii)
            blob = generate_blob(grid_size, radius, jitter)
            map += blob  # Add blob to the map

    return map

def generate_synthetic_fmri_data(subjects, maps_per_subject, time_points, grid_size):
    """Generate synthetic fMRI data for multiple subjects."""
    data = np.zeros((subjects, maps_per_subject, grid_size, grid_size, time_points))

    for subject in range(subjects):
        jitter = np.random.normal(0, 1, 2)  # Gaussian-distributed jitter
        for map_index in range(maps_per_subject):
            spatial_map = generate_map(grid_size, max_blobs=3, jitter=jitter)
            for time_point in range(time_points):
                time_series = np.random.rand(grid_size, grid_size)
                noise = gaussian_filter(np.random.randn(grid_size, grid_size), sigma=3)
                data[subject, map_index, :, :, time_point] = spatial_map * time_series + noise

    return data

# Parameters
subjects = 12
maps_per_subject = 5
time_points = 150
grid_size = 50

# Generate synthetic fMRI data
synthetic_data = generate_synthetic_fmri_data(subjects, maps_per_subject, time_points, grid_size)

# Visualize one of the generated maps
plt.imshow(synthetic_data[0, 3, :, :, 0], cmap='gray')
plt.title("Sample Synthetic fMRI Map")
plt.colorbar()
plt.show()


In [None]:
import statistics

def compute_energy(Us,Vs,V,mu,lambda_r, regularization_fun, ):
    S = len(Us)  # Number of subjects
    energy = 0
    for s in range(S):
        term1 = np.linalg.norm(Ys[s] - Us[s] @ Vs.T, 'fro')**2
        term2 = mu * np.linalg.norm(Vs - V, 'fro')**2
        energy += 0.5 * (term1 + term2)
    
    regularizer = regularization_fun(V) 

    energy += lambda_r * regularizer
    return energy
    
def prox(V):
    return V

def update_vs(V ,Vs,Us ,Ys , mu):
    # Mettre à jour Vs en utilisant la régression ridge
    id = np.identity(Us.shape[1])
    Vs = V + (Ys-Us@Vs.T)@ Us@ np.linalg.inv(Us.T@Us + mu * id )
    return Vs

def update_us(Ys, Vs, u_l):
    u_l_new = u_l + np.linalg.norm(u_l, 2)**-2 * (Ys - u_l @ Vs.T @ Vs)
    u_l_new = u_l_new / max(np.linalg.norm(u_l_new, 2), 1)
    return u_l_new
    

def algorithm(Ys,V, k, mu, lambda_r, max_iter=1000, tolerance=1e-4):

    n,p = Ys.shape  # n number of time points, p number of voxels 
    S = len(Ys)     # S number of subjects
    
    V = np.random.rand(p, k)  # group level spatial maps
    Vs = [np.random.rand(p, k) for _ in range(S)]  # subject specific spatial maps
    Us = [np.random.rand(n, k) for _ in range(S)]  # time series
    E_old = float('inf')
    E_new = compute_energy(Us, V)

    i = 0
    while np.abs(E_new - E_old) > tolerance * E_old and i < max_iter:
        E_old = E_new
        
        # Update each U^s
        for s in range(S):
            for l in range(k):
                Us[s][:, l] = update_us(Ys[s], V, Us[s][:, l])
                
            Vs = update_vs(Ys[s], Us[s], V, mu)
            
        Vs_mean = statistics.mean(Vs)
        # Apply the proximal operator to the mean of every subject spatial maps
        V = prox(Vs_mean, lambda_r)
        
        # Compute the energy
        E_new = compute_energy(Us, V)
        i += 1
    
    return V,Vs,Us