## Training an IRT model with py-irt

This document walks through using the py-irt package.

First, you need to make sure that py-irt is installed.

In [1]:
!pip install py-irt

Collecting py-irt
  Using cached py_irt-0.4.4-py3-none-any.whl (26 kB)
Collecting scikit-learn<0.25.0,>=0.24.2
  Downloading scikit_learn-0.24.2-cp39-cp39-manylinux2010_x86_64.whl (23.8 MB)
[K     |████████████████████████████████| 23.8 MB 18.0 MB/s eta 0:00:01
[?25hCollecting typer<0.4.0,>=0.3.2
  Using cached typer-0.3.2-py3-none-any.whl (21 kB)
Collecting toml<0.11.0,>=0.10.2
  Using cached toml-0.10.2-py2.py3-none-any.whl (16 kB)
Collecting pydantic<2.0.0,>=1.8.2
  Downloading pydantic-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)
[K     |████████████████████████████████| 12.4 MB 9.8 MB/s eta 0:00:01
[?25hCollecting rich<10.0.0,>=9.3.0
  Using cached rich-9.13.0-py3-none-any.whl (197 kB)
Collecting pyro-ppl<2.0.0,>=1.6.0
  Using cached pyro_ppl-1.8.1-py3-none-any.whl (718 kB)
Collecting opt-einsum>=2.3.2
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting tqdm>=4.36
  Using cached tqdm-4.64.0-py2.py3-none-any.whl (78 kB)
Collecting pyr

Then, download the sample data file, stored as a jsonlines document.

In [2]:
!wget https://raw.githubusercontent.com/nd-ball/py-irt/d2a27dd55a84459782a5514e752ee48d9a63626e/test_fixtures/minitest.jsonlines

--2022-06-20 14:41:44--  https://raw.githubusercontent.com/nd-ball/py-irt/d2a27dd55a84459782a5514e752ee48d9a63626e/test_fixtures/minitest.jsonlines
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 311 [text/plain]
Saving to: ‘minitest.jsonlines’


2022-06-20 14:41:46 (6.49 MB/s) - ‘minitest.jsonlines’ saved [311/311]



In [3]:
!cat minitest.jsonlines

{"subject_id": "pedro",    "responses": {"q1": 1, "q2": 0, "q3": 1, "q4": 0}}
{"subject_id": "pinguino", "responses": {"q1": 1, "q2": 1, "q3": 0, "q4": 0}}
{"subject_id": "ken",      "responses": {"q1": 1, "q2": 1, "q3": 1, "q4": 1}}
{"subject_id": "burt",     "responses": {"q1": 0, "q2": 0, "q3": 0, "q4": 0}}

We'll first look at the options available to us in the py-irt package for the **train** command.
In this case, because the dataset is so small, we'll use a smaller learning rate and train for a smaller number of epochs than the defaults specify.

In [5]:
!py-irt train --help

Usage: py-irt train [OPTIONS] MODEL_TYPE DATA_PATH OUTPUT_DIR

Arguments:
  MODEL_TYPE  [required]
  DATA_PATH   [required]
  OUTPUT_DIR  [required]

Options:
  --epochs INTEGER
  --priors TEXT
  --dims INTEGER
  --lr FLOAT
  --lr-decay FLOAT
  --device TEXT        [default: cuda]
  --initializers TEXT
  --config-path TEXT
  --dropout FLOAT      [default: 0.5]
  --hidden INTEGER     [default: 100]
  --help               Show this message and exit.


We then run the py-irt package to fit a 1PL model and save it to a specified output directory.

In [6]:
!py-irt train 1pl minitest.jsonlines test-1pl/ --lr 0.02 --epochs 100

[2;36m[14:43:39][0m[2;36m [0mconfig: [33mmodel_type[0m=[32m'1pl'[0m [33mepochs[0m=[1;34m100[0m [33mpriors[0m=[3;35mNone[0m             ]8;id=1655750619.4997613-658926;file:///home/lalor/miniconda3/envs/pedro/lib/python3.9/site-packages/py_irt/cli.py\[2mcli.py[0m]8;;\[2m:66[0m
           [33minitializers[0m=[1m[[0m[1m][0m [33mdims[0m=[3;35mNone[0m [33mlr[0m=[1;34m0[0m[1;34m.02[0m [33mlr_decay[0m=[1;34m0[0m[1;34m.9999[0m                    
           [33mdropout[0m=[1;34m0[0m[1;34m.5[0m [33mhidden[0m=[1;34m100[0m [33mvocab_size[0m=[3;35mNone[0m                               
[2;36m          [0m[2;36m [0mdata_path: minitest.jsonlines                               ]8;id=1655750619.5006726-866513;file:///home/lalor/miniconda3/envs/pedro/lib/python3.9/site-packages/py_irt/cli.py\[2mcli.py[0m]8;;\[2m:68[0m
[2;36m          [0m[2;36m [0moutput directory: test-1pl/                                 ]8;id=1655750619.50160

We can inspect the best fitting parameters and the parameters from the final training epoch.
To make it look nice we can use jq.

In [8]:
!pip install jq

Collecting jq
  Downloading jq-1.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (591 kB)
[K     |████████████████████████████████| 591 kB 123 kB/s eta 0:00:01
[?25hInstalling collected packages: jq
Successfully installed jq-1.2.2


In [9]:
!jq . test-1pl/best_parameters.json

[1;39m{
  [0m[34;1m"ability"[0m[1;39m: [0m[1;39m[
    [0;39m0.202667236328125[0m[1;39m,
    [0;39m-0.95220547914505[0m[1;39m,
    [0;39m0.9865020513534546[0m[1;39m,
    [0;39m0.21425879001617432[0m[1;39m
  [1;39m][0m[1;39m,
  [0m[34;1m"diff"[0m[1;39m: [0m[1;39m[
    [0;39m0.5277286767959595[0m[1;39m,
    [0;39m-0.19620391726493835[0m[1;39m,
    [0;39m-0.7283022403717041[0m[1;39m,
    [0;39m-0.011020521633327007[0m[1;39m
  [1;39m][0m[1;39m,
  [0m[34;1m"irt_model"[0m[1;39m: [0m[0;32m"1pl"[0m[1;39m,
  [0m[34;1m"item_ids"[0m[1;39m: [0m[1;39m{
    [0m[34;1m"0"[0m[1;39m: [0m[0;32m"q4"[0m[1;39m,
    [0m[34;1m"1"[0m[1;39m: [0m[0;32m"q3"[0m[1;39m,
    [0m[34;1m"2"[0m[1;39m: [0m[0;32m"q1"[0m[1;39m,
    [0m[34;1m"3"[0m[1;39m: [0m[0;32m"q2"[0m[1;39m
  [1;39m}[0m[1;39m,
  [0m[34;1m"subject_ids"[0m[1;39m: [0m[1;39m{
    [0m[34;1m"0"[0m[1;39m: [0m[0;32m"pedro"[0m[1;39m,
    [0m[34;1m"1"[0m[1;

In [11]:
!jq . test-1pl/parameters.json

[1;39m{
  [0m[34;1m"ability"[0m[1;39m: [0m[1;39m[
    [0;39m0.22013945877552032[0m[1;39m,
    [0;39m-1.2920401096343994[0m[1;39m,
    [0;39m1.3400511741638184[0m[1;39m,
    [0;39m0.042264051735401154[0m[1;39m
  [1;39m][0m[1;39m,
  [0m[34;1m"diff"[0m[1;39m: [0m[1;39m[
    [0;39m0.8419020771980286[0m[1;39m,
    [0;39m-0.23540474474430084[0m[1;39m,
    [0;39m-1.0303592681884766[0m[1;39m,
    [0;39m-0.04312527924776077[0m[1;39m
  [1;39m][0m[1;39m,
  [0m[34;1m"irt_model"[0m[1;39m: [0m[0;32m"1pl"[0m[1;39m,
  [0m[34;1m"item_ids"[0m[1;39m: [0m[1;39m{
    [0m[34;1m"0"[0m[1;39m: [0m[0;32m"q4"[0m[1;39m,
    [0m[34;1m"1"[0m[1;39m: [0m[0;32m"q3"[0m[1;39m,
    [0m[34;1m"2"[0m[1;39m: [0m[0;32m"q1"[0m[1;39m,
    [0m[34;1m"3"[0m[1;39m: [0m[0;32m"q2"[0m[1;39m
  [1;39m}[0m[1;39m,
  [0m[34;1m"subject_ids"[0m[1;39m: [0m[1;39m{
    [0m[34;1m"0"[0m[1;39m: [0m[0;32m"pedro"[0m[1;39m,
    [0m[34;1m"1"[0m