In [1]:
import pandas as pd

df = pd.read_csv('data.csv')

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [28]:
class KTDataset(Dataset):
    def __init__(self, features, questions, answers, seq_len):
        super(KTDataset, self).__init__()
        self.features = features
        self.questions = questions
        self.answers = answers
        self.max_length = seq_len

        # Flatten and split long sequences
        self.data = []
        for feat, qst, ans in zip(features, questions, answers):
            for i in range(0, len(feat), self.max_length):
                self.data.append((
                    feat[i:i+self.max_length],
                    qst[i:i+self.max_length],
                    ans[i:i+self.max_length]
                ))

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


problem_ids = [torch.tensor(u_df["problem"].values, dtype=torch.long)
            for _, u_df in df.groupby("user_id")]
interaction_ids = [torch.tensor(u_df["problem_with_answer"].values, dtype=torch.long)
             for _, u_df in df.groupby("user_id")]
answer = [torch.tensor(u_df["correct"].values, dtype=torch.long)
          for _, u_df in df.groupby("user_id")]

interaction_ids = [torch.cat((torch.zeros(1, dtype=torch.long), s))[:-1] for s in interaction_ids]


kt_dataset = KTDataset(problem_ids, interaction_ids, answer, 200)
    
def pad_collate(batch):
    (problem_ids, interaction_ids, answer) = zip(*batch)
    problem_ids = pad_sequence(problem_ids, batch_first=True, padding_value= 0)
    interaction_ids = pad_sequence(interaction_ids, batch_first=True, padding_value=0)
    answer = pad_sequence(answer, batch_first=True, padding_value=-1)
    return problem_ids, interaction_ids, answer
    
# Get the total size of the dataset after splitting sequences
total_size = len(kt_dataset)

train_ratio = 0.7
val_ratio = 0.2

# Calculate the sizes for train, validation, and test splits
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - (train_size + val_size)  # Ensure the sizes sum up to total_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(kt_dataset, [train_size, val_size, test_size])

train_data_loader = DataLoader(train_dataset, batch_size = 128, shuffle = True, collate_fn=pad_collate)
val_data_loader = DataLoader(val_dataset, batch_size = 128, shuffle = False, collate_fn=pad_collate)
test_data_loader = DataLoader(test_dataset, batch_size = 128, shuffle = False, collate_fn=pad_collate)


In [39]:
len(kt_dataset)

5431

In [40]:
len(problem_ids)

4047

In [41]:
problem_ids[0]

tensor([12668, 12692, 12685, 12704, 12705, 12700, 12708,  2993,  3182,  2977,
         3173,  3168, 12032, 11732, 11712, 11715, 12242, 12231, 12213, 12668,
        12692, 12685, 12704, 12705, 12700, 12708,  3182,  3173,  3168, 12708,
         2993,  3182,  2977,  3173,  3168, 12668, 12692, 12685, 12704, 12705,
        12700, 12708,  3168,  2993,  2977])

In [42]:
problem_ids[1]

tensor([  530,   540,   538,   549,  2801,  2811,  2765,   530,   540,   538,
          549,  2801,  2811,  2765, 11354, 11371, 11323, 11350, 11372, 11163,
        11134, 11104, 11207, 11159])

In [43]:
interaction_ids[0]

tensor([    0, 12668, 30440, 12685, 12704, 12705, 12700, 12708,  2993, 20930,
         2977, 20921,  3168, 12032, 11732, 11712, 11715, 29990, 29979, 29961,
        12668, 30440, 12685, 12704, 12705, 12700, 12708, 20930, 20921,  3168,
        12708,  2993, 20930,  2977, 20921,  3168, 12668, 30440, 12685, 12704,
        12705, 12700, 12708,  3168,  2993])

In [44]:
interaction_ids[1]

tensor([    0,   530, 18288, 18286, 18297,  2801, 20559,  2765,   530, 18288,
        18286, 18297,  2801, 20559,  2765, 29102, 29119, 29071, 29098, 29120,
        28911, 28882, 28852, 28955])

In [51]:
len(train_dataset)

3801

In [52]:
len(test_dataset)

544

In [53]:
len(val_dataset)

1086

In [None]:
for batch in train_data_loader:
    problem_id, interaction_id, answer = batch
    print('problem_shape', problem_id.shape)