# 第三题

In [None]:
import numpy as np
import pandas as pd
import os

def load_data(folder):
    def read(name):
        return pd.read_csv(os.path.join(folder, name), header=None).values

    return {
        "R_R": read("R_R.csv"),
        "R_G": read("R_G.csv"),
        "R_B": read("R_B.csv"),
        "G_R": read("G_R.csv"),
        "G_G": read("G_G.csv"),
        "G_B": read("G_B.csv"),
        "B_R": read("B_R.csv"),
        "B_G": read("B_G.csv"),
        "B_B": read("B_B.csv"),
    }

def compute_correction_matrices(data):
    H, W = data["R_R"].shape
    M_inv_all = np.zeros((H, W, 3, 3))  # 64x64x3x3
    
    input_vecs = np.array([
        [220, 0, 0],
        [0, 220, 0],
        [0, 0, 220]
    ])  # Shape: (3, 3)

    A = input_vecs.T  # (3,3)

    for i in range(H):
        for j in range(W):
            B = np.array([
                [data["R_R"][i,j], data["G_R"][i,j], data["B_R"][i,j]],
                [data["R_G"][i,j], data["G_G"][i,j], data["B_G"][i,j]],
                [data["R_B"][i,j], data["G_B"][i,j], data["B_B"][i,j]],
            ])  # 3x3

            try:
                M = np.linalg.lstsq(A.T, B.T, rcond=None)[0].T  # A @ M = B → solve M
                if np.linalg.det(M) == 0:
                    M_inv = np.linalg.pinv(M)   # SVD pseudo-inverse
                else:
                    M_inv = np.linalg.inv(M)
            except np.linalg.LinAlgError:
                M_inv = np.eye(3)  # fallback

            M_inv_all[i,j] = M_inv

    return M_inv_all

def apply_color_correction(img_data, M_inv_all):
    H, W, _ = img_data.shape
    corrected_img = np.zeros_like(img_data, dtype=np.float32)

    for i in range(H):
        for j in range(W):
            inp = img_data[i,j]
            M_inv = M_inv_all[i,j]
            corrected_rgb = M_inv @ inp
            corrected_img[i,j] = np.clip(corrected_rgb, 0, 255)

    return corrected_img.astype(np.uint8)
