<a href="https://colab.research.google.com/github/IzhanAli08/Knowledge_Tracing/blob/main/SAINT_test1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch.utils.data
import torch.nn.utils
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional

In [2]:
class SAKTDataset(torch.utils.data.Dataset):
    def __init__(self, df, n_skill, max_len=200):
        super(SAKTDataset, self).__init__()
        self.df = df
        self.n_skill = n_skill
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        qids = self.df[0][idx].split(",")
        qids = [q.strip('"') for q in qids]

        correct = str(self.df[1][idx]).split(",")
        correct = [c.strip('"') for c in correct if c]

        if len(qids) > self.max_len:
            qids = qids[-self.max_len :]
            correct = correct[-self.max_len :]

        #qids = np.array(list(map(int, qids)))
        #correct = np.array(list(map(int, correct)))

        # Filter out empty strings and 'nan' before converting to integers
        qids = np.array(list(map(int, [q for q in qids if q and q != 'nan'])))
        correct = np.array(list(map(int, [c for c in correct if c and c != 'nan'])))


        qa = qids + correct * self.n_skill

        q = np.ones(self.max_len, dtype=int) * self.n_skill
        qa2 = np.ones(self.max_len, dtype=int) * (self.n_skill * 2 + 1)
        correct2 = np.ones(self.max_len, dtype=int) * -1
        mask = np.zeros(self.max_len, dtype=int)

        q[: len(qids)] = qids
        qa2[: len(qa)] = qa
        correct2[: len(correct)] = correct
        mask[: len(qa)] = np.ones(len(qa), dtype=int)

        return (
            torch.cat(
                (torch.LongTensor([2 * self.n_skill]), torch.LongTensor(qa2[:-1]))
            ),
            torch.LongTensor(q),
            torch.LongTensor(correct2),
            torch.LongTensor(mask),
        )


def collate_fn(data, n_skill):
    qa = [x[0] for x in data]
    qid = [x[1] for x in data]
    qc = [x[2] for x in data]
    mask = [x[3] for x in data]
    qa = torch.nn.utils.rnn.pad_sequence(
        qa, batch_first=True, padding_value=n_skill * 2
    )
    qid = torch.nn.utils.rnn.pad_sequence(qid, batch_first=True, padding_value=n_skill)
    qc = torch.nn.utils.rnn.pad_sequence(qc, batch_first=True, padding_value=-1)
    mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True, padding_value=0)

    return qa, qid, qc, mask

In [3]:
class FFN(nn.Module):
    def __init__(self, state_size=200, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        self.dropout = dropout
        self.lr1 = nn.Linear(self.state_size, self.state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(self.state_size, self.state_size)
        self.dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

In [4]:
class SAKTLoss(nn.Module):
    def __init__(self, reduce="mean"):
        super(SAKTLoss, self).__init__()
        self.reduce = reduce

    def forward(self, logits, targets, qid, mask, device="cpu"):

        mask = mask.gt(0).view(-1)
        targets = targets.view(-1)

        logits = torch.masked_select(logits.view(-1), mask)
        targets = torch.masked_select(targets, mask)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, targets.float(), reduction=self.reduce
        )
        return loss
def dkt_predict(logits, qid):
    preds = torch.sigmoid(logits)
    preds = torch.gather(preds, dim=2, index=qid)
    preds = torch.squeeze(preds)
    binary_preds = torch.round(preds)
    return (
        preds.view(preds.size()[0], preds.size()[1]),
        binary_preds.view(preds.size()[0], preds.size()[1]),
    )

In [5]:
from sklearn.metrics import (
    roc_auc_score,
    precision_recall_fscore_support,
    accuracy_score,
)
import numpy as np
import os
import random


def seed_everything(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def train_epoch(model, train_iterator, optim, criterion, device="cpu"):
    model.train()

    for i, (qa, qid, labels, mask) in enumerate(train_iterator):
        qa, qid, labels, mask = (
            qa.to(device),
            qid.to(device),
            labels.to(device),
            mask.to(device),
        )

        optim.zero_grad()
        logits, _ = model(qid, qa)
        loss = criterion(logits, labels, qid, mask, device=device)
        loss.backward()
        optim.step()


def eval_epoch(model, test_iterator, criterion, eval_func, device="cpu"):
    model.eval()

    eval_loss = []
    preds, binary_preds, targets = [], [], []
    for i, (qa, qid, labels, mask) in enumerate(test_iterator):
        qa, qid, labels, mask = (
            qa.to(device),
            qid.to(device),
            labels.to(device),
            mask.to(device),
        )

        with torch.no_grad():
            logits, _ = model(qid, qa)

        loss = criterion(logits, labels, qid, mask, device=device)
        eval_loss.append(loss.detach().item())

        mask = mask.eq(1)

        # pred, binary_pred = deepkt.loss.dkt_predict(logits, qid)
        # pred = torch.masked_select(pred, mask).detach().numpy()
        # binary_pred = torch.masked_select(binary_pred, mask).detach().numpy()
        # target = torch.masked_select(labels, mask).detach().numpy()
        # pred = pred.cpu().detach().numpy().reshape(-1)
        # binary_pred = binary_pred.cpu().detach().numpy().reshape(-1)
        pred, binary_pred, target = eval_func(logits, qid, labels, mask)
        preds.append(pred)
        binary_preds.append(binary_pred)
        targets.append(target)

    preds = np.concatenate(preds)
    binary_preds = np.concatenate(binary_preds)
    targets = np.concatenate(targets)

    auc_value = roc_auc_score(targets, preds)
    accuracy = accuracy_score(targets, binary_preds)
    precision, recall, f_score, _ = precision_recall_fscore_support(
        targets, binary_preds,zero_division=0
    )
    pos_rate = np.sum(targets) / float(len(targets))
    print(
        "auc={0}, accuracy={1}, precision={2}, recall={3}, fscore={4}, pos_rate={5}".format(
            auc_value, accuracy, precision, recall, f_score, pos_rate
        )
    )


def dkt_eval(logits, qid, targets, mask):
    pred, binary_pred = dkt_predict(logits, qid)
    pred = torch.masked_select(pred, mask).detach().numpy()
    binary_pred = torch.masked_select(binary_pred, mask).detach().numpy()
    target = torch.masked_select(targets, mask).detach().numpy()
    return pred, binary_pred, target


def deepirt_eval(logits, qid, targets, mask):
    mask = mask.gt(0).view(-1)
    targets = targets.view(-1)

    logits = torch.masked_select(logits, mask)

    pred = torch.sigmoid(logits).detach().numpy()
    binary_pred = pred.round()
    target = torch.masked_select(targets, mask).detach().numpy()
    return pred, binary_pred, target


def sakt_eval(logits, qid, targets, mask):
    mask = mask.gt(0).view(-1)
    targets = targets.view(-1)

    logits = torch.masked_select(logits.view(-1), mask)

    pred = torch.sigmoid(logits).cpu().detach().numpy()
    binary_pred = pred.round()
    #target = torch.masked_select(targets, mask).detach().numpy()
    target = torch.masked_select(targets, mask).cpu().detach().numpy() # Moved targets to CPU as well for consistency
    return pred, binary_pred, target


def future_mask(seq_length):
    mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype("bool")
    return torch.from_numpy(mask)


In [6]:
class SaintEncoder(nn.Module):
    def __init__(self, embed_dim, dropout=0.3, num_heads=4):
        super(SaintEncoder, self).__init__()
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.num_heads = num_heads
        self.attn = nn.MultiheadAttention(
            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
        )
        self.layer_norm1 = nn.LayerNorm(self.embed_dim)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim)
        self.ffn = FFN(self.embed_dim)
        self.dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        device = x.device
        x = self.layer_norm1(x)
        encoder, _ = self.attn(
            x, x, x, attn_mask=future_mask(x.size(0)).to(device)
        )
        encoder = encoder + x

        encoder = encoder.permute(1, 0, 2)
        encoder_out = self.layer_norm2(encoder)

        return self.ffn(encoder_out) + encoder_out


class SaintDecoder(nn.Module):
    def __init__(self, embed_dim, dropout=0.3, num_heads=4):
        super(SaintDecoder, self).__init__()
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.num_heads = num_heads

        self.attn1 = nn.MultiheadAttention(
            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
        )
        self.attn2 = nn.MultiheadAttention(
            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
        )

        self.ffn = FFN(self.embed_dim)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim)
        self.layer_norm3 = nn.LayerNorm(self.embed_dim)

    def forward(self, decoder_in, encoder_out):
        device = decoder_in.device
        x = self.layer_norm1(decoder_in)
        decoder, _ = self.attn1(
            x, x, x, attn_mask=future_mask(x.size(0)).to(device)
        )
        decoder = decoder + x

        # Reshape encoder_out before passing to attn2
        encoder_out = encoder_out.permute(1, 0, 2) # shape: [batch_size, seq_len, embed_dim]
        encoder_out = self.layer_norm2(encoder_out)

        # Reshape decoder to match encoder_out's shape before passing to attn2
        decoder = decoder.permute(1, 0, 2) # shape: [batch_size, seq_len, embed_dim]

        encoder_out2 = self.attn2(decoder, encoder_out, encoder_out)[0] # Pass reshaped decoder and encoder_out

        decoder_out = decoder + encoder_out2

        decoder_out = decoder_out.permute(1, 0, 2)
        decoder_out = self.layer_norm3(decoder_out)
        return decoder_out + self.ffn(decoder_out)

        #encoder_out = encoder_out.permute(1, 0, 2)
        #encoder_out = self.layer_norm2(encoder_out)
        #encoder_out2 = self.attn2(decoder, encoder_out, encoder_out)[0]
        #decoder_out = decoder + encoder_out2

        #decoder_out = decoder_out.permute(1, 0, 2)
        #decoder_out = self.layer_norm3(decoder_out)
        #return decoder_out + self.ffn(decoder_out)


class SaintModel(nn.Module):
    def __init__(
        self,
        n_skill,
        embed_dim,
        dropout,
        num_heads=4,
        num_enc=4,
        max_len=64,
        device="cpu",
    ):
        super(SaintModel, self).__init__()
        self.n_skill = n_skill
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.num_heads = num_heads
        self.num_enc = num_enc
        self.max_len = max_len
        self.device = device

        self.q_embedding = nn.Embedding(
            n_skill + 1, self.embed_dim, padding_idx=n_skill
        )
        self.qa_embedding = nn.Embedding(
            2 * n_skill + 2, self.embed_dim, padding_idx=2 * n_skill + 1
        )
        self.pos_embedding = nn.Embedding(self.max_len, self.embed_dim)

        self.encoders = nn.ModuleList(
            [
                SaintEncoder(self.embed_dim, self.dropout, self.num_heads)
                for x in range(self.num_enc)
            ]
        )
        self.decoders = nn.ModuleList(
            [
                SaintDecoder(self.embed_dim, self.dropout, self.num_heads)
                for x in range(self.num_enc)
            ]
        )

        self.fc = nn.Linear(self.embed_dim, 1)

    def forward(self, q, qa):
        qa = self.qa_embedding(qa)
        pos_id = torch.arange(qa.size(1)).unsqueeze(0).to(self.device)
        pos_x = self.pos_embedding(pos_id)
        qa = qa + pos_x
        q = self.q_embedding(q)

        q = q.permute(1, 0, 2)
        qa = qa.permute(1, 0, 2)

        for x in range(self.num_enc):
            q = self.encoders[x](q)

        for x in range(self.num_enc):
            qa = self.decoders[x](qa, q)

        logits = self.fc(qa)
        return logits, None


In [7]:
import sys

sys.path.insert(0, "..")

import argparse
import torch
import torch.optim
from torch.utils.data import DataLoader
import pandas as pd
import logging
# Import the collate_fn
from __main__ import collate_fn
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)


def run(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_df = pd.read_csv("/content/assist2009_train.csv",
                           header=None,
                           sep=',')
    test_df = pd.read_csv("/content/assist2009_test.csv", header=None, sep=',')

    train = SAKTDataset(train_df, args.num_skill, max_len=100)
    test = SAKTDataset(test_df, args.num_skill, max_len=100)
    train_dataloader = DataLoader(train,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_worker,
                                  shuffle=True)
    test_dataloader = DataLoader(test,
                                 batch_size=args.batch_size * 2,
                                 num_workers=args.num_worker,
                                 shuffle=False)

    saint = SaintModel(args.num_skill, args.embed_dim, args.dropout, args.num_heads,
                       args.num_enc, device=device, max_len=100)

    optimizer = torch.optim.Adam(saint.parameters(), lr=args.learning_rate)
    loss_func = SAKTLoss()

    saint.to(device)
    loss_func.to(device)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    for epoch in range(args.epoch):
        train_epoch(saint, train_dataloader, optimizer, loss_func,
                                 device)
        eval_epoch(saint, test_dataloader, loss_func, sakt_eval, device)
        scheduler.step()


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="train deep IRT model")
    # Provide default values for the arguments
    arg_parser.add_argument("--learning_rate", dest="learning_rate", default=0.001, type=float)
    arg_parser.add_argument("--batch_size", dest="batch_size", default=64, type=int)
    arg_parser.add_argument("--num_skill", dest="num_skill", default=150, type=int)
    arg_parser.add_argument("--embed_dim", dest="embed_dim", default=200, type=int)
    arg_parser.add_argument("--dropout", dest="dropout", default=0.2, type=float)
    arg_parser.add_argument("--num_heads", dest="num_heads", default=5, type=int)
    arg_parser.add_argument("--epoch", dest="epoch", default=10, type=int)
    arg_parser.add_argument("--num_worker", dest="num_worker", default=0, type=int)
    # Define the num_enc argument with a default value
    arg_parser.add_argument("--num_enc", dest="num_enc", default=4, type=int) # This line is added to define the num_enc argument

    args = arg_parser.parse_args([]) #remove any other code that is being passed to parse_args such as empty lists.
    run(args)

auc=0.5018910084607004, accuracy=0.6307840616966581, precision=[0.         0.63078406], recall=[0. 1.], fscore=[0.         0.77359606], pos_rate=0.6307840616966581
auc=0.4970294973990096, accuracy=0.6307840616966581, precision=[0.         0.63078406], recall=[0. 1.], fscore=[0.         0.77359606], pos_rate=0.6307840616966581
auc=0.49744180968456037, accuracy=0.6307840616966581, precision=[0.         0.63078406], recall=[0. 1.], fscore=[0.         0.77359606], pos_rate=0.6307840616966581
auc=0.4912536760353751, accuracy=0.6307840616966581, precision=[0.         0.63078406], recall=[0. 1.], fscore=[0.         0.77359606], pos_rate=0.6307840616966581
auc=0.4906633223778279, accuracy=0.6307840616966581, precision=[0.         0.63078406], recall=[0. 1.], fscore=[0.         0.77359606], pos_rate=0.6307840616966581
auc=0.4900040146540415, accuracy=0.6307840616966581, precision=[0.         0.63078406], recall=[0. 1.], fscore=[0.         0.77359606], pos_rate=0.6307840616966581
auc=0.490994091

In [8]:
import torch
import gc
import torch.nn as nn

def measure_memory(model, input_q, input_qa):
    gc.collect()
    torch.cuda.empty_cache()
    model.cuda()
    input_q = input_q.cuda()
    input_qa = input_qa.cuda()

    torch.cuda.reset_peak_memory_stats()

    with torch.no_grad():
        _ = model(input_q, input_qa)

    max_mem = torch.cuda.max_memory_allocated() / (1024 ** 2)  # in MB
    print(f"Max memory used: {max_mem:.2f} MB")
    return max_mem

# Define parameters for the SaintModel instance
n_skill = 150  # Example value, adjust based on your data
embed_dim = 100
dropout = 0.001
num_heads = 10
max_len = 128 # Max sequence length
batch_size = 64 # Use a small batch size for memory measurement

# Create an instance of SaintModel
# You also need to provide device if not default
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sakt_model_instance = SaintModel(n_skill, embed_dim, dropout, num_heads, max_len=max_len, device=device)

# Create dummy input tensors with appropriate dimensions (batch_size, sequence_length)
# The SAKTModel forward method expects q and qa as inputs.
# q is typically the question sequence, qa is the question-answer sequence.
# The values should be within the range of your embedding layers' vocabulary size.
# For q, values should be < n_skill + 1.
# For qa, values should be < 2 * n_skill + 2.
dummy_input_q = torch.randint(0, n_skill, (batch_size, max_len), dtype=torch.long)
dummy_input_qa = torch.randint(0, 2 * n_skill + 1, (batch_size, max_len), dtype=torch.long)


# Pass the instance and dummy inputs to the measure_memory function
mem_dot = measure_memory(sakt_model_instance, dummy_input_q, dummy_input_qa)

# Correct the typo in print
print(mem_dot)

Max memory used: 131.85 MB
131.8505859375
