<a href="https://colab.research.google.com/github/IzhanAli08/Knowledge_Tracing/blob/main/LinSAKT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch.utils.data
import torch.nn.utils
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional

In [None]:
class SAKTDataset(torch.utils.data.Dataset):
    def __init__(self, df, n_skill, max_len=200):
        super(SAKTDataset, self).__init__()
        self.df = df
        self.n_skill = n_skill
        self.max_len = max_len


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        qids = self.df[0][idx].split(",")
        qids = [q.strip('"') for q in qids]

        correct = str(self.df[1][idx]).split(",")
        correct = [c.strip('"') for c in correct if c]

        if len(qids) > self.max_len:
            qids = qids[-self.max_len :]
            correct = correct[-self.max_len :]

        #qids = np.array(list(map(int, qids)))
        #correct = np.array(list(map(int, correct)))

        # Filter out empty strings and 'nan' before converting to integers
        qids = np.array(list(map(int, [q for q in qids if q and q != 'nan'])))
        correct = np.array(list(map(int, [c for c in correct if c and c != 'nan'])))

        # adding for 2009 Ensure correct and qids have the same length before combining, pad or truncate if necessary
        min_len = min(len(qids), len(correct))
        qids = qids[:min_len]
        correct = correct[:min_len]


        qa = qids + correct * self.n_skill

         # --- Add assertions here to check the range of generated indices ---
        # Check qids values for q_embedding
        assert np.all(qids >= 0) and np.all(qids < self.n_skill), \
            f"Invalid qid index found at index {idx}. Values must be between 0 and {self.n_skill - 1}. Found: {qids}"

        # Check qa values for qa_embedding
        # The qa_embedding is defined with size 2*n_skill + 2, with padding_idx = 2*n_skill + 1.
        # Valid indices are 0 to 2*n_skill.
        valid_qa_max = 2 * self.n_skill
        assert np.all(qa >= 0) and np.all(qa <= valid_qa_max), \
            f"Invalid qa index found at index {idx}. Values must be between 0 and {valid_qa_max}. Found: {qa}"
        # --- End of assertions ---

        q = np.ones(self.max_len, dtype=int) * self.n_skill
        qa2 = np.ones(self.max_len, dtype=int) * (self.n_skill * 2 + 1)
        correct2 = np.ones(self.max_len, dtype=int) * -1
        mask = np.zeros(self.max_len, dtype=int)

        q[: len(qids)] = qids
        qa2[: len(qa)] = qa
        correct2[: len(correct)] = correct
        mask[: len(qa)] = np.ones(len(qa), dtype=int)

        return (
            torch.cat(
                (torch.LongTensor([2 * self.n_skill]), torch.LongTensor(qa2[:-1]))
            ),
            torch.LongTensor(q),
            torch.LongTensor(correct2),
            torch.LongTensor(mask),
        )



def collate_fn(data, n_skill):
    qa = [x[0] for x in data]
    qid = [x[1] for x in data]
    qc = [x[2] for x in data]
    mask = [x[3] for x in data]
    qa = torch.nn.utils.rnn.pad_sequence(
        qa, batch_first=True, padding_value=n_skill * 2
    )
    qid = torch.nn.utils.rnn.pad_sequence(qid, batch_first=True, padding_value=n_skill)
    qc = torch.nn.utils.rnn.pad_sequence(qc, batch_first=True, padding_value=-1)
    mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True, padding_value=0)

    return qa, qid, qc, mask

In [None]:
class FFN(nn.Module):
    def __init__(self, state_size=200, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        self.dropout = dropout
        self.lr1 = nn.Linear(self.state_size, self.state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(self.state_size, self.state_size)
        self.dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalLinearAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def _feature_map(self, x):
        return F.elu(x) + 1  # Positive kernel

    def forward(self, query, key, value):
        B, T, _ = query.shape
        H = self.num_heads
        D = self.head_dim

        # Linear projections
        #Q = self.q_proj(query).reshape(B, T, H, D)
        Q = self.q_proj(query).view(B, T, H, D).contiguous()
        #K = self.k_proj(key).reshape(B, T, H, D)
        K = self.k_proj(key).view(B, T, H, D).contiguous()
        #V = self.v_proj(value).reshape(B, T, H, D)
        V = self.v_proj(value).view(B, T, H, D).contiguous()

        Q = self._feature_map(Q)
        K = self._feature_map(K)

        # Causal cumulative KV
        K_cum = torch.cumsum(K, dim=1)  # B, T, H, D
        KV_cum = torch.cumsum(K * V, dim=1)  # B, T, H, D

        #Z = 1 / (torch.einsum("bthd,bthd->bth", Q, K_cum) + 1e-6)
        Z = 1 / ((Q * K_cum).sum(dim=-1) + 1e-6)
        #context = torch.einsum("bthd,bthd->bthd", Q, KV_cum)
        context = Q * KV_cum
        output = context * Z.unsqueeze(-1)

        output = output.reshape(B, T, self.embed_dim)
        return self.out_proj(output)


In [None]:
class CausalLinearAttention_new(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def _feature_map(self, x):
        return F.relu(x) + 1  # Revert to original ELU

    def forward(self, query, key, value):
        B, T, _ = query.shape
        H, D = self.num_heads, self.head_dim
        Q = self._feature_map(self.q_proj(query).view(B, T, H, D))
        K = self._feature_map(self.k_proj(key).view(B, T, H, D))
        V = self.v_proj(value).view(B, T, H, D)
        window_size = min(32, T)  # Sliding window
        K_cum = torch.cumsum(K[:, -window_size:], dim=1)
        KV_cum = torch.cumsum((K * V)[:, -window_size:], dim=1)
        output = Q[:, -window_size:] * KV_cum / ((Q[:, -window_size:] * K_cum).sum(dim=-1, keepdim=True) + 1e-6)
        output = F.pad(output, (0, 0, 0, 0, T - window_size, 0), value=0)
        output = output.view(B, T, self.embed_dim)
        return self.out_proj(output)

In [None]:
class SAKTLoss(nn.Module):
    def __init__(self, reduce="mean"):
        super(SAKTLoss, self).__init__()
        self.reduce = reduce

    def forward(self, logits, targets, qid, mask, device="cpu"):

        mask = mask.gt(0).view(-1)
        targets = targets.view(-1)

        logits = torch.masked_select(logits.view(-1), mask)
        targets = torch.masked_select(targets, mask)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, targets.float(), reduction=self.reduce
        )
        return loss
def dkt_predict(logits, qid):
    preds = torch.sigmoid(logits)
    preds = torch.gather(preds, dim=2, index=qid)
    preds = torch.squeeze(preds)
    binary_preds = torch.round(preds)
    return (
        preds.view(preds.size()[0], preds.size()[1]),
        binary_preds.view(preds.size()[0], preds.size()[1]),
    )

In [None]:
from sklearn.metrics import (
    roc_auc_score,
    precision_recall_fscore_support,
    accuracy_score,
)
import numpy as np
import os
import random


def seed_everything(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def train_epoch(model, train_iterator, optim, criterion, device="cpu"):
    model.train()

    for i, (qa, qid, labels, mask) in enumerate(train_iterator):
        qa, qid, labels, mask = (
            qa.to(device),
            qid.to(device),
            labels.to(device),
            mask.to(device),
        )

        optim.zero_grad()
        logits, _ = model(qid, qa)
        loss = criterion(logits, labels, qid, mask, device=device)
        loss.backward()
        optim.step()


def eval_epoch(model, test_iterator, criterion, eval_func, device="cpu"):
    model.eval()

    eval_loss = []
    preds, binary_preds, targets = [], [], []
    for i, (qa, qid, labels, mask) in enumerate(test_iterator):
        qa, qid, labels, mask = (
            qa.to(device),
            qid.to(device),
            labels.to(device),
            mask.to(device),
        )

        with torch.no_grad():
            logits, _ = model(qid, qa)

        loss = criterion(logits, labels, qid, mask, device=device)
        eval_loss.append(loss.detach().item())

        mask = mask.eq(1)

        # pred, binary_pred = deepkt.loss.dkt_predict(logits, qid)
        # pred = torch.masked_select(pred, mask).detach().numpy()
        # binary_pred = torch.masked_select(binary_pred, mask).detach().numpy()
        # target = torch.masked_select(labels, mask).detach().numpy()
        # pred = pred.cpu().detach().numpy().reshape(-1)
        # binary_pred = binary_pred.cpu().detach().numpy().reshape(-1)
        pred, binary_pred, target = eval_func(logits, qid, labels, mask)
        preds.append(pred)
        binary_preds.append(binary_pred)
        targets.append(target)

    preds = np.concatenate(preds)
    binary_preds = np.concatenate(binary_preds)
    targets = np.concatenate(targets)

    auc_value = roc_auc_score(targets, preds)
    accuracy = accuracy_score(targets, binary_preds)
    precision, recall, f_score, _ = precision_recall_fscore_support(
        targets, binary_preds
    )
    pos_rate = np.sum(targets) / float(len(targets))
    print(
        "auc={0}, accuracy={1}, precision={2}, recall={3}, fscore={4}, pos_rate={5}".format(
            auc_value, accuracy, precision, recall, f_score, pos_rate
        )
    )


def dkt_eval(logits, qid, targets, mask):
    pred, binary_pred = dkt_predict(logits, qid)
    pred = torch.masked_select(pred, mask).detach().numpy()
    binary_pred = torch.masked_select(binary_pred, mask).detach().numpy()
    target = torch.masked_select(targets, mask).detach().numpy()
    return pred, binary_pred, target


def deepirt_eval(logits, qid, targets, mask):
    mask = mask.gt(0).view(-1)
    targets = targets.view(-1)

    logits = torch.masked_select(logits, mask)

    pred = torch.sigmoid(logits).detach().numpy()
    binary_pred = pred.round()
    target = torch.masked_select(targets, mask).detach().numpy()
    return pred, binary_pred, target


def sakt_eval(logits, qid, targets, mask):
    mask = mask.gt(0).view(-1)
    targets = targets.view(-1)

    logits = torch.masked_select(logits.view(-1), mask)

    pred = torch.sigmoid(logits).cpu().detach().numpy()
    binary_pred = pred.round()
    #target = torch.masked_select(targets, mask).detach().numpy()
    target = torch.masked_select(targets, mask).cpu().detach().numpy() # Moved targets to CPU as well for consistency
    return pred, binary_pred, target


def future_mask(seq_length):
    mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype("bool")
    return torch.from_numpy(mask)


In [None]:
class SAKTModel(nn.Module):
    def __init__(
        self, n_skill, embed_dim, dropout, num_heads=4, max_len=64, device="cpu"
    ):
        super(SAKTModel, self).__init__()
        self.n_skill = n_skill
        self.q_embed_dim = embed_dim
        self.qa_embed_dim = embed_dim
        self.pos_embed_dim = embed_dim
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.num_heads = num_heads
        self.max_len = max_len
        self.device = device

        self.q_embedding = nn.Embedding(
            n_skill + 1, self.q_embed_dim, padding_idx=n_skill
        )
        self.qa_embedding = nn.Embedding(
            2 * n_skill + 2, self.qa_embed_dim, padding_idx=2 * n_skill + 1
        )
        self.pos_embedding = nn.Embedding(self.max_len, self.pos_embed_dim)

        #self.multi_attention = nn.MultiheadAttention(
            #embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
        #)
        self.linear_attention = CausalLinearAttention(embed_dim=embed_dim, num_heads=num_heads)

        self.layer_norm1 = nn.LayerNorm(self.embed_dim)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim)
        self.dropout_layer = nn.Dropout(self.dropout)
        self.ffn = FFN(self.embed_dim)
        self.pred = nn.Linear(self.embed_dim, 1, bias=True)

    def forward(self, q, qa):
        qa = self.qa_embedding(qa)
        pos_id = torch.arange(qa.size(1), device=self.device).unsqueeze(0)
        qa += self.pos_embedding(pos_id)
        q = self.q_embedding(q)

        attention_out = self.linear_attention(q, qa, qa)
        attention_out = self.layer_norm1(attention_out + q)

        x = self.ffn(attention_out)
        x = self.dropout_layer(x)
        x = self.layer_norm2(x + attention_out)
        x = self.pred(x)

        return x.squeeze(-1), None


In [None]:
import sys
import time

sys.path.insert(0, "..")

import argparse
import torch
import torch.optim
from torch.utils.data import DataLoader
import pandas as pd
import logging

logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)


def run(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_df = pd.read_csv("/content/assist2015_train.csv",
                           header=None,
                           sep='\t')
    test_df = pd.read_csv("/content/assist2015_test.csv", header=None, sep='\t')

    train = SAKTDataset(train_df, args.num_skill, max_len=64)
    test = SAKTDataset(test_df, args.num_skill, max_len=64)
    train_dataloader = DataLoader(train,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_worker,
                                  shuffle=True)
    test_dataloader = DataLoader(test,
                                 batch_size=args.batch_size * 2,
                                 num_workers=args.num_worker,
                                 shuffle=False)

    sakt = SAKTModel(args.num_skill, args.embed_dim, args.dropout, args.num_heads, device=device, max_len=64)

    optimizer = torch.optim.Adam(sakt.parameters(), lr=args.learning_rate)
    loss_func = SAKTLoss()

    sakt.to(device)
    loss_func.to(device)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    start_time = time.time()
    for epoch in range(args.epoch):
        train_epoch(sakt, train_dataloader, optimizer, loss_func,
                                 device)
        eval_epoch(sakt, test_dataloader, loss_func, sakt_eval, device)
        scheduler.step()
    print("total time: ", time.time() - start_time, "seconds")
    print('average time per epoch', (time.time() - start_time) / args.epoch)

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="train deep IRT model")
    # Provide default values for the arguments
    arg_parser.add_argument("--learning_rate", dest="learning_rate", default=0.001, type=float)
    arg_parser.add_argument("--batch_size", dest="batch_size", default=64, type=int)
    arg_parser.add_argument("--num_skill", dest="num_skill", default=230, type=int)
    arg_parser.add_argument("--embed_dim", dest="embed_dim", default=256, type=int)
    arg_parser.add_argument("--dropout", dest="dropout", default=0.2, type=float)
    arg_parser.add_argument("--num_heads", dest="num_heads", default=4, type=int)
    arg_parser.add_argument("--epoch", dest="epoch", default=10, type=int)
    arg_parser.add_argument("--num_worker", dest="num_worker", default=0, type=int)

    args = arg_parser.parse_args([]) #remove any other code that is being passed to parse_args such as empty lists.
    run(args)

auc=0.7081011881813667, accuracy=0.7581930499048187, precision=[0.58038924 0.77266718], recall=[0.17206909 0.95766316], fscore=[0.26544207 0.85527587], pos_rate=0.7460901048859692
auc=0.7174931992088027, accuracy=0.7593688178866037, precision=[0.57366187 0.77776524], recall=[0.20363837 0.94849538], fscore=[0.30057772 0.85468756], pos_rate=0.7460901048859692
auc=0.7208239491474595, accuracy=0.7603206300623344, precision=[0.58456249 0.77646795], recall=[0.19371555 0.95314806], fscore=[0.29099843 0.85578408], pos_rate=0.7460901048859692
auc=0.7235542647147063, accuracy=0.7600873427643612, precision=[0.57597245 0.77876907], recall=[0.20896729 0.9476449 ], fscore=[0.3066717  0.85494736], pos_rate=0.7460901048859692
auc=0.722532045185445, accuracy=0.7602553096189019, precision=[0.57949309 0.77793599], recall=[0.20334436 0.94978363], fscore=[0.30105011 0.8553134 ], pos_rate=0.7460901048859692
auc=0.722784855472345, accuracy=0.7600406853047665, precision=[0.58131187 0.77681151], recall=[0.1963

In [None]:
import torch
import gc
import torch.nn as nn

def measure_memory(model, input_q, input_qa):
    gc.collect()
    torch.cuda.empty_cache()
    model.cuda()
    input_q = input_q.cuda()
    input_qa = input_qa.cuda()

    torch.cuda.reset_peak_memory_stats()

    with torch.no_grad():
        _ = model(input_q, input_qa)

    max_mem = torch.cuda.max_memory_allocated() / (1024 ** 2)  # in MB
    print(f"Max memory used: {max_mem:.2f} MB")
    return max_mem

# Define parameters for the SAKTModel instance
n_skill = 230  # Example value, adjust based on your data
embed_dim = 256
dropout = 0.001
num_heads = 4
max_len = 512 # Max sequence length
batch_size = 64

# Create an instance of SAKTModel
# You also need to provide device if not default
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sakt_model_instance = SAKTModel(n_skill, embed_dim, dropout, num_heads, max_len=max_len, device=device)

# Create dummy input tensors with appropriate dimensions (batch_size, sequence_length)
# The SAKTModel forward method expects q and qa as inputs.
# q is typically the question sequence, qa is the question-answer sequence.
# The values should be within the range of your embedding layers' vocabulary size.
# For q, values should be < n_skill + 1.
# For qa, values should be < 2 * n_skill + 2.
dummy_input_q = torch.randint(0, n_skill, (batch_size, max_len), dtype=torch.long)
dummy_input_qa = torch.randint(0, 2 * n_skill + 1, (batch_size, max_len), dtype=torch.long)


# Pass the instance and dummy inputs to the measure_memory function
mem_linear = measure_memory(sakt_model_instance, dummy_input_q, dummy_input_qa)

# Correct the typo in print
print(mem_linear)

Max memory used: 340.94 MB
340.94189453125
