In [None]:
# Tutorial 1: Dictionary learning assignment
# ======================================================
#
# This tutorial guides students through learning a dictionary from acoustic data.
# They will extract patches, train a dictionary, reconstruct, and analyze sparsity.

# Section 0: Setup
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import DictionaryLearning, MiniBatchDictionaryLearning, PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from skimage.util import view_as_windows
from scipy.signal import detrend

np.random.seed(42)
plt.rcParams["figure.dpi"] = 130

# Section 1: Load your Data


# Section 2: Patch Extraction
def extract_patches_2d_st(data2d, patch_size=(16, 32), stride=(8, 16)):
    Ps, Pt = patch_size
    Ss, St = stride
    windows = view_as_windows(data2d, (Ps, Pt), step=(Ss, St))
    patches = windows.reshape(-1, Ps, Pt)
    return patches

def stack_patches(data, patch_size=(16, 32), stride=(8, 16), max_patches=None):
    if data.ndim == 2:
        patches = extract_patches_2d_st(data, patch_size, stride)
    else:
        plist = [extract_patches_2d_st(d, patch_size, stride) for d in data]
        patches = np.concatenate(plist, axis=0)
    if max_patches and patches.shape[0] > max_patches:
        idx = np.random.choice(patches.shape[0], max_patches, replace=False)
        patches = patches[idx]
    return patches

PATCH_SIZE = (16, 32)
STRIDE = (8, 16)
MAX_PATCHES = 20000
patches = stack_patches(data, PATCH_SIZE, STRIDE, max_patches=MAX_PATCHES)
print("Patches:", patches.shape)

# Section 3: Preprocessing
Ps, Pt = PATCH_SIZE
X = patches.reshape(patches.shape[0], -1)
X_detrended = [detrend(p, axis=1, type='constant').reshape(-1) for p in patches]
X = np.asarray(X_detrended, dtype=np.float32)

scaler = StandardScaler(with_mean=True, with_std=True)
X_scaled = scaler.fit_transform(X)
X_train, X_val = train_test_split(X_scaled, test_size=0.1, random_state=42)
print("Train/Val:", X_train.shape, X_val.shape)

# Section 4: Dictionary Learning (students tune hyperparameters)
N_COMPONENTS = 128
ALPHA = 1.0
MAX_ITER = 100

learner = DictionaryLearning(
    n_components=N_COMPONENTS,
    alpha=ALPHA,
    max_iter=MAX_ITER,
    fit_algorithm='cd',
    transform_algorithm='lasso_lars',
    transform_alpha=ALPHA,
    random_state=42,
    verbose=True
)

D_codes = learner.fit_transform(X_train)
D_atoms = learner.components_
print("Atoms:", D_atoms.shape)

# Section 5: Reconstruction
C_val = learner.transform(X_val)
X_val_rec = C_val @ learner.components_
X_val_rec = scaler.inverse_transform(X_val_rec)
X_val_unscaled = scaler.inverse_transform(X_val)

mse = np.mean((X_val_unscaled - X_val_rec)**2)
print(f"Reconstruction MSE: {mse:.4e}")

# Section 6: Visualization
def show_atoms(atoms, patch_size, n_rows=6, n_cols=8, title="Dictionary Atoms"):
    Ps, Pt = patch_size
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(1.2*n_cols, 1.2*n_rows))
    for i, ax in enumerate(axes.flat):
        if i >= atoms.shape[0]:
            ax.axis('off'); continue
        im = ax.imshow(atoms[i].reshape(Ps, Pt), aspect='auto', origin='lower', cmap='viridis')
        ax.set_xticks([]); ax.set_yticks([])
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

show_atoms(D_atoms, PATCH_SIZE)

