diff --git a/lib/irt_ruby/rasch_model.rb b/lib/irt_ruby/rasch_model.rb index 8d0e22e..1095af7 100644 --- a/lib/irt_ruby/rasch_model.rb +++ b/lib/irt_ruby/rasch_model.rb @@ -5,33 +5,33 @@ module IrtRuby # A class representing the Rasch model for Item Response Theory. class RaschModel - def initialize(data, max_iter: 1000, tolerance: 1e-6) + def initialize(data, max_iter: 1000, tolerance: 1e-6, learning_rate: 0.01) @data = data @abilities = Array.new(data.row_count) { rand } @difficulties = Array.new(data.column_count) { rand } @max_iter = max_iter @tolerance = tolerance + @learning_rate = learning_rate end + # Sigmoid function to calculate probability def sigmoid(x) 1.0 / (1.0 + Math.exp(-x)) end + # Calculate the log-likelihood of the data given the current parameters def likelihood likelihood = 0 @data.row_vectors.each_with_index do |row, i| row.to_a.each_with_index do |response, j| prob = sigmoid(@abilities[i] - @difficulties[j]) - if response == 1 - likelihood += Math.log(prob) - elsif response.zero? - likelihood += Math.log(1 - prob) - end + likelihood += response == 1 ? Math.log(prob) : Math.log(1 - prob) end end likelihood end + # Update parameters using gradient ascent def update_parameters last_likelihood = likelihood @max_iter.times do |_iter| @@ -39,8 +39,8 @@ def update_parameters row.to_a.each_with_index do |response, j| prob = sigmoid(@abilities[i] - @difficulties[j]) error = response - prob - @abilities[i] += 0.01 * error - @difficulties[j] -= 0.01 * error + @abilities[i] += @learning_rate * error + @difficulties[j] -= @learning_rate * error end end current_likelihood = likelihood @@ -50,6 +50,7 @@ def update_parameters end end + # Fit the model to the data def fit update_parameters { abilities: @abilities, difficulties: @difficulties } diff --git a/spec/irt_ruby/rasch_model_spec.rb b/spec/irt_ruby/rasch_model_spec.rb index bd1c199..323636d 100644 --- a/spec/irt_ruby/rasch_model_spec.rb +++ b/spec/irt_ruby/rasch_model_spec.rb @@ -4,26 +4,31 @@ RSpec.describe IrtRuby::RaschModel do let(:data) { Matrix[[1, 0, 1], [0, 1, 0], [1, 1, 1]] } - let(:irt_model) { IrtRuby::RaschModel.new(data, max_iter: 2000) } + let(:model) { IrtRuby::RaschModel.new(data, max_iter: 2000) } + + describe "#initialize" do + it "initializes with data" do + expect(model.instance_variable_get(:@data)).to eq(data) + end + end describe "#sigmoid" do - it "calculates the sigmoid of a value" do - expect(irt_model.sigmoid(0)).to be_within(0.01).of(0.5) - expect(irt_model.sigmoid(2)).to be_within(0.01).of(0.88) + it "calculates the sigmoid function" do + expect(model.sigmoid(0)).to eq(0.5) end end describe "#likelihood" do it "calculates the likelihood of the data" do - expect(irt_model.likelihood).to be_a(Float) + expect(model.likelihood).to be_a(Float) end end describe "#fit" do it "fits the model and returns abilities and difficulties" do - results = irt_model.fit - expect(results[:abilities].size).to eq(3) - expect(results[:difficulties].size).to eq(3) + result = model.fit + expect(result[:abilities].size).to eq(data.row_count) + expect(result[:difficulties].size).to eq(data.column_count) end end end