In [33]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from sb3_contrib import TRPO
import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
import math
import torch.nn.functional as F
from enum import IntEnum
import numpy as np
import pickle
from akt import AKT
import os
from load_data import DATA, PID_DATA
import csv

In [34]:
class PracticeProblemEnv(gym.Env):
    def __init__(self, params, max_step=10, rew_func="mock", device="cuda"):
        super(PracticeProblemEnv, self).__init__()
        # self.curr_step = None
        self.max_step = max_step
        self.rew_func = rew_func
        # self.curr_q = -1
        # self.curr_pred = -1
        # self.curr_pid = -1

        self.params = params
        self.device = device
        self.kt_model = self._load_model()
        self.p_q_dict = self._load_p_q_dict()

        self.action_space = spaces.Discrete(self.params.n_pid,start=1) #[1,n_pid]
        self.observation_space = spaces.Box(np.array([1,0,1]), np.array([self.params.n_question, 1, self.params.n_pid])) #[1,n_question]/[0,1]/[1,n_pid]
        # self.observation_space = spaces.Tuple([spaces.Discrete(self.params.n_question+1),spaces.Discrete(2),spaces.Discrete(self.params.n_pid+1)])


        self.reset()
    
    def _load_model(self,pretrained_path="_b24_nb1_gn-1_lr1e-05_s224_sl200_do0.05_dm256_ts1_kq1_l21e-05_178"):
        model = AKT(n_question=self.params.n_question, n_pid=self.params.n_pid, n_blocks=self.params.n_block, d_model=self.params.d_model,
                    dropout=self.params.dropout, kq_same=self.params.kq_same, model_type='akt', l2=self.params.l2).to(self.device)
        checkpoint = torch.load(os.path.join( 'model', self.params.model, self.params.save, pretrained_path))
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        return model
    

    def reset(self, seed=None):
        self.history = {'q':[],'target':[],'pid':[]} # Initialize past interactions
        self.curr_step = 0
        self.curr_q = -1
        self.curr_pred = -1
        self.curr_pid = -1
        return self.step(np.random.choice(range(self.params.n_pid)))[0], {}


    def _get_obs(self):
        return np.array([self.curr_q, self.curr_pred, self.curr_pid], dtype=int)
        
    def _rew(self):
        raise NotImplementedError
    
    def step(self, action):#action needs to be a tuple of q and pid like (1,1000)
        self.curr_step += 1
        self.curr_pid = action.item()
        self.curr_q = self.p_q_dict[self.curr_pid]

        correct_prob = self.predict()
        self.curr_pred = np.random.rand() < correct_prob

        if self.rew_func == "mock":
            reward = self._rew()
        elif self.rew_func == "greedy":
            reward = 1 if self.curr_pred else -1
        else:
            raise NotImplementedError
            
        
        # Update history with the action and the correctness

        self.history['q'] += [self.curr_q]
        self.history['target'] += [self.curr_pred]
        self.history['pid'] += [self.curr_pid]
        
        # Recompute the state using the kt_model for each question
        # obs = np.array([self.kt_model.predict(self.history + [(i, 0)]) for i in range(self.num_questions)])
        obs = self._get_obs()

        done = self.curr_step >= self.max_step

        return obs, reward, done, done, {}
    
    def predict(self):
        q = np.array(self.history['q'][-(self.params.seqlen-1):]+[self.curr_q])
        target = np.array(self.history['target'][-(self.params.seqlen-1):]+[0])
        pid = np.array(self.history['pid'][-(self.params.seqlen-1):]+[self.curr_pid])
        qa = q+target*self.params.n_question

        padded_q = np.zeros((1, self.params.seqlen))
        padded_qa = np.zeros((1, self.params.seqlen))
        padded_target = np.full((1,self.params.seqlen),-1)
        padded_pid = np.zeros((1, self.params.seqlen))

        padded_q[0, :len(q)] = q
        padded_qa[0, :len(q)] = qa
        padded_target[0, :len(target)] = target
        padded_pid[0, :len(pid)] = pid

        # target_1 = np.floor(target)
        q = torch.tensor(padded_q).long().to(self.device)
        qa = torch.tensor(padded_qa).long().to(self.device)
        target = torch.tensor(padded_target).long().to(self.device)
        pid = torch.tensor(padded_pid).long().to(self.device)
        
        with torch.no_grad():
            loss, pred, ct = self.kt_model(q,qa,target,pid)

        nopadding_index = np.flatnonzero(padded_target.reshape((-1,)) >= -0.9).tolist()
        pred_nopadding = pred[nopadding_index]
        correct_prob = pred_nopadding[-1].item()
        return correct_prob

    def _load_p_q_dict(self):
        def iterate_over_data(file_path):
            with open(file_path, mode='r') as file:
                reader = csv.reader(file)
                rows = list(reader)

                # Iterate over the rows in groups of 4
                for i in range(0, len(rows), 4):
                    # Extract the question ids and concept ids
                    question_ids = [int(q) for q in rows[i+1] if q]
                    concept_ids = [int(c) for c in rows[i+2] if c]
                    
                    # Build the dictionary mapping question ids to concept ids
                    for question_id, concept_id in zip(question_ids, concept_ids):
                        if question_id not in question_concept_dict:
                            question_concept_dict[question_id] = concept_id

        all_files = os.listdir(self.params.data_dir)

        # Filter the list to include only CSV files
        csv_files = [file for file in all_files if file.endswith('.csv')]

        question_concept_dict = {}
        for f in csv_files:
            old_dict = question_concept_dict.copy()
            iterate_over_data(self.params.data_dir+'/'+ f)
            if old_dict == question_concept_dict:
                break
        
        return question_concept_dict

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open('result/akt_pid/assist2009_pid/args.pkl', 'rb') as f:
    params = pickle.load(f)

env = PracticeProblemEnv(params,max_step=1000, rew_func='greedy',device=device)

# Define and train the TRPO model
model = TRPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=1000)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 999      |
|    ep_rew_mean     | -363     |
| time/              |          |
|    fps             | 92       |
|    iterations      | 1        |
|    time_elapsed    | 22       |
|    total_timesteps | 2048     |
---------------------------------


<sb3_contrib.trpo.trpo.TRPO at 0x182412d3700>

In [13]:
model.save("model_1000_1000")

In [9]:
# Test the trained model
obs, _ = env.reset()
obs_list = []
for _ in range(100):
    action, _states = model.predict(obs)
    obs, rewards, terminated, truncated, info = env.step(action)
    obs_list.append(obs)
    if terminated or truncated:
        obs, _ = env.reset()

In [10]:
print([i[0] for i in obs_list])
print([i[2] for i in obs_list])

[81, 70, 44, 70, 2, 84, 101, 87, 51, 21, 30, 87, 99, 52, 30, 97, 87, 87, 55, 46, 53, 87, 10, 51, 22, 34, 89, 88, 55, 30, 68, 103, 2, 70, 52, 21, 69, 92, 55, 78, 7, 39, 30, 93, 54, 48, 19, 13, 48, 10, 48, 67, 42, 30, 5, 67, 89, 78, 7, 52, 87, 87, 14, 21, 88, 42, 48, 86, 56, 25, 68, 47, 87, 107, 93, 26, 2, 107, 34, 44, 56, 46, 33, 55, 38, 93, 48, 58, 26, 81, 70, 7, 68, 82, 4, 47, 71, 70, 68, 50]
[13447, 12057, 8156, 12166, 220, 13632, 16229, 14588, 9836, 4041, 5814, 14541, 16072, 9849, 5479, 16011, 14339, 14122, 10309, 6779, 9976, 14838, 2263, 9826, 4149, 7000, 15610, 15534, 10375, 5632, 11500, 16306, 223, 11867, 9930, 4019, 11774, 15760, 10340, 13108, 1485, 7654, 5609, 15824, 10978, 9086, 3610, 2575, 8998, 2227, 8939, 11307, 7991, 5904, 977, 11253, 15609, 13099, 1783, 9905, 14250, 14517, 2470, 4042, 15452, 7870, 9143, 13960, 8406, 4423, 11385, 8808, 14967, 16526, 15804, 4627, 184, 16647, 6969, 147, 10725, 8557, 6645, 10449, 7497, 10708, 9378, 11007, 4798, 13418, 12042, 1701, 11455, 1352

In [11]:
import numpy as np
from collections import Counter

Counter([i[0] for i in obs_list])

Counter({87: 8,
         70: 5,
         30: 5,
         48: 5,
         55: 4,
         68: 4,
         2: 3,
         21: 3,
         52: 3,
         7: 3,
         93: 3,
         81: 2,
         44: 2,
         51: 2,
         46: 2,
         10: 2,
         34: 2,
         89: 2,
         88: 2,
         78: 2,
         67: 2,
         42: 2,
         56: 2,
         47: 2,
         107: 2,
         26: 2,
         84: 1,
         101: 1,
         99: 1,
         97: 1,
         53: 1,
         22: 1,
         103: 1,
         69: 1,
         92: 1,
         39: 1,
         54: 1,
         19: 1,
         13: 1,
         5: 1,
         14: 1,
         86: 1,
         25: 1,
         33: 1,
         38: 1,
         58: 1,
         82: 1,
         4: 1,
         71: 1,
         50: 1})

In [24]:
k= 25
[i for i in range(50,0,-1)][-k:]

[25,
 24,
 23,
 22,
 21,
 20,
 19,
 18,
 17,
 16,
 15,
 14,
 13,
 12,
 11,
 10,
 9,
 8,
 7,
 6,
 5,
 4,
 3,
 2,
 1]