In [None]:
!pip install omegaconf pandas scikit-learn torch fvcore

In [None]:
import random
import time
import os

import numpy as np
import omegaconf
import pandas as pd
import torch
import torch.nn as nn
import torch.optim
from sklearn.metrics import roc_auc_score
from torch.utils.data.dataloader import DataLoader
from tqdm.notebook import tqdm  # Use notebook-friendly tqdm

# This import should now work because of the sys.path.append call in the previous cell
from minigpt4.models.rec_model import MatrixFactorization

def uAUC_me(user, predict, label):
    if not isinstance(predict, np.ndarray):
        predict = np.array(predict)
    if not isinstance(label, np.ndarray):
        label = np.array(label)
    predict = predict.squeeze()
    label = label.squeeze()

    u, inverse, counts = np.unique(user, return_inverse=True, return_counts=True)  # sort in increasing
    index = np.argsort(inverse)
    candidates_dict = {}
    k = 0
    total_num = 0
    only_one_interaction = 0
    computed_u = []
    for u_i in u:
        start_id, end_id = total_num, total_num + counts[k]
        u_i_counts = counts[k]
        index_ui = index[start_id:end_id]
        if u_i_counts == 1:
            only_one_interaction += 1
            total_num += counts[k]
            k += 1
            continue
        candidates_dict[u_i] = [predict[index_ui], label[index_ui]]
        total_num += counts[k]
        k += 1
    # print(f"only one interaction users: {only_one_interaction}")
    auc = []
    only_one_class = 0

    for ui, pre_and_true in candidates_dict.items():
        pre_i, label_i = pre_and_true
        try:
            ui_auc = roc_auc_score(label_i, pre_i)
            auc.append(ui_auc)
            computed_u.append(ui)
        except:
            only_one_class += 1

    auc_for_user = np.array(auc)
    # print(f"computed user: {auc_for_user.shape[0]}, can not users: {only_one_class}")
    uauc = auc_for_user.mean()
    # print(f"uauc for validation Cost: {time.time() - start_time}, uauc: {uauc}")
    return uauc, computed_u, auc_for_user

class early_stoper(object):
    def __init__(self, ref_metric='valid_auc', incerase=True, patience=20) -> None:
        self.ref_metric = ref_metric
        self.best_metric = None
        self.increase = incerase
        self.reach_count = 0
        self.patience = patience

    def _registry(self, metrics):
        self.best_metric = metrics

    def update(self, metrics):
        if self.best_metric is None:
            self._registry(metrics)
            return True
        else:
            if self.increase and metrics[self.ref_metric] > self.best_metric[self.ref_metric]:
                self.best_metric = metrics
                self.reach_count = 0
                return True
            elif not self.increase and metrics[self.ref_metric] < self.best_metric[self.ref_metric]:
                self.best_metric = metrics
                self.reach_count = 0
                return True
            else:
                self.reach_count += 1
                return False

    def is_stop(self):
        if self.reach_count >= self.patience:
            return True
        else:
            return False

def run_a_trail(train_config, save_mode=False, save_file=None, need_train=True, warm_or_cold=None):
    seed = 2023
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if device.type == "cuda":
        torch.cuda.manual_seed_all(seed)

    data_dir = "/content/CoRA/dataset/ml-1m/"
    train_data = pd.read_pickle(data_dir + "train_ood2.pkl")[['uid', 'iid', 'label']].values
    valid_data = pd.read_pickle(data_dir + "valid_ood2.pkl")[['uid', 'iid', 'label']].values
    test_data = pd.read_pickle(data_dir + "test_ood2.pkl")[['uid', 'iid', 'label']].values

    user_num = max(train_data[:, 0].max(), valid_data[:, 0].max(), test_data[:, 0].max()) + 1
    item_num = max(train_data[:, 1].max(), valid_data[:, 1].max(), test_data[:, 1].max()) + 1

    print(f"user nums: {user_num}, item nums: {item_num}")

    mf_config = {
        "user_num": int(user_num),
        "item_num": int(item_num),
        "embedding_size": int(train_config['embedding_size'])
    }
    mf_config = omegaconf.OmegaConf.create(mf_config)

    train_data_loader = DataLoader(train_data, batch_size=train_config['batch_size'], shuffle=True)
    valid_data_loader = DataLoader(valid_data, batch_size=train_config['batch_size'], shuffle=False)
    test_data_loader = DataLoader(test_data, batch_size=train_config['batch_size'], shuffle=False)

    model = MatrixFactorization(mf_config).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=train_config['lr'], weight_decay=train_config['wd'])
    early_stop = early_stoper(ref_metric='valid_auc', incerase=True, patience=train_config['patience'])
    criterion = nn.BCEWithLogitsLoss()

    if not need_train:
        model.load_state_dict(torch.load(save_file, map_location=device))
        model.eval()
        # ... (evaluation logic remains the same, just ensure data is moved to device)
        return

    for epoch in tqdm(range(train_config['epoch']), desc="Epochs"):
        model.train()
        for bacth_id, batch_data in enumerate(tqdm(train_data_loader, desc=f"Epoch {epoch+1} Training", leave=False)):
            batch_data = batch_data.to(device)
            ui_matching = model(batch_data[:, 0].long(), batch_data[:, 1].long())
            loss = criterion(ui_matching, batch_data[:, -1].float())
            opt.zero_grad()
            loss.backward()
            opt.step()

        if epoch % train_config['eval_epoch'] == 0:
            model.eval()
            pre_val, label_val, users_val = [], [], []
            with torch.no_grad():
                for batch_data in tqdm(valid_data_loader, desc=f"Epoch {epoch+1} Validation", leave=False):
                    batch_data = batch_data.to(device)
                    ui_matching = model(batch_data[:, 0].long(), batch_data[:, 1].long())
                    users_val.extend(batch_data[:, 0].cpu().numpy())
                    pre_val.extend(ui_matching.cpu().numpy())
                    label_val.extend(batch_data[:, -1].cpu().numpy())
            valid_auc = roc_auc_score(label_val, pre_val)
            valid_uauc, _, _ = uAUC_me(users_val, pre_val, label_val)

            pre_test, label_test, users_test = [], [], []
            with torch.no_grad():
                for batch_data in tqdm(test_data_loader, desc=f"Epoch {epoch+1} Testing", leave=False):
                    batch_data = batch_data.to(device)
                    ui_matching = model(batch_data[:, 0].long(), batch_data[:, 1].long())
                    users_test.extend(batch_data[:, 0].cpu().numpy())
                    pre_test.extend(ui_matching.cpu().numpy())
                    label_test.extend(batch_data[:, -1].cpu().numpy())
            test_auc = roc_auc_score(label_test, pre_test)
            test_uauc, _, _ = uAUC_me(users_test, pre_test, label_test)

            updated = early_stop.update(
                {'valid_auc': valid_auc, 'valid_uauc': valid_uauc, 'test_auc': test_auc, 'test_uauc': test_uauc, 'epoch': epoch}
            )
            if updated and save_mode:
                torch.save(model.state_dict(), save_file)

            print(f"Epoch: {epoch+1}, Valid AUC: {valid_auc:.4f}, Valid uAUC: {valid_uauc:.4f}, Test AUC: {test_auc:.4f}, Test uAUC: {test_uauc:.4f}, Early Stop: {early_stop.reach_count}")

            if early_stop.is_stop():
                print("Early stopping reached!")
                break

    print("Training finished.")
    print(f"Best result: {early_stop.best_metric}")

# --- Execution Block ---
train_config = {
    'lr': 1e-3,
    'wd': 1e-4,
    'embedding_size': 64,
    "epoch": 100,
    "eval_epoch": 1,
    "patience": 10,
    "batch_size": 1024
}

save_dir = "/content/CoRA/pretrained/mf/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

save_file = os.path.join(save_dir, "mf_movielens_best.pth")

print("Starting training with config:", train_config)
print(f"Model will be saved to: {save_file}")

run_a_trail(
    train_config=train_config, 
    save_mode=True, 
    save_file=save_file,
    need_train=True
)
