This code is heavily based on Edu CDM:
@misc{bigdata2021educdm, title={EduCDM}, author={bigdata-ustc}, publisher = {GitHub}, journal = {GitHub repository}, year = {2021}, howpublished = {\url{https://github.com/bigdata-ustc/EduCDM}}, }

Specifically the presentation of the MIRT model as proposed in Reckase, Mark D. "18 Multidimensional Item Response Theory." _Handbook of statistics_ 26 (2006): 607-642.

# Package Links
EduData: https://pypi.org/project/EduData/ EduCDM: https://pypi.org/project/EduCDM/

In [1]:
!pip --quiet install EduData
!pip --quiet install EduCDM

In [2]:
#use package for easy download of files
from EduData import get_data
get_data("cdbd-a0910", "../data")

downloader, INFO http://base.ustc.edu.cn/data/cdbd/a0910/item.csv is saved as ..\data\a0910\item.csv
downloader, INFO file existed, skipped


'..\\data'

# Data Wrangling

In [3]:
# Load the data from files
import pandas as pd

train_data = pd.read_csv("../data/a0910/train.csv")
valid_data = pd.read_csv("../data/a0910/valid.csv")
test_data = pd.read_csv("../data/a0910/test.csv")

train_data.head(5)

Unnamed: 0,user_id,item_id,score
0,1615,12977,1
1,782,13124,0
2,1084,16475,0
3,593,8690,0
4,127,14225,1


In [4]:
len(train_data), len(valid_data), len(test_data)

(186049, 25606, 55760)

In [5]:
# Transform data to torch Dataloader (i.e., batchify)
# batch_size is set to 32

import torch
from torch.utils.data import TensorDataset, DataLoader

batch_size = 256
def transform(x, y, z, batch_size, **params):
    dataset = TensorDataset(
        torch.tensor(x, dtype=torch.int64),
        torch.tensor(y, dtype=torch.int64),
        torch.tensor(z, dtype=torch.float)
    )
    return DataLoader(dataset, batch_size=batch_size, **params)

train, valid, test = [
    transform(data["user_id"], data["item_id"], data["score"], batch_size)
    for data in [train_data, valid_data, test_data]
]
train, valid, test

(<torch.utils.data.dataloader.DataLoader at 0x2287caea520>,
 <torch.utils.data.dataloader.DataLoader at 0x2287caeaca0>,
 <torch.utils.data.dataloader.DataLoader at 0x2287caf64c0>)

# Building the Model

In [6]:
import logging
logging.getLogger().setLevel(logging.INFO)

In [11]:
import logging
import numpy as np
import torch
from EduCDM import CDM
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score, mean_squared_error

In [34]:
#from EduCDM import MIRT

class CDM(object):
    def __init__(self, *args, **kwargs) -> ...:
        pass

    def train(self, *args, **kwargs) -> ...:
        raise NotImplementedError

    def eval(self, *args, **kwargs) -> ...:
        raise NotImplementedError

    def save(self, *args, **kwargs) -> ...:
        raise NotImplementedError

    def load(self, *args, **kwargs) -> ...:
        raise NotImplementedError
        
def irt2pl(theta, a, b, c, *, F=np):
    """

    Parameters
    ----------
    theta
    a
    b
    F

    Returns
    -------

    Examples
    --------
    >>> theta = [1, 0.5, 0.3]
    >>> a = [-3, 1, 3]
    >>> b = 0.5
    >>> irt2pl(theta, a, b) # doctest: +ELLIPSIS
    0.109...
    >>> theta = [[1, 0.5, 0.3], [2, 1, 0]]
    >>> a = [[-3, 1, 3], [-3, 1, 3]]
    >>> b = [0.5, 0.5]
    >>> irt2pl(theta, a, b) # doctest: +ELLIPSIS
    array([0.109..., 0.004...])
    """
    return c + (1 - c) / (1 + F.exp(- F.sum(F.multiply(a, theta), axis=-1) + b)) # Based on 2 parameter logistic model


class MIRTNet(nn.Module):
    def __init__(self, user_num, item_num, latent_dim, a_range, irf_kwargs=None):
        super(MIRTNet, self).__init__()
        self.user_num = user_num
        self.item_num = item_num
        self.irf_kwargs = irf_kwargs if irf_kwargs is not None else {}
        self.theta = nn.Embedding(self.user_num, latent_dim)
        self.a = nn.Embedding(self.item_num, latent_dim)
        self.b = nn.Embedding(self.item_num, 1)
        self.c = nn.Embedding(self.item_num, 1)
        self.a_range = a_range

    def forward(self, user, item):
        theta = torch.squeeze(self.theta(user), dim=-1) # theta, a, and b are extracted and squeezed to fit the appropriate dimensions
        a = torch.squeeze(self.a(item), dim=-1)
        if self.a_range is not None: #ensure a values are within the "a_range"
            a = self.a_range * torch.sigmoid(a)
        else:
            a = F.softplus(a)
        b = torch.squeeze(self.b(item), dim=-1)
        c = torch.sigmoid(torch.squeeze(self.c(item), dim=-1)) #sigmoid activation keeps c in [0,1]
        if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b):  # pragma: no cover # check for any NaN values
            raise ValueError('ValueError:theta,a,b may contains nan!  The a_range is too large.')
        return self.irf(theta, a, b, c,  **self.irf_kwargs) #compute output with irt2pl

    @classmethod
    def irf(cls, theta, a, b, c, **kwargs):
        return irt2pl(theta, a, b, c, F=torch)


class MIRT(CDM):
    def __init__(self, user_num, item_num, latent_dim, a_range=None):
        super(MIRT, self).__init__()
        self.irt_net = MIRTNet(user_num, item_num, latent_dim, a_range)

    def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
        self.irt_net = self.irt_net.to(device)
        loss_function = nn.BCELoss()

        trainer = torch.optim.Adam(self.irt_net.parameters(), lr)

        for e in range(epoch):
            losses = []
            for batch_data in tqdm(train_data, "Epoch %s" % e):
                user_id, item_id, response = batch_data
                user_id: torch.Tensor = user_id.to(device)
                item_id: torch.Tensor = item_id.to(device)
                predicted_response: torch.Tensor = self.irt_net(user_id, item_id)
                response: torch.Tensor = response.to(device)
                loss = loss_function(predicted_response, response)

                # back propagation
                trainer.zero_grad()
                loss.backward()
                trainer.step()

                losses.append(loss.mean().item())
            print("[Epoch %d] LogisticLoss: %.6f" % (e, float(np.mean(losses))))

            if test_data is not None:
                auc, accuracy, rmse = self.eval(test_data, device=device)
                print("[Epoch %d] auc: %.6f, accuracy: %.6f, rmse: %.6f" % (e, auc, accuracy, rmse))

    def eval(self, test_data, device="cpu") -> tuple:
        self.irt_net = self.irt_net.to(device)
        self.irt_net.eval()
        y_pred = []
        y_true = []
        for batch_data in tqdm(test_data, "evaluating"):
            user_id, item_id, response = batch_data
            user_id: torch.Tensor = user_id.to(device)
            item_id: torch.Tensor = item_id.to(device)
            pred: torch.Tensor = self.irt_net(user_id, item_id)
            y_pred.extend(pred.tolist())
            y_true.extend(response.tolist())

        #calculate rmse
        mse = mean_squared_error(y_true, y_pred)
        rmse = np.sqrt(mse)
        
        self.irt_net.train()
        return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5), rmse

    def save(self, filepath):
        torch.save(self.irt_net.state_dict(), filepath)
        logging.info("save parameters to %s" % filepath)

    def load(self, filepath):
        self.irt_net.load_state_dict(torch.load(filepath))
        logging.info("load parameters from %s" % filepath)


In [35]:
cdm = MIRT(4164, 17747, 123)

cdm.train(train, valid, epoch=2)
cdm.save("mirt.params")

Epoch 0: 100%|██████████████████████████████████████████████████████████████████████| 727/727 [00:04<00:00, 172.98it/s]


[Epoch 0] LogisticLoss: 4.740408


evaluating: 100%|███████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 273.57it/s]


[Epoch 0] auc: 0.497733, accuracy: 0.573420, rmse: 0.559327


Epoch 1: 100%|██████████████████████████████████████████████████████████████████████| 727/727 [00:03<00:00, 187.06it/s]


[Epoch 1] LogisticLoss: 1.925835


evaluating: 100%|███████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 288.08it/s]
INFO:root:save parameters to mirt.params


[Epoch 1] auc: 0.497243, accuracy: 0.568968, rmse: 0.556275


In [36]:
cdm.load("mirt.params")
auc, accuracy, rmse = cdm.eval(test)
print("auc: %.6f, accuracy: %.6f, rmse: %.6f" % (auc, accuracy, rmse))

INFO:root:load parameters from mirt.params
evaluating: 100%|███████████████████████████████████████████████████████████████████| 218/218 [00:00<00:00, 267.54it/s]


auc: 0.497638, accuracy: 0.567055, rmse: 0.559500
