<a href="https://colab.research.google.com/github/Lexaun-chen/STAT-4830-Group-Project/blob/main/notebooks/GPU_Matrix_Completion_with_SGD_and_NQM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import mean_squared_error, mean_absolute_error

In [None]:
# PyTorch GPU version of basic MF (vanilla SGD only)
class MF_Torch:
    def __init__(self, X_np, k, alpha, beta, iterations, device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.X = torch.tensor(X_np, dtype=torch.float32, device=self.device)
        self.k = k
        self.alpha = alpha
        self.beta = beta
        self.iterations = iterations

        self.num_users, self.num_items = self.X.shape
        self.mask = ~torch.isnan(self.X)

        self.U = torch.randn(self.num_users, k, device=self.device) * 0.01
        self.V = torch.randn(self.num_items, k, device=self.device) * 0.01
        self.b_u = torch.zeros(self.num_users, device=self.device)
        self.b_v = torch.zeros(self.num_items, device=self.device)
        self.b = torch.nanmean(self.X)

        self.observed_idx = torch.nonzero(self.mask, as_tuple=False)

    def predict(self):
        return self.b + self.b_u[:, None] + self.b_v[None, :] + torch.matmul(self.U, self.V.T)

    def train(self):
        history = []
        for it in range(self.iterations):
            torch.random.manual_seed(it)
            self._sgd_step()
            loss = self.compute_loss()
            history.append((it, loss))
            if (it + 1) % 20 == 0:
                print(f"Iteration {it+1}/{self.iterations}, Error: {loss:.4f}")
        return history

    def _sgd_step(self):
        for idx in torch.randperm(len(self.observed_idx)):
            i, j = self.observed_idx[idx]
            x = self.X[i, j]
            pred = self.predict()[i, j]
            e = x - pred

            self.b_u[i] += self.alpha * (2 * e - self.beta * self.b_u[i])
            self.b_v[j] += self.alpha * (2 * e - self.beta * self.b_v[j])

            grad_U = 2 * e * self.V[j] - self.beta * self.U[i]
            grad_V = 2 * e * self.U[i] - self.beta * self.V[j]

            self.U[i] += self.alpha * grad_U
            self.V[j] += self.alpha * grad_V

    def compute_loss(self):
        pred = self.predict()
        return torch.mean((self.X[self.mask] - pred[self.mask]) ** 2).item()

    def full_matrix(self):
        return self.predict().detach().cpu().numpy()

    def replace_nan(self, X_hat_np, X_np):
        X_complete = np.copy(X_np)
        X_complete[np.isnan(X_complete)] = X_hat_np[np.isnan(X_complete)]
        return X_complete


In [None]:
class MF_NQM_Torch:
    def __init__(self, X_np, k, alpha=1e-4, beta=0.01, iterations=100,
                 noise_var=1e-6, momentum=0.8, ema_beta=0.9, update_strategy='vanilla', device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        X_tensor = torch.tensor(X_np, dtype=torch.float32, device=self.device)
        self.X = X_tensor
        self.k = k
        self.alpha = alpha
        self.beta = beta
        self.iterations = iterations
        self.noise_var = noise_var
        self.momentum = momentum
        self.ema_beta = ema_beta
        self.update_strategy = update_strategy

        self.num_users, self.num_items = self.X.shape
        self.mask = ~torch.isnan(self.X)

        self.U = torch.randn(self.num_users, k, device=self.device) * 0.01
        self.V = torch.randn(self.num_items, k, device=self.device) * 0.01
        self.b_u = torch.zeros(self.num_users, device=self.device)
        self.b_v = torch.zeros(self.num_items, device=self.device)
        self.b = torch.nanmean(self.X)

        self.U_momentum = torch.zeros_like(self.U)
        self.V_momentum = torch.zeros_like(self.V)
        self.U_ema = torch.zeros_like(self.U)
        self.V_ema = torch.zeros_like(self.V)

        self.observed_idx = torch.nonzero(self.mask, as_tuple=False)

    def predict(self):
        return self.b + self.b_u[:, None] + self.b_v[None, :] + torch.matmul(self.U, self.V.T)

    def train(self):
        history = []
        for it in range(self.iterations):
            self._sgd_step()
            error = self.compute_loss()
            history.append((it, error))
            if (it + 1) % 10 == 0:
                print(f"Iteration {it+1}/{self.iterations}, MSE: {error:.4f}")
        return history

    def _sgd_step(self):
        max_grad = 1.0
        for idx in torch.randperm(len(self.observed_idx)):
            i, j = self.observed_idx[idx]
            x = self.X[i, j]
            pred = self.predict()[i, j]
            e = x - pred

            grad_U = 2 * e * self.V[j] - self.beta * self.U[i]
            grad_V = 2 * e * self.U[i] - self.beta * self.V[j]

            grad_U = torch.clamp(grad_U, -max_grad, max_grad)
            grad_V = torch.clamp(grad_V, -max_grad, max_grad)

            self.b_u[i] += self.alpha * (2 * e - self.beta * self.b_u[i])
            self.b_v[j] += self.alpha * (2 * e - self.beta * self.b_v[j])

            noise_u = torch.randn_like(grad_U) * self.noise_var**0.5
            noise_v = torch.randn_like(grad_V) * self.noise_var**0.5

            if self.update_strategy == 'vanilla':
                self.U[i] += self.alpha * grad_U + noise_u
                self.V[j] += self.alpha * grad_V + noise_v
            elif self.update_strategy == 'momentum':
                self.U_momentum[i] = self.momentum * self.U_momentum[i] + self.alpha * grad_U
                self.V_momentum[j] = self.momentum * self.V_momentum[j] + self.alpha * grad_V
                self.U[i] += self.U_momentum[i] + noise_u
                self.V[j] += self.V_momentum[j] + noise_v
            elif self.update_strategy == 'ema':
                self.U_ema[i] = self.ema_beta * self.U_ema[i] + (1 - self.ema_beta) * grad_U
                self.V_ema[j] = self.ema_beta * self.V_ema[j] + (1 - self.ema_beta) * grad_V
                self.U[i] += self.U_ema[i] + noise_u
                self.V[j] += self.V_ema[j] + noise_v

    def compute_loss(self):
        pred = self.predict()
        return F.mse_loss(pred[self.mask], self.X[self.mask]).item()

    def full_matrix(self):
        return self.predict().detach().cpu().numpy()

    def replace_nan(self, X_hat_np, X_np):
        X_complete = np.copy(X_np)
        X_complete[np.isnan(X_complete)] = X_hat_np[np.isnan(X_complete)]
        return X_complete


In [None]:
def generate_low_rank_matrix(n1, n2, r, noise_std=0.01, observed_ratio=0.2):
    M = np.random.randn(n1, r) @ np.random.randn(r, n2)
    df = r * (n1 + n2 - r)
    m = min(6 * df, round(observed_ratio * n1 * n2))
    p = 3 * m / (n1 * n2)
    Omega = np.random.choice(n1 * n2, m, replace=False)
    data = M.flatten()[Omega]
    data += noise_std * np.random.randn(*data.shape)

    X = np.full((n1, n2), np.nan)
    X.flat[Omega] = data
    print(f"Generated matrix: {n1}x{n2}, rank {r}, observed {100*p:.1f}%")
    return X, M

def evaluate(X_original, X_hat):
    observed = ~np.isnan(X_original)
    y_true = X_original[observed]
    y_pred = X_hat[observed]
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_true, y_pred)
    return round(mse, 4), round(rmse, 4), round(mae, 4)

In [None]:
if __name__ == "__main__":
    np.random.seed(42)
    torch.manual_seed(42)

    n1, n2, r = 200, 200, 5
    X, M_true = generate_low_rank_matrix(n1, n2, r)
    X_scaled = X / np.nanmax(X)

    # --------------------- Basic MF Torch ---------------------
    mf = MF_Torch(X_scaled, k=5, alpha=1e-3, beta=0.05, iterations=100)
    sgd_hist = mf.train()
    sgd_hat = mf.full_matrix() * np.nanmax(X)
    sgd_comp = mf.replace_nan(sgd_hat, X)
    mse, rmse, mae = evaluate(X, sgd_hat)

    print("\n[Vanilla SGD]")
    print(f"MSE: {mse}, RMSE: {rmse}, MAE: {mae}")

    # --------------------- NQM Torch (all strategies) ---------------------
    strategies = ["vanilla", "momentum", "ema"]
    all_histories = {}
    all_metrics = []

    for strat in strategies:
        model = MF_NQM_Torch(X_scaled, k=5, alpha=5e-4, beta=0.05,
                             iterations=100, update_strategy=strat)
        hist = model.train()
        X_hat = model.full_matrix() * np.nanmax(X)
        X_comp = model.replace_nan(X_hat, X)
        mse, rmse, mae = evaluate(X, X_hat)

        print(f"\n[{strat.upper()}]")
        print(f"MSE: {mse}, RMSE: {rmse}, MAE: {mae}")

        all_histories[strat] = hist
        all_metrics.append({"Strategy": strat, "MSE": mse, "RMSE": rmse, "MAE": mae})

    # --------------------- Plot Training Curves ---------------------
    plt.figure(figsize=(10, 6))
    sgd_iters, sgd_errors = zip(*sgd_hist)
    plt.plot(sgd_iters, sgd_errors, label="SGD (Basic)")

    for strat, hist in all_histories.items():
        iters, errors = zip(*hist)
        plt.plot(iters, errors, label=strat.capitalize())

    plt.title("Training Error on Large-scale Matrix Completion")
    plt.xlabel("Iteration")
    plt.ylabel("MSE")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --------------------- Plot Final Errors ---------------------
    df_metrics = pd.DataFrame(all_metrics)
    df_metrics = pd.concat([
        pd.DataFrame([{"Strategy": "sgd", "MSE": mse, "RMSE": rmse, "MAE": mae}]),
        df_metrics
    ], ignore_index=True)

    df_metrics.plot(x='Strategy', y=['MSE', 'RMSE', 'MAE'], kind='bar')
    plt.title("Final Error Comparison")
    plt.ylabel("Error Value")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    print("\nFinal Metrics:")
    print(df_metrics.round(4).to_string(index=False))
