In [51]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from sb3_contrib import TRPO
import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
import math
import torch.nn.functional as F
from enum import IntEnum
import numpy as np
import pickle
from akt import AKT
import os
from load_data import DATA, PID_DATA
import csv
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [52]:
class PracticeProblemEnv(gym.Env):
    def __init__(self, params, max_step=10, rew_func="mock", units=None, device="cuda"):
        super(PracticeProblemEnv, self).__init__()
        # self.curr_step = None
        self.max_step = max_step
        self.rew_func = rew_func
        # self.curr_q = -1
        # self.curr_pred = -1
        # self.curr_pid = -1

        self.params = params
        self.device = device
        self.kt_model = self._load_model()
        self.p_q_dict, self.q_p_dict = self._load_pq_qp_dict(units)

        self.actions = [*self.p_q_dict.keys()]
        self.action_space = spaces.Discrete(len(self.actions)) #[0,n_question-1]
        self.observation_space = spaces.Box(np.array([1,0,1]), np.array([self.params.n_question, 1, self.params.n_pid])) #[1,n_question]/[0,1]/[1,n_pid]
   


        self.reset()
    
    def _load_model(self, pretrained_path="_b24_nb1_gn-1_lr1e-05_s224_sl200_do0.05_dm256_ts1_kq1_l21e-05_178"):
        model = AKT(n_question=self.params.n_question, n_pid=self.params.n_pid, n_blocks=self.params.n_block, d_model=self.params.d_model,
                    dropout=self.params.dropout, kq_same=self.params.kq_same, model_type='akt', l2=self.params.l2).to(self.device)
        checkpoint = torch.load(os.path.join( 'model', self.params.model, self.params.save, pretrained_path))
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        return model
    

    def reset(self, seed=None):
        self.history = {'q':[],'target':[],'pid':[]} # Initialize past interactions
        self.curr_step = 0
        self.curr_q = -1
        self.curr_pred = -1
        self.curr_pid = -1

        # return self.step(np.random.choice(range(self.n_problems)))[0], {}
        return self.step(self.action_space.sample())[0], {}

    def _get_obs(self):
        return np.array([self.curr_q, self.curr_pred, self.curr_pid], dtype=int)
        
    def _rew(self, n_problems_per_type=1):
        sampled_concpets = []
        sampled_problems = []
        for question_type, question_ids in self.q_p_dict.items():
            num = min(n_problems_per_type, len(question_ids))
            sampled_problems += random.sample([*question_ids], num)
            sampled_concpets += ([question_type ] * num)

        mean_performance = self.predict(sampled_concpets,sampled_problems)
        return mean_performance
    
    def switch_rew(self, new_rew_func):
        self.rew_func = new_rew_func
    
    def step(self, action):#action is an np int e.g. nparray(3) of the index of the action specified in self.actions
        self.curr_step += 1
        # self.curr_pid = self.actions[action.item]
        self.curr_pid = self.actions[action]
        self.curr_q = self.p_q_dict[self.curr_pid]

        correct_prob = self.predict([self.curr_q],[self.curr_pid])
        # correct_prob = self.predict()
        self.curr_pred = np.random.rand() < correct_prob

        if self.rew_func == "mock":
            reward = self._rew()
        elif self.rew_func == "correct":
            reward = 1 if self.curr_pred else 0
        else:
            raise NotImplementedError
            
        # Update history with the action and the correctness
        self.history['q'] += [self.curr_q]
        self.history['target'] += [self.curr_pred]
        self.history['pid'] += [self.curr_pid]
        
        # Recompute the state using the kt_model for each question
        obs = self._get_obs()

        done = self.curr_step >= self.max_step
        good = (self.rew_func == "mock") and (reward > 0.9)

        return obs, reward, done, good, {}
    
    # def predict(self):
    #     q = torch.tensor(self.history['q'][-(self.params.seqlen-1):]+[self.curr_q])
    #     target = torch.tensor(self.history['target'][-(self.params.seqlen-1):]+[0])
    #     pid = torch.tensor(self.history['pid'][-(self.params.seqlen-1):]+[self.curr_pid])
    #     qa = q+target*self.params.n_question

    #     padded_q = torch.zeros((1, self.params.seqlen))
    #     padded_qa = torch.zeros((1, self.params.seqlen))
    #     padded_target = torch.full((1,self.params.seqlen),-1)
    #     padded_pid = torch.zeros((1, self.params.seqlen))

    #     pred_index = q.shape[0]
    #     padded_q[:, :len(q)] = q
    #     padded_qa[:, :len(q)] = qa
    #     padded_target[:, :len(target)] = target
    #     padded_pid[:, :len(pid)] = pid

    #     # target_1 = np.floor(target)
    #     q = padded_q.long().to(self.device)
    #     qa = padded_qa.long().to(self.device)
    #     target = padded_target.long().to(self.device)
    #     pid = padded_pid.long().to(self.device)
        
    #     with torch.no_grad():
    #         loss, pred, ct = self.kt_model(q,qa,target,pid)

    #     nopadding_index = np.flatnonzero(padded_target.reshape((-1,)) >= -0.9).tolist()
    #     pred_nopadding = pred[nopadding_index]
    #     correct_prob = pred_nopadding[-1].item()
    #     return correct_prob
    
    def predict(self, curr_q, curr_pid):

        assert type(curr_q) == type(curr_pid) == list
        batch_size = len(curr_q)

        q = torch.cat((torch.tensor(self.history['q'][-(self.params.seqlen-1):]).tile((batch_size,1)),(torch.tensor(curr_q).unsqueeze(-1))),1)
        target = torch.tensor(self.history['target'][-(self.params.seqlen-1):]+[0]).tile((batch_size,1))
        pid = torch.cat((torch.tensor(self.history['pid'][-(self.params.seqlen-1):]).tile((batch_size,1)),(torch.tensor(curr_pid).unsqueeze(-1))),1)
        assert pid.shape == target.shape == pid.shape #(test_n_problem,3)
        qa = q+target*self.params.n_question
        
        padded_q = torch.zeros((batch_size, self.params.seqlen))
        padded_qa = torch.zeros((batch_size, self.params.seqlen))
        padded_target = torch.full((batch_size,self.params.seqlen),-1)
        padded_pid = torch.zeros((batch_size, self.params.seqlen))

        pred_index = q.shape[1]
        padded_q[:, :pred_index]= q
        padded_qa[:, :pred_index]= qa
        padded_target[:, :pred_index]= target
        padded_pid[:, :pred_index]= pid

        q = padded_q.long().to(device)
        qa = padded_qa.long().to(device)
        target = padded_target.long().to(device)
        pid = padded_pid.long().to(device)

        with torch.no_grad():
            loss, pred, ct = self.kt_model(q,qa,target,pid)

        nopadding_index = np.flatnonzero(padded_target.reshape((-1,)) >= -0.9).tolist()
        pred_nopadding = pred[nopadding_index]

        test_result = pred_nopadding[(pred_index-1)::pred_index]
        assert test_result.shape == (batch_size,)
        correct_prob = test_result.mean().item()

        return correct_prob

    def _load_pq_qp_dict(self, units=None):
        def iterate_over_data(file_path):
            with open(file_path, mode='r') as file:
                reader = csv.reader(file)
                rows = list(reader)

            for i in range(0, len(rows), 4):
                # Extract the question ids and concept ids
                question_ids = [int(q) for q in rows[i+1] if q]
                concept_ids = [int(c) for c in rows[i+2] if c]
                q_c_ids = [(int(q),int(c)) for (q,c) in zip(rows[i+1],rows[i+2]) if (q and c and ((units is None) or (int(c) in units)))]

                # Build the dictionary mapping question ids to concept ids
                for question_id, concept_id in q_c_ids:
                    if question_id not in question_concept_dict:
                        question_concept_dict[question_id] = concept_id
                    if concept_id not in concept_question_dict:
                        concept_question_dict[concept_id] = {question_id}
                    else:
                        concept_question_dict[concept_id].add(question_id)
                        concept_question_dict[concept_id].add(question_id)

        all_files = os.listdir(self.params.data_dir)

        # Filter the list to include only CSV files
        csv_files = [file for file in all_files if file.endswith('.csv')]

        question_concept_dict = {}
        concept_question_dict = {}
        for f in csv_files:
            old_dict = question_concept_dict.copy()
            iterate_over_data(self.params.data_dir+'/'+ f)
            if old_dict == question_concept_dict:
                break
        
        return question_concept_dict, concept_question_dict
    
    def check_hist(self):
        return self.history

In [53]:
with open('result/akt_pid/assist2009_pid/args.pkl', 'rb') as f:
    params = pickle.load(f)

env = PracticeProblemEnv(params,max_step=10, rew_func='mock', units = [1], device=device)

# Define and train the TRPO model
model = TRPO('MlpPolicy', env, verbose=1)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [54]:
model.learn(total_timesteps=100)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.5      |
|    ep_rew_mean     | 2.51     |
| time/              |          |
|    fps             | 46       |
|    iterations      | 1        |
|    time_elapsed    | 43       |
|    total_timesteps | 2048     |
---------------------------------


<sb3_contrib.trpo.trpo.TRPO at 0x19abebafd60>

In [21]:
# model.save("model_1000_100_mock1")
# model.load("model_1000_100_mock1")

### Evaluation Space

In [58]:
# Test the trained model
# model.load("model_1000_100_mock1")

obs, _ = env.reset()
env.switch_rew("mock")
obs_list = []
total_rewards = []
k=1000
for _ in range(k):
    action, _states = model.predict(obs)
    obs, rewards, terminated, truncated, info = env.step(action)
    total_rewards.append(rewards)
    obs_list.append(obs)
    if terminated or truncated:
        obs, _ = env.reset()

In [56]:
print([i[0] for i in obs_list])
print([i[2] for i in obs_list])
# print(total_rewards/k)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [57]:
total_rewards

[0.2684689462184906,
 0.19569899141788483,
 0.18130241334438324,
 0.21762630343437195,
 0.2262425273656845,
 0.1859651654958725,
 0.21663598716259003,
 0.23775117099285126,
 0.43738433718681335,
 0.899749755859375,
 0.9113047122955322,
 0.9187636375427246,
 0.8796201944351196,
 0.8884856700897217,
 0.9255806803703308,
 0.9033583998680115,
 0.9003592729568481,
 0.9063846468925476,
 0.9144624471664429,
 0.9157002568244934,
 0.2179296463727951,
 0.24461352825164795,
 0.19291038811206818,
 0.18481090664863586,
 0.211447075009346,
 0.19316084682941437,
 0.2020760029554367,
 0.5253486633300781,
 0.4389297664165497,
 0.19872157275676727,
 0.21851469576358795,
 0.6145160794258118,
 0.8044691681861877,
 0.808168888092041,
 0.823475182056427,
 0.8848199248313904,
 0.8527151942253113,
 0.8900446891784668,
 0.880946159362793,
 0.8938531279563904,
 0.9023830890655518,
 0.19684861600399017,
 0.2348404973745346,
 0.6312023997306824,
 0.7603068351745605,
 0.7881522178649902,
 0.821006715297699,
 0.875

In [54]:
print(total_rewards[-1])

0.9127399921417236


In [None]:
import numpy as np
from collections import Counter

Counter([i[0] for i in obs_list])

### Experimental Space

In [271]:
with open('result/akt_pid/assist2009_pid/args.pkl', 'rb') as f:
    params = pickle.load(f)
env = PracticeProblemEnv(params,max_step=1000, rew_func='greedy',device=device)
obs, _ = env.reset()


In [272]:
obs, rewards, terminated, truncated, info = env.step(np.array(75))

In [274]:
env.q_p_dict
sampled_concpets = []
sampled_problems = []
n_problems_per_type = 10
for question_type, question_ids in env.q_p_dict.items():
    num = min(n_problems_per_type, len(question_ids))
    sampled_problems += random.sample([*question_ids], num)
    sampled_concpets += ([question_type ] * num)

In [275]:
# sampled_problems = sampled_problems[:1]
# sampled_concpets = sampled_concpets[:1]

In [276]:
test_n_problem = len(sampled_problems)
q = torch.cat((torch.tensor(env.history['q'][-(env.params.seqlen-1):]).tile((test_n_problem,1)),(torch.tensor(sampled_concpets).unsqueeze(-1))),1)
target = torch.tensor(env.history['target'][-(env.params.seqlen-1):]+[0]).tile((test_n_problem,1))
pid = torch.cat((torch.tensor(env.history['pid'][-(env.params.seqlen-1):]).tile((test_n_problem,1)),(torch.tensor(sampled_problems).unsqueeze(-1))),1)
assert pid.shape == target.shape == pid.shape #(test_n_problem,3)
qa = q+target*env.params.n_question

In [277]:
padded_q = torch.zeros((test_n_problem, env.params.seqlen)) 
padded_qa = torch.zeros((test_n_problem, env.params.seqlen))
padded_target = torch.full((test_n_problem,env.params.seqlen),-1)
padded_pid = torch.zeros((test_n_problem, env.params.seqlen))

In [278]:
pred_index = q.shape[1]
padded_q[:, :pred_index]= q
padded_qa[:, :pred_index]= qa
padded_target[:, :pred_index]= target
padded_pid[:, :pred_index]= pid

In [279]:
q = padded_q.long().to(device)
qa = padded_qa.long().to(device)
target = padded_target.long().to(device)
pid = padded_pid.long().to(device)

In [280]:
with torch.no_grad():
    loss, pred, ct = env.kt_model(q,qa,target,pid)

In [281]:
nopadding_index = np.flatnonzero(padded_target.reshape((-1,)) >= -0.9).tolist()
pred_nopadding = pred[nopadding_index]

In [282]:
test_result = pred_nopadding[(pred_index-1)::pred_index]
assert test_result.shape == (test_n_problem,)
mean_performance = test_result.mean().item()
# return mean_performance

In [283]:
test_result.shape

torch.Size([750])

In [284]:
mean_performance

0.418801873922348

In [288]:
np.array(3).shape

()