In [None]:
import torch
import torchvision.transforms as T
from sklearn.cluster import KMeans
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

# --- Parameters ---
patch_size = 6
channels = 3
num_patches = 10000
num_clusters = 100  # Number of dictionary elements (filters)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load Dataset ---
transform = T.Compose([
    T.ToTensor()
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# --- Function: Extract Random Patches ---
def extract_random_patches(loader, patch_size, num_patches):
    patches = []
    for img, _ in tqdm(loader):
        img = img.squeeze(0)  # shape: [C, H, W]
        c, h, w = img.shape
        for _ in range(10):  # take up to 10 patches per image
            if len(patches) >= num_patches:
                break
            x = torch.randint(0, w - patch_size, (1,))
            y = torch.randint(0, h - patch_size, (1,))
            patch = img[:, y:y+patch_size, x:x+patch_size].reshape(-1)
            patches.append(patch.numpy())
        if len(patches) >= num_patches:
            break
    return np.stack(patches)

# --- Step 1: Extract Patches ---
print("Extracting patches...")
X = extract_random_patches(dataloader, patch_size, num_patches)

# --- Step 2: Normalize Patches (zero-mean, unit-variance) ---
X -= np.mean(X, axis=1, keepdims=True)
X /= (np.std(X, axis=1, keepdims=True) + 1e-5)

# --- Step 3: Unsupervised Learning (k-means) ---
print("Running k-means...")
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
kmeans.fit(X)
centroids = kmeans.cluster_centers_  # shape: [num_clusters, patch_dim]

# --- Step 4: Encoding Function ---
def encode_patch(patch, centroids):
    # Use negative L2 distance (simplified encoding)
    distances = np.linalg.norm(centroids - patch, axis=1)
    return -distances  # higher means more similar

# Example: encode a few patches
print("Encoding a few sample patches:")
encoded = [encode_patch(p, centroids) for p in X[:5]]
print(np.array(encoded).shape)  # [5, num_clusters]
