In [None]:
import gc
import os
import random
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import re

import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pickle
with open('/content/drive/MyDrive/g.pkl', 'rb') as f:
    group = pickle.load(f)

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(0)

In [None]:
valid_group = group.sample(frac=0.2)
train_group = group.drop(valid_group.index).reset_index(drop=True)
valid_group.reset_index(drop=True, inplace=True)
train_group.shape, valid_group.shape

((23238,), (5810,))

In [None]:
from sklearn.model_selection import train_test_split
train_group, valid_group = train_test_split(group, test_size=0.2)
train_group.shape, valid_group.shape

((23238,), (5810,))

In [None]:
class SAINTDataset(Dataset):
    def __init__(self, user_sequences, num_questions, subset='train', max_seq=100, min_seq=10):
        super(SAINTDataset, self).__init__()
        self.max_seq = max_seq
        self.num_questions = num_questions
        self.user_sequences = user_sequences
        self.subset = subset

        self.user_ids = []
        for user_id in user_sequences.index:
            q, _, _ = user_sequences[user_id]
            # if len(q) < min_seq:
            #     continue
            self.user_ids.append(user_id)

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]

        q_, qa_, diff_ = self.user_sequences[user_id]
        seq_len = len(q_)

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)
        diff = np.zeros(self.max_seq, dtype=int)


        #잘라서 붙이는 걸로 바꾸기
        if seq_len >= self.max_seq:
            q[:] = q_[-self.max_seq:]
            qa[:] = qa_[-self.max_seq:]
            diff[:] = diff_[-self.max_seq:]

        # If not, map our user_sequences to the tail end of q and qa, the start will be padded with zeros
        else:
            q[-seq_len:] = q_
            qa[-seq_len:] = qa_
            diff[-seq_len:] = diff_


        r = np.zeros(self.max_seq, dtype=int)   #shifted qa
        r[1:] = qa[:-1].copy()

        return q, r, qa, diff

In [None]:
class FFN(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.layer1 = nn.Linear(dim, dim)
        self.layer2 = nn.Linear(dim, dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.layer2(   self.relu(   self.layer1(x)))


def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)


class Encoder(nn.Module):
    def __init__(self, n_in, seq_len=100, embed_dim=128, nheads=4):
        super().__init__()

        #e : q seq
        #n_in: total num of input

        self.seq_len = seq_len

        #self.part_embed = nn.Embedding(10, embed_dim)

        self.e_embed = nn.Embedding(n_in, embed_dim)
        self.e_pos_embed = nn.Embedding(seq_len, embed_dim)
        self.e_norm = nn.LayerNorm(embed_dim)

        self.e_multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=nheads, dropout=0.2)
        self.m_norm = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim)

    def forward(self, e, first_block=True):

        if first_block:
            e = self.e_embed(e)

        pos = torch.arange(self.seq_len).unsqueeze(0).to(device)
        e_pos = self.e_pos_embed(pos)
        e = e + e_pos
        e = self.e_norm(e)
        e = e.permute(1,0,2) #[bs, s_len, embed] => [s_len, bs, embed]
        n = e.shape[0]

        att_mask = future_mask(n).to(device)
        att_out, _ = self.e_multi_att(e, e, e, attn_mask=att_mask)
        m = e + att_out
        m = m.permute(1,0,2)

        o = m + self.ffn(self.m_norm(m))

        return o

class Decoder(nn.Module):
    def __init__(self, n_in, seq_len=100, embed_dim=128, nheads=4):
        super().__init__()
        self.seq_len = seq_len

        self.r_embed = nn.Embedding(n_in, embed_dim)
        self.r_pos_embed = nn.Embedding(seq_len, embed_dim)
        self.r_norm = nn.LayerNorm(embed_dim)
        self.diff_embed = nn.Embedding(NUM_DIFFS, embed_dim)


        self.r_multi_att1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4, dropout=0.2)
        self.r_multi_att2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4, dropout=0.2)
        self.ffn = FFN(embed_dim)

        self.r_norm1 = nn.LayerNorm(embed_dim)
        self.r_norm2 = nn.LayerNorm(embed_dim)
        self.r_norm3 = nn.LayerNorm(embed_dim)


    def forward(self, r, o, diff,  first_block=True):

        if first_block:
            r = self.r_embed(r)
            diff = self.diff_embed(diff)


            r = r + diff

        pos = torch.arange(self.seq_len).unsqueeze(0).to(device)
        r_pos_embed = self.r_pos_embed(pos)
        r = r + r_pos_embed
        r = self.r_norm1(r)
        r = r.permute(1,0,2)
        n = r.shape[0]

        att_out1, _ = self.r_multi_att1(r, r, r, attn_mask=future_mask(n).to(device))
        m1 = r + att_out1

        o = o.permute(1,0,2)
        o = self.r_norm2(o)
        att_out2, _ = self.r_multi_att2(m1, o, o, attn_mask=future_mask(n).to(device))

        m2 = att_out2 + m1
        m2 = m2.permute(1,0,2)
        m2 = self.r_norm3(m2)

        l = m2 + self.ffn(m2)

        return l


def get_clones(module, N): #모듈 리스트
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class SAINT(nn.Module):
    def __init__(self, dim_model, num_en, num_de, heads_en, total_ex, total_in, heads_de, seq_len):
        super().__init__()

        self.num_en = num_en
        self.num_de = num_de

        self.encoder = get_clones( Encoder(n_in=total_ex, seq_len=seq_len, embed_dim=dim_model, nheads=heads_en) , num_en)
        self.decoder = get_clones( Decoder(n_in=total_in, seq_len=seq_len, embed_dim=dim_model, nheads=heads_de) , num_de)

        self.out = nn.Linear(in_features= dim_model , out_features=1)

    def forward(self, in_ex, in_in, diff):

        ## pass through each of the encoder blocks in sequence
        first_block = True
        for x in range(self.num_en):
            if x>=1:
                first_block = False
            in_ex = self.encoder[x](in_ex, first_block=first_block)

        ## pass through each decoder blocks in sequence
        first_block = True
        for x in range(self.num_de):
            if x>=1:
                first_block = False
            in_in = self.decoder[x]( in_in , in_ex, diff, first_block=first_block )

        ## Output layer
        in_in = torch.sigmoid( self.out( in_in ) )
        return in_in.squeeze(-1)


In [None]:
def train_epoch(model, train_iterator, optim, criterion, device="cpu"):
    model.train()

    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    tbar = tqdm(train_iterator)
    for item in tbar:
        e = item[0].to(device).long()
        r = item[1].to(device).long()
        label = item[2].to(device).float()
        diff = item[3].to(device).long()


        # Zero the gradients in the optimizer
        optim.zero_grad()
        # The results of one forward pass
        output = model(e, r, diff)
        # Calculate the loss
        loss = criterion(output, torch.sigmoid(label))
        # Calculate the gradients with respect to the loss
        loss.backward()
        # Adjust the parameters to minimize the loss based on these gradients
        optim.step()
        # Add our loss to the list of losses
        train_loss.append(loss.item())

        output = output[:, -1]
        label = label[:, -1]
        pred = (output >= 0.5).long()

        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

        tbar.set_description('loss - {:.4f}'.format(loss))

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(train_loss)

    return loss, acc, auc

In [None]:
def valid_epoch(model, valid_iterator, criterion, device="cpu"):
    model.eval()

    valid_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    #tbar = tqdm(valid_iterator)
    for item in valid_iterator: # tbar:
        e = item[0].to(device).long()
        r = item[1].to(device).long()
        label = item[2].to(device).float()
        diff = item[3].to(device).long()


        with torch.no_grad():
            output = model(e, r, diff)
        loss = criterion(output, torch.sigmoid(label))
        valid_loss.append(loss.item())

#####################


         # (BS, 1)
        label = label
        pred = (output >= 0.5).long()

        print(label.shape)
        print(pred.shape)
        print('-----------------------')

        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(valid_loss)

    return loss, acc, auc

In [None]:
gc.collect()
NUM_QUESTIONS = 9985
MAX_SEQ = 100
BS = 64
NUM_DIFFS = 10
train_dataset = SAINTDataset(train_group, NUM_QUESTIONS, max_seq=MAX_SEQ)
train_dataloader = DataLoader(train_dataset, batch_size=BS, shuffle=True, num_workers=8)

valid_dataset = SAINTDataset(valid_group, NUM_QUESTIONS, max_seq=MAX_SEQ, subset='valid')
valid_dataloader = DataLoader(valid_dataset, batch_size=BS, shuffle=False, num_workers=8)



In [None]:
import copy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model2 = SAINT(dim_model=128,
            num_en=2,
            num_de=2,
            heads_en=4,
            heads_de=4,
            total_ex=NUM_QUESTIONS,
            total_in=2,
            seq_len=100
            )

optimizer = torch.optim.Adam(model2.parameters(), lr=1e-3)
criterion = nn.BCELoss()

model2.to(device)
criterion.to(device)

BCELoss()

In [None]:
gc.collect()
epochs = 30
history = []
auc_max = -np.inf

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_auc = train_epoch(model2, train_dataloader, optimizer, criterion, device)
    print(f'Epoch {epoch}, train_loss: {train_loss:5f}, train_acc: {train_acc:5f}, train_auc: {train_auc:5f}')
    valid_loss, valid_acc, valid_auc = valid_epoch(model2, valid_dataloader, criterion, device)
    print(f'Epoch {epoch}, valid_loss: {valid_loss:5f}, valid_acc: {valid_acc:5f}, valid_auc: {valid_auc:5f}')

    lr = optimizer.param_groups[0]['lr']
    history.append({"epoch":epoch, "lr": lr, **{"train_auc": train_auc, "train_acc": train_acc}, **{"valid_auc": valid_auc, "valid_acc": valid_acc}})
    if valid_auc > auc_max:
        print("Epoch#%s, valid loss %.4f, Metric loss improved from %.4f to %.4f, saving model ..." % (epoch, valid_loss, auc_max, valid_auc))
        auc_max = valid_auc

In [None]:
gc.collect()
epochs = 30
history = []
auc_max = -np.inf

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_auc = train_epoch(model2, train_dataloader, optimizer, criterion, device)
    print(f'Epoch {epoch}, train_loss: {train_loss:5f}, train_acc: {train_acc:5f}, train_auc: {train_auc:5f}')
    valid_loss, valid_acc, valid_auc = valid_epoch(model2, valid_dataloader, criterion, device)
    print(f'Epoch {epoch}, valid_loss: {valid_loss:5f}, valid_acc: {valid_acc:5f}, valid_auc: {valid_auc:5f}')

    lr = optimizer.param_groups[0]['lr']
    history.append({"epoch":epoch, "lr": lr, **{"train_auc": train_auc, "train_acc": train_acc}, **{"valid_auc": valid_auc, "valid_acc": valid_acc}})
    if valid_auc > auc_max:
        print("Epoch#%s, valid loss %.4f, Metric loss improved from %.4f to %.4f, saving model ..." % (epoch, valid_loss, auc_max, valid_auc))
        auc_max = valid_auc

loss - 0.6594: 100%|██████████| 361/361 [00:14<00:00, 25.18it/s]


Epoch 1, train_loss: 0.659565, train_acc: 0.716260, train_auc: 0.742201




Epoch 1, valid_loss: 0.658256, valid_acc: 0.716237, valid_auc: 0.784549
Epoch#1, valid loss 0.6583, Metric loss improved from -inf to 0.7845, saving model ...


loss - 0.6546: 100%|██████████| 361/361 [00:14<00:00, 24.43it/s]

Epoch 2, train_loss: 0.657870, train_acc: 0.718685, train_auc: 0.779694





Epoch 2, valid_loss: 0.658252, valid_acc: 0.716237, valid_auc: 0.786600
Epoch#2, valid loss 0.6583, Metric loss improved from 0.7845 to 0.7866, saving model ...


loss - 0.6571: 100%|██████████| 361/361 [00:14<00:00, 25.49it/s]


Epoch 3, train_loss: 0.657817, train_acc: 0.718685, train_auc: 0.781905




Epoch 3, valid_loss: 0.658223, valid_acc: 0.717621, valid_auc: 0.788140
Epoch#3, valid loss 0.6582, Metric loss improved from 0.7866 to 0.7881, saving model ...


loss - 0.6603: 100%|██████████| 361/361 [00:14<00:00, 25.04it/s]


Epoch 4, train_loss: 0.657782, train_acc: 0.718512, train_auc: 0.782859




Epoch 4, valid_loss: 0.658140, valid_acc: 0.716064, valid_auc: 0.788218
Epoch#4, valid loss 0.6581, Metric loss improved from 0.7881 to 0.7882, saving model ...


loss - 0.6623: 100%|██████████| 361/361 [00:14<00:00, 24.91it/s]


Epoch 5, train_loss: 0.657762, train_acc: 0.717559, train_auc: 0.783567




Epoch 5, valid_loss: 0.658169, valid_acc: 0.715891, valid_auc: 0.788994
Epoch#5, valid loss 0.6582, Metric loss improved from 0.7882 to 0.7890, saving model ...


loss - 0.6483: 100%|██████████| 361/361 [00:14<00:00, 25.03it/s]


Epoch 6, train_loss: 0.657690, train_acc: 0.717819, train_auc: 0.787282
Epoch 6, valid_loss: 0.658142, valid_acc: 0.715546, valid_auc: 0.790946
Epoch#6, valid loss 0.6581, Metric loss improved from 0.7890 to 0.7909, saving model ...


loss - 0.6543: 100%|██████████| 361/361 [00:14<00:00, 25.04it/s]


Epoch 7, train_loss: 0.657569, train_acc: 0.717819, train_auc: 0.790883




Epoch 7, valid_loss: 0.657934, valid_acc: 0.717794, valid_auc: 0.810555
Epoch#7, valid loss 0.6579, Metric loss improved from 0.7909 to 0.8106, saving model ...


loss - 0.6569: 100%|██████████| 361/361 [00:14<00:00, 24.54it/s]

Epoch 8, train_loss: 0.657277, train_acc: 0.718512, train_auc: 0.813332





Epoch 8, valid_loss: 0.657631, valid_acc: 0.715373, valid_auc: 0.821301
Epoch#8, valid loss 0.6576, Metric loss improved from 0.8106 to 0.8213, saving model ...


loss - 0.6515: 100%|██████████| 361/361 [00:14<00:00, 24.81it/s]

Epoch 9, train_loss: 0.657030, train_acc: 0.717169, train_auc: 0.824381





Epoch 9, valid_loss: 0.657674, valid_acc: 0.715200, valid_auc: 0.827798
Epoch#9, valid loss 0.6577, Metric loss improved from 0.8213 to 0.8278, saving model ...


loss - 0.6531: 100%|██████████| 361/361 [00:14<00:00, 25.14it/s]


Epoch 10, train_loss: 0.656856, train_acc: 0.717169, train_auc: 0.831677




Epoch 10, valid_loss: 0.657463, valid_acc: 0.715373, valid_auc: 0.831741
Epoch#10, valid loss 0.6575, Metric loss improved from 0.8278 to 0.8317, saving model ...


loss - 0.6588: 100%|██████████| 361/361 [00:14<00:00, 24.67it/s]


Epoch 11, train_loss: 0.656722, train_acc: 0.717039, train_auc: 0.835686
Epoch 11, valid_loss: 0.657453, valid_acc: 0.715546, valid_auc: 0.831284


loss - 0.6533: 100%|██████████| 361/361 [00:14<00:00, 24.74it/s]


Epoch 12, train_loss: 0.656597, train_acc: 0.717429, train_auc: 0.839734




Epoch 12, valid_loss: 0.657470, valid_acc: 0.715200, valid_auc: 0.832005
Epoch#12, valid loss 0.6575, Metric loss improved from 0.8317 to 0.8320, saving model ...


loss - 0.6626: 100%|██████████| 361/361 [00:14<00:00, 24.74it/s]


Epoch 13, train_loss: 0.656486, train_acc: 0.716996, train_auc: 0.844086




Epoch 13, valid_loss: 0.657581, valid_acc: 0.715200, valid_auc: 0.829046


loss - 0.6601: 100%|██████████| 361/361 [00:14<00:00, 24.81it/s]


Epoch 14, train_loss: 0.656356, train_acc: 0.717082, train_auc: 0.848570




Epoch 14, valid_loss: 0.657642, valid_acc: 0.715200, valid_auc: 0.825654


loss - 0.6629: 100%|██████████| 361/361 [00:14<00:00, 24.22it/s]

Epoch 15, train_loss: 0.656216, train_acc: 0.717126, train_auc: 0.852832





Epoch 15, valid_loss: 0.657672, valid_acc: 0.715718, valid_auc: 0.829009


loss - 0.6526: 100%|██████████| 361/361 [00:14<00:00, 24.58it/s]


Epoch 16, train_loss: 0.656071, train_acc: 0.717559, train_auc: 0.854875




Epoch 16, valid_loss: 0.657763, valid_acc: 0.715200, valid_auc: 0.822970


loss - 0.6562: 100%|██████████| 361/361 [00:14<00:00, 25.02it/s]


Epoch 17, train_loss: 0.655935, train_acc: 0.717082, train_auc: 0.860540




Epoch 17, valid_loss: 0.658029, valid_acc: 0.715200, valid_auc: 0.808988


loss - 0.6635: 100%|██████████| 361/361 [00:14<00:00, 24.42it/s]


Epoch 18, train_loss: 0.655805, train_acc: 0.717169, train_auc: 0.864848




Epoch 18, valid_loss: 0.657972, valid_acc: 0.715200, valid_auc: 0.821766


loss - 0.6566: 100%|██████████| 361/361 [00:14<00:00, 24.84it/s]


Epoch 19, train_loss: 0.655671, train_acc: 0.717082, train_auc: 0.869679




Epoch 19, valid_loss: 0.658018, valid_acc: 0.715373, valid_auc: 0.818160


loss - 0.6542: 100%|██████████| 361/361 [00:14<00:00, 24.74it/s]


Epoch 20, train_loss: 0.655531, train_acc: 0.717126, train_auc: 0.873246




Epoch 20, valid_loss: 0.658039, valid_acc: 0.715546, valid_auc: 0.817545


loss - 0.6544: 100%|██████████| 361/361 [00:14<00:00, 24.33it/s]

Epoch 21, train_loss: 0.655403, train_acc: 0.717559, train_auc: 0.876053





Epoch 21, valid_loss: 0.658237, valid_acc: 0.715200, valid_auc: 0.815508


loss - 0.6550: 100%|██████████| 361/361 [00:14<00:00, 24.73it/s]


Epoch 22, train_loss: 0.655255, train_acc: 0.717256, train_auc: 0.879963




Epoch 22, valid_loss: 0.658219, valid_acc: 0.715200, valid_auc: 0.812819


loss - 0.6574: 100%|██████████| 361/361 [00:14<00:00, 24.74it/s]


Epoch 23, train_loss: 0.655142, train_acc: 0.717169, train_auc: 0.882245




Epoch 23, valid_loss: 0.658268, valid_acc: 0.715546, valid_auc: 0.809814


loss - 0.6547: 100%|██████████| 361/361 [00:14<00:00, 24.62it/s]


Epoch 24, train_loss: 0.655022, train_acc: 0.717602, train_auc: 0.887514




Epoch 24, valid_loss: 0.658453, valid_acc: 0.715200, valid_auc: 0.809359


loss - 0.6545: 100%|██████████| 361/361 [00:14<00:00, 24.79it/s]

Epoch 25, train_loss: 0.654904, train_acc: 0.717689, train_auc: 0.890113





Epoch 25, valid_loss: 0.658515, valid_acc: 0.715200, valid_auc: 0.804667


loss - 0.6556: 100%|██████████| 361/361 [00:14<00:00, 24.55it/s]


Epoch 26, train_loss: 0.654789, train_acc: 0.717732, train_auc: 0.891820




Epoch 26, valid_loss: 0.658620, valid_acc: 0.715373, valid_auc: 0.802509


loss - 0.6543: 100%|██████████| 361/361 [00:14<00:00, 24.42it/s]

Epoch 27, train_loss: 0.654685, train_acc: 0.718988, train_auc: 0.895219





Epoch 27, valid_loss: 0.658490, valid_acc: 0.715200, valid_auc: 0.804590


loss - 0.6585: 100%|██████████| 361/361 [00:14<00:00, 24.74it/s]


Epoch 28, train_loss: 0.654595, train_acc: 0.717646, train_auc: 0.898598




Epoch 28, valid_loss: 0.658780, valid_acc: 0.715373, valid_auc: 0.802303


loss - 0.6630: 100%|██████████| 361/361 [00:14<00:00, 24.64it/s]


Epoch 29, train_loss: 0.654488, train_acc: 0.717775, train_auc: 0.900935




Epoch 29, valid_loss: 0.658745, valid_acc: 0.715200, valid_auc: 0.796720


loss - 0.6587: 100%|██████████| 361/361 [00:14<00:00, 24.76it/s]


Epoch 30, train_loss: 0.654396, train_acc: 0.718772, train_auc: 0.903014




Epoch 30, valid_loss: 0.658948, valid_acc: 0.715200, valid_auc: 0.794338


In [None]:
gc.collect()
epochs = 30
history = []
auc_max = -np.inf

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_auc = train_epoch(model, train_dataloader, optimizer, criterion, device)
    print(f'Epoch {epoch}, train_loss: {train_loss:5f}, train_acc: {train_acc:5f}, train_auc: {train_auc:5f}')
    valid_loss, valid_acc, valid_auc = valid_epoch(model, valid_dataloader, criterion, device)
    print(f'Epoch {epoch}, valid_loss: {valid_loss:5f}, valid_acc: {valid_acc:5f}, valid_auc: {valid_auc:5f}')

    lr = optimizer.param_groups[0]['lr']
    history.append({"epoch":epoch, "lr": lr, **{"train_auc": train_auc, "train_acc": train_acc}, **{"valid_auc": valid_auc, "valid_acc": valid_acc}})
    if valid_auc > auc_max:
        print("Epoch#%s, valid loss %.4f, Metric loss improved from %.4f to %.4f, saving model ..." % (epoch, valid_loss, auc_max, valid_auc))
        auc_max = valid_auc

loss - 0.6501: 100%|██████████| 361/361 [00:14<00:00, 24.48it/s]

Epoch 1, train_loss: 0.660059, train_acc: 0.717712, train_auc: 0.747340





Epoch 1, valid_loss: 0.658533, valid_acc: 0.707828, valid_auc: 0.790009
Epoch#1, valid loss 0.6585, Metric loss improved from -inf to 0.7900, saving model ...


loss - 0.6549: 100%|██████████| 361/361 [00:14<00:00, 25.44it/s]


Epoch 2, train_loss: 0.657785, train_acc: 0.719964, train_auc: 0.777030




Epoch 2, valid_loss: 0.658519, valid_acc: 0.709733, valid_auc: 0.793281
Epoch#2, valid loss 0.6585, Metric loss improved from 0.7900 to 0.7933, saving model ...


loss - 0.6607: 100%|██████████| 361/361 [00:14<00:00, 24.50it/s]


Epoch 3, train_loss: 0.657732, train_acc: 0.720310, train_auc: 0.779902




Epoch 3, valid_loss: 0.658444, valid_acc: 0.708001, valid_auc: 0.792895


loss - 0.6489: 100%|██████████| 361/361 [00:14<00:00, 25.00it/s]

Epoch 4, train_loss: 0.657694, train_acc: 0.719877, train_auc: 0.782608





Epoch 4, valid_loss: 0.658430, valid_acc: 0.707828, valid_auc: 0.793850
Epoch#4, valid loss 0.6584, Metric loss improved from 0.7933 to 0.7938, saving model ...


loss - 0.6592: 100%|██████████| 361/361 [00:14<00:00, 25.19it/s]


Epoch 5, train_loss: 0.657667, train_acc: 0.719920, train_auc: 0.783758




Epoch 5, valid_loss: 0.658447, valid_acc: 0.710080, valid_auc: 0.796315
Epoch#5, valid loss 0.6584, Metric loss improved from 0.7938 to 0.7963, saving model ...


loss - 0.6571: 100%|██████████| 361/361 [00:14<00:00, 24.67it/s]


Epoch 6, train_loss: 0.657627, train_acc: 0.719704, train_auc: 0.785411




Epoch 6, valid_loss: 0.658401, valid_acc: 0.708001, valid_auc: 0.798937
Epoch#6, valid loss 0.6584, Metric loss improved from 0.7963 to 0.7989, saving model ...


loss - 0.6572: 100%|██████████| 361/361 [00:14<00:00, 24.62it/s]

Epoch 7, train_loss: 0.657540, train_acc: 0.719964, train_auc: 0.790629





Epoch 7, valid_loss: 0.658219, valid_acc: 0.708001, valid_auc: 0.812512
Epoch#7, valid loss 0.6582, Metric loss improved from 0.7989 to 0.8125, saving model ...


loss - 0.6635: 100%|██████████| 361/361 [00:14<00:00, 24.94it/s]


Epoch 8, train_loss: 0.657317, train_acc: 0.721349, train_auc: 0.804258




Epoch 8, valid_loss: 0.657971, valid_acc: 0.709041, valid_auc: 0.828616
Epoch#8, valid loss 0.6580, Metric loss improved from 0.8125 to 0.8286, saving model ...


loss - 0.6554: 100%|██████████| 361/361 [00:14<00:00, 24.96it/s]


Epoch 9, train_loss: 0.657050, train_acc: 0.720613, train_auc: 0.816371




Epoch 9, valid_loss: 0.657821, valid_acc: 0.707482, valid_auc: 0.836252
Epoch#9, valid loss 0.6578, Metric loss improved from 0.8286 to 0.8363, saving model ...


loss - 0.6518: 100%|██████████| 361/361 [00:14<00:00, 24.72it/s]


Epoch 10, train_loss: 0.656845, train_acc: 0.719271, train_auc: 0.826095




Epoch 10, valid_loss: 0.657739, valid_acc: 0.708001, valid_auc: 0.838126
Epoch#10, valid loss 0.6577, Metric loss improved from 0.8363 to 0.8381, saving model ...


loss - 0.6578: 100%|██████████| 361/361 [00:14<00:00, 25.01it/s]


Epoch 11, train_loss: 0.656714, train_acc: 0.719141, train_auc: 0.831174




Epoch 11, valid_loss: 0.657739, valid_acc: 0.707482, valid_auc: 0.838684
Epoch#11, valid loss 0.6577, Metric loss improved from 0.8381 to 0.8387, saving model ...


loss - 0.6615: 100%|██████████| 361/361 [00:14<00:00, 24.68it/s]


Epoch 12, train_loss: 0.656588, train_acc: 0.718925, train_auc: 0.835280




Epoch 12, valid_loss: 0.657750, valid_acc: 0.707482, valid_auc: 0.838913
Epoch#12, valid loss 0.6577, Metric loss improved from 0.8387 to 0.8389, saving model ...


loss - 0.6542: 100%|██████████| 361/361 [00:14<00:00, 24.72it/s]

Epoch 13, train_loss: 0.656475, train_acc: 0.718881, train_auc: 0.837691





Epoch 13, valid_loss: 0.657928, valid_acc: 0.707482, valid_auc: 0.838244


loss - 0.6528: 100%|██████████| 361/361 [00:14<00:00, 24.83it/s]


Epoch 14, train_loss: 0.656355, train_acc: 0.719098, train_auc: 0.841017




Epoch 14, valid_loss: 0.657795, valid_acc: 0.707482, valid_auc: 0.837618


loss - 0.6533: 100%|██████████| 361/361 [00:14<00:00, 25.09it/s]


Epoch 15, train_loss: 0.656225, train_acc: 0.719011, train_auc: 0.845036




Epoch 15, valid_loss: 0.657852, valid_acc: 0.708348, valid_auc: 0.836491


loss - 0.6590: 100%|██████████| 361/361 [00:14<00:00, 24.94it/s]


Epoch 16, train_loss: 0.656101, train_acc: 0.718881, train_auc: 0.848618




Epoch 16, valid_loss: 0.657965, valid_acc: 0.707482, valid_auc: 0.833324


loss - 0.6546: 100%|██████████| 361/361 [00:14<00:00, 24.96it/s]


Epoch 17, train_loss: 0.655974, train_acc: 0.719141, train_auc: 0.850752




Epoch 17, valid_loss: 0.658027, valid_acc: 0.707482, valid_auc: 0.828560


loss - 0.6471: 100%|██████████| 361/361 [00:14<00:00, 24.86it/s]


Epoch 18, train_loss: 0.655844, train_acc: 0.718838, train_auc: 0.855375




Epoch 18, valid_loss: 0.658090, valid_acc: 0.708348, valid_auc: 0.829843


loss - 0.6570: 100%|██████████| 361/361 [00:14<00:00, 24.80it/s]


Epoch 19, train_loss: 0.655722, train_acc: 0.719141, train_auc: 0.858661




Epoch 19, valid_loss: 0.658169, valid_acc: 0.707482, valid_auc: 0.827045


loss - 0.6617: 100%|██████████| 361/361 [00:14<00:00, 24.59it/s]

Epoch 20, train_loss: 0.655574, train_acc: 0.718925, train_auc: 0.861615





Epoch 20, valid_loss: 0.658202, valid_acc: 0.707482, valid_auc: 0.827223


loss - 0.6497: 100%|██████████| 361/361 [00:14<00:00, 24.49it/s]


Epoch 21, train_loss: 0.655448, train_acc: 0.719054, train_auc: 0.865061




Epoch 21, valid_loss: 0.658285, valid_acc: 0.707482, valid_auc: 0.826157


loss - 0.6548: 100%|██████████| 361/361 [00:14<00:00, 24.77it/s]


Epoch 22, train_loss: 0.655328, train_acc: 0.719054, train_auc: 0.868094




Epoch 22, valid_loss: 0.658600, valid_acc: 0.707482, valid_auc: 0.814354


loss - 0.6573: 100%|██████████| 361/361 [00:14<00:00, 24.25it/s]

Epoch 23, train_loss: 0.655202, train_acc: 0.719098, train_auc: 0.871813





Epoch 23, valid_loss: 0.658552, valid_acc: 0.712504, valid_auc: 0.825801


loss - 0.6462: 100%|██████████| 361/361 [00:14<00:00, 24.57it/s]


Epoch 24, train_loss: 0.655077, train_acc: 0.719184, train_auc: 0.876175




Epoch 24, valid_loss: 0.658531, valid_acc: 0.707828, valid_auc: 0.818900


loss - 0.6465: 100%|██████████| 361/361 [00:14<00:00, 24.35it/s]


Epoch 25, train_loss: 0.654959, train_acc: 0.719444, train_auc: 0.877891




Epoch 25, valid_loss: 0.658718, valid_acc: 0.708001, valid_auc: 0.816058


loss - 0.6551: 100%|██████████| 361/361 [00:18<00:00, 19.35it/s]


Epoch 26, train_loss: 0.654843, train_acc: 0.719141, train_auc: 0.882149




Epoch 26, valid_loss: 0.658784, valid_acc: 0.707482, valid_auc: 0.807330


loss - 0.6507: 100%|██████████| 361/361 [00:18<00:00, 19.00it/s]


Epoch 27, train_loss: 0.654740, train_acc: 0.719487, train_auc: 0.882414




Epoch 27, valid_loss: 0.658789, valid_acc: 0.708694, valid_auc: 0.809889


loss - 0.6574: 100%|██████████| 361/361 [00:15<00:00, 23.38it/s]


Epoch 28, train_loss: 0.654628, train_acc: 0.719661, train_auc: 0.885449




Epoch 28, valid_loss: 0.658729, valid_acc: 0.707655, valid_auc: 0.815323


loss - 0.6579:  33%|███▎      | 119/361 [00:05<00:11, 20.99it/s]


KeyboardInterrupt: ignored

In [None]:
gc.collect()
epochs = 30
history = []
auc_max = -np.inf

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_auc = train_epoch(model, train_dataloader, optimizer, criterion, device)
    print(f'Epoch {epoch}, train_loss: {train_loss:5f}, train_acc: {train_acc:5f}, train_auc: {train_auc:5f}')
    valid_loss, valid_acc, valid_auc = valid_epoch(model, valid_dataloader, criterion, device)
    print(f'Epoch {epoch}, valid_loss: {valid_loss:5f}, valid_acc: {valid_acc:5f}, valid_auc: {valid_auc:5f}')

    lr = optimizer.param_groups[0]['lr']
    history.append({"epoch":epoch, "lr": lr, **{"train_auc": train_auc, "train_acc": train_acc}, **{"valid_auc": valid_auc, "valid_acc": valid_acc}})
    if valid_auc > auc_max:
        print("Epoch#%s, valid loss %.4f, Metric loss improved from %.4f to %.4f, saving model ..." % (epoch, valid_loss, auc_max, valid_auc))
        auc_max = valid_auc

loss - 0.6644: 100%|██████████| 438/438 [00:20<00:00, 21.21it/s]


Epoch 1, train_loss: 0.659503, train_acc: 0.716862, train_auc: 0.755001
Epoch 1, valid_loss: 0.655156, valid_acc: 0.707036, valid_auc: 0.777008
Epoch#1, valid loss 0.6552, Metric loss improved from -inf to 0.7770, saving model ...


loss - 0.6498: 100%|██████████| 438/438 [00:17<00:00, 25.74it/s]


Epoch 2, train_loss: 0.657992, train_acc: 0.718255, train_auc: 0.782300




Epoch 2, valid_loss: 0.655149, valid_acc: 0.705882, valid_auc: 0.782151
Epoch#2, valid loss 0.6551, Metric loss improved from 0.7770 to 0.7822, saving model ...


loss - 0.6480: 100%|██████████| 438/438 [00:16<00:00, 25.78it/s]


Epoch 3, train_loss: 0.657945, train_acc: 0.718397, train_auc: 0.783447




Epoch 3, valid_loss: 0.655146, valid_acc: 0.704729, valid_auc: 0.783513
Epoch#3, valid loss 0.6551, Metric loss improved from 0.7822 to 0.7835, saving model ...


loss - 0.6570: 100%|██████████| 438/438 [00:17<00:00, 24.56it/s]

Epoch 4, train_loss: 0.657885, train_acc: 0.718290, train_auc: 0.786810





Epoch 4, valid_loss: 0.655033, valid_acc: 0.704729, valid_auc: 0.781320


loss - 0.6571: 100%|██████████| 438/438 [00:17<00:00, 25.69it/s]


Epoch 5, train_loss: 0.657769, train_acc: 0.719004, train_auc: 0.796589




Epoch 5, valid_loss: 0.654822, valid_acc: 0.707036, valid_auc: 0.809642
Epoch#5, valid loss 0.6548, Metric loss improved from 0.7835 to 0.8096, saving model ...


loss - 0.6655: 100%|██████████| 438/438 [00:17<00:00, 25.68it/s]


Epoch 6, train_loss: 0.657489, train_acc: 0.719219, train_auc: 0.812744
Epoch 6, valid_loss: 0.654648, valid_acc: 0.704729, valid_auc: 0.827722
Epoch#6, valid loss 0.6546, Metric loss improved from 0.8096 to 0.8277, saving model ...


loss - 0.6524: 100%|██████████| 438/438 [00:17<00:00, 25.76it/s]


Epoch 7, train_loss: 0.657210, train_acc: 0.718005, train_auc: 0.825929




Epoch 7, valid_loss: 0.654304, valid_acc: 0.704729, valid_auc: 0.831776
Epoch#7, valid loss 0.6543, Metric loss improved from 0.8277 to 0.8318, saving model ...


loss - 0.6508: 100%|██████████| 438/438 [00:17<00:00, 25.27it/s]

Epoch 8, train_loss: 0.657043, train_acc: 0.717326, train_auc: 0.832353





Epoch 8, valid_loss: 0.654257, valid_acc: 0.704729, valid_auc: 0.837165
Epoch#8, valid loss 0.6543, Metric loss improved from 0.8318 to 0.8372, saving model ...


loss - 0.6648: 100%|██████████| 438/438 [00:17<00:00, 25.35it/s]


Epoch 9, train_loss: 0.656937, train_acc: 0.717362, train_auc: 0.836767




Epoch 9, valid_loss: 0.654251, valid_acc: 0.704729, valid_auc: 0.837037


loss - 0.6472: 100%|██████████| 438/438 [00:17<00:00, 25.74it/s]


Epoch 10, train_loss: 0.656815, train_acc: 0.717326, train_auc: 0.839654




Epoch 10, valid_loss: 0.654248, valid_acc: 0.704729, valid_auc: 0.838182
Epoch#10, valid loss 0.6542, Metric loss improved from 0.8372 to 0.8382, saving model ...


loss - 0.6492: 100%|██████████| 438/438 [00:17<00:00, 25.71it/s]


Epoch 11, train_loss: 0.656721, train_acc: 0.717219, train_auc: 0.843258




Epoch 11, valid_loss: 0.654309, valid_acc: 0.704729, valid_auc: 0.836072


loss - 0.6543: 100%|██████████| 438/438 [00:17<00:00, 25.31it/s]

Epoch 12, train_loss: 0.656633, train_acc: 0.717290, train_auc: 0.844490





Epoch 12, valid_loss: 0.654367, valid_acc: 0.704729, valid_auc: 0.835327


loss - 0.6642: 100%|██████████| 438/438 [00:17<00:00, 25.38it/s]


Epoch 13, train_loss: 0.656532, train_acc: 0.717040, train_auc: 0.846669




Epoch 13, valid_loss: 0.654432, valid_acc: 0.704729, valid_auc: 0.832089


loss - 0.6511: 100%|██████████| 438/438 [00:17<00:00, 25.67it/s]


Epoch 14, train_loss: 0.656406, train_acc: 0.717255, train_auc: 0.849683




Epoch 14, valid_loss: 0.654508, valid_acc: 0.704729, valid_auc: 0.825855


loss - 0.6534: 100%|██████████| 438/438 [00:17<00:00, 25.64it/s]


Epoch 15, train_loss: 0.656294, train_acc: 0.717290, train_auc: 0.853403




Epoch 15, valid_loss: 0.654459, valid_acc: 0.704729, valid_auc: 0.827831


loss - 0.6607: 100%|██████████| 438/438 [00:18<00:00, 23.85it/s]

Epoch 16, train_loss: 0.656190, train_acc: 0.717219, train_auc: 0.855761





Epoch 16, valid_loss: 0.654670, valid_acc: 0.704729, valid_auc: 0.820965


loss - 0.6653: 100%|██████████| 438/438 [00:18<00:00, 23.41it/s]

Epoch 17, train_loss: 0.656071, train_acc: 0.717469, train_auc: 0.859896





Epoch 17, valid_loss: 0.654596, valid_acc: 0.704729, valid_auc: 0.821329


loss - 0.6512: 100%|██████████| 438/438 [00:17<00:00, 25.51it/s]


Epoch 18, train_loss: 0.655935, train_acc: 0.717255, train_auc: 0.862186




Epoch 18, valid_loss: 0.654752, valid_acc: 0.704729, valid_auc: 0.819027


loss - 0.6532: 100%|██████████| 438/438 [00:17<00:00, 25.27it/s]


Epoch 19, train_loss: 0.655824, train_acc: 0.717112, train_auc: 0.864284




Epoch 19, valid_loss: 0.654845, valid_acc: 0.704729, valid_auc: 0.825325


loss - 0.6555: 100%|██████████| 438/438 [00:18<00:00, 23.45it/s]


Epoch 20, train_loss: 0.655702, train_acc: 0.717255, train_auc: 0.867619




Epoch 20, valid_loss: 0.654916, valid_acc: 0.704729, valid_auc: 0.825996


loss - 0.6586: 100%|██████████| 438/438 [00:17<00:00, 25.45it/s]


Epoch 21, train_loss: 0.655585, train_acc: 0.717683, train_auc: 0.871710




Epoch 21, valid_loss: 0.654934, valid_acc: 0.704729, valid_auc: 0.815236


loss - 0.6575: 100%|██████████| 438/438 [00:17<00:00, 25.36it/s]


Epoch 22, train_loss: 0.655476, train_acc: 0.717219, train_auc: 0.873033
Epoch 22, valid_loss: 0.655050, valid_acc: 0.704729, valid_auc: 0.814865


loss - 0.6575: 100%|██████████| 438/438 [00:17<00:00, 25.19it/s]

Epoch 23, train_loss: 0.655382, train_acc: 0.717540, train_auc: 0.876742





Epoch 23, valid_loss: 0.655178, valid_acc: 0.704729, valid_auc: 0.808881


loss - 0.6553: 100%|██████████| 438/438 [00:17<00:00, 24.47it/s]


Epoch 24, train_loss: 0.655260, train_acc: 0.717897, train_auc: 0.879440




Epoch 24, valid_loss: 0.655178, valid_acc: 0.704729, valid_auc: 0.807091


loss - 0.6585: 100%|██████████| 438/438 [00:17<00:00, 25.21it/s]


Epoch 25, train_loss: 0.655171, train_acc: 0.717505, train_auc: 0.879618




Epoch 25, valid_loss: 0.655134, valid_acc: 0.704729, valid_auc: 0.805717


loss - 0.6459: 100%|██████████| 438/438 [00:17<00:00, 25.65it/s]


Epoch 26, train_loss: 0.655069, train_acc: 0.717540, train_auc: 0.882619




Epoch 26, valid_loss: 0.655218, valid_acc: 0.704729, valid_auc: 0.807155


loss - 0.6558: 100%|██████████| 438/438 [00:17<00:00, 25.30it/s]

Epoch 27, train_loss: 0.654970, train_acc: 0.718040, train_auc: 0.886221





Epoch 27, valid_loss: 0.655350, valid_acc: 0.704729, valid_auc: 0.804911


loss - 0.6567: 100%|██████████| 438/438 [00:17<00:00, 24.81it/s]


Epoch 28, train_loss: 0.654893, train_acc: 0.717755, train_auc: 0.887669




Epoch 28, valid_loss: 0.655356, valid_acc: 0.704729, valid_auc: 0.805576


loss - 0.6625: 100%|██████████| 438/438 [00:17<00:00, 25.07it/s]


Epoch 29, train_loss: 0.654807, train_acc: 0.717719, train_auc: 0.890259




Epoch 29, valid_loss: 0.655563, valid_acc: 0.704729, valid_auc: 0.806414


loss - 0.6551: 100%|██████████| 438/438 [00:17<00:00, 24.85it/s]


Epoch 30, train_loss: 0.654724, train_acc: 0.718147, train_auc: 0.893241




Epoch 30, valid_loss: 0.655556, valid_acc: 0.704729, valid_auc: 0.799886


In [None]:
class SAINT2Dataset(Dataset):
    def __init__(self, user_sequences, num_questions, subset='train', max_seq=100, min_seq=10):
        super(SAINT2Dataset, self).__init__()
        self.max_seq = max_seq
        self.num_questions = num_questions
        self.user_sequences = user_sequences
        self.subset = subset

        self.user_ids = []
        for user_id in user_sequences.index:
            q, _ = user_sequences[user_id]
            if len(q) < min_seq:        #10문제이하 제외 (나중에 없애기)
                continue
            self.user_ids.append(user_id)

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]

        q_, qa_ = self.user_sequences[user_id]
        seq_len = len(q_)

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)
        #diff = np.zeros(self.max_seq, dtype=int)


        # If there are more questions answered than max_seq, take the last max_seq sequences
        if seq_len >= self.max_seq:
            q[:] = q_[-self.max_seq:]
            qa[:] = qa_[-self.max_seq:]
            #diff[:] = diff_[-self.max_seq:]

        # If not, map our user_sequences to the tail end of q and qa, the start will be padded with zeros
        else:
            q[-seq_len:] = q_
            qa[-seq_len:] = qa_
            #diff[-seq_len:] = diff_


        r = np.zeros(self.max_seq, dtype=int)   #shifted qa
        r[1:] = qa[:-1].copy()

        return q, r, qa#, diff

In [None]:
class FFN(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.layer1 = nn.Linear(dim, dim)
        self.layer2 = nn.Linear(dim, dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.layer2(   self.relu(   self.layer1(x)))


def future_mask(seq_length):    #다음 시점 마스킹
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)


class Encoder(nn.Module):
    def __init__(self, n_in, seq_len=100, embed_dim=128, nheads=4):
        super().__init__()

        #n_in: 입력 총개수

        self.seq_len = seq_len

        #self.part_embed = nn.Embedding(10, embed_dim)

        self.e_embed = nn.Embedding(n_in, embed_dim)
        self.e_pos_embed = nn.Embedding(seq_len, embed_dim)
        self.e_norm = nn.LayerNorm(embed_dim)

        self.e_multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=nheads, dropout=0.2)
        self.m_norm = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim)

    def forward(self, e, first_block=True):
        #e : q seq
        if first_block:
            e = self.e_embed(e)

        pos = torch.arange(self.seq_len).unsqueeze(0).to(device)
        e_pos = self.e_pos_embed(pos)
        e = e + e_pos
        e = self.e_norm(e)
        e = e.permute(1,0,2) #[bs, s_len, embed] => [s_len, bs, embed]
        n = e.shape[0]

        att_mask = future_mask(n).to(device)
        att_out, _ = self.e_multi_att(e, e, e, attn_mask=att_mask)
        m = e + att_out
        m = m.permute(1,0,2)

        o = m + self.ffn(self.m_norm(m))

        return o

class Decoder(nn.Module):
    def __init__(self, n_in, seq_len=100, embed_dim=128, nheads=4):
        super().__init__()
        self.seq_len = seq_len

        self.r_embed = nn.Embedding(n_in, embed_dim)    #r: 이전 시점 정답여부
        self.r_pos_embed = nn.Embedding(seq_len, embed_dim)
        self.r_norm = nn.LayerNorm(embed_dim)
        #self.diff_embed = nn.Embedding(NUM_DIFFS, embed_dim)


        self.r_multi_att1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4, dropout=0.2)
        self.r_multi_att2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4, dropout=0.2)
        self.ffn = FFN(embed_dim)

        self.r_norm1 = nn.LayerNorm(embed_dim)
        self.r_norm2 = nn.LayerNorm(embed_dim)
        self.r_norm3 = nn.LayerNorm(embed_dim)


    def forward(self, r, o,  first_block=True):

        if first_block:
            r = self.r_embed(r)
            #diff = self.diff_embed(diff)


            #r = r + diff

        pos = torch.arange(self.seq_len).unsqueeze(0).to(device)
        r_pos_embed = self.r_pos_embed(pos)
        r = r + r_pos_embed
        r = self.r_norm1(r)
        r = r.permute(1,0,2)
        n = r.shape[0]

        att_out1, _ = self.r_multi_att1(r, r, r, attn_mask=future_mask(n).to(device))
        m1 = r + att_out1

        o = o.permute(1,0,2)
        o = self.r_norm2(o)
        att_out2, _ = self.r_multi_att2(m1, o, o, attn_mask=future_mask(n).to(device))

        m2 = att_out2 + m1
        m2 = m2.permute(1,0,2)
        m2 = self.r_norm3(m2)

        l = m2 + self.ffn(m2)

        return l


def get_clones(module, N): #모듈 리스트
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class SAINT2(nn.Module):
    def __init__(self, dim_model, num_en, num_de, heads_en, total_ex, total_in, heads_de, seq_len):
        super().__init__()

        self.num_en = num_en
        self.num_de = num_de

        self.encoder = get_clones( Encoder(n_in=total_ex, seq_len=seq_len, embed_dim=dim_model, nheads=heads_en) , num_en)
        self.decoder = get_clones( Decoder(n_in=total_in, seq_len=seq_len, embed_dim=dim_model, nheads=heads_de) , num_de)

        self.out = nn.Linear(in_features= dim_model , out_features=1)

    def forward(self, in_ex, in_in):

        ## pass through each of the encoder blocks in sequence
        first_block = True
        for x in range(self.num_en):
            if x>=1:
                first_block = False
            in_ex = self.encoder[x](in_ex, first_block=first_block)

        ## pass through each decoder blocks in sequence
        first_block = True
        for x in range(self.num_de):
            if x>=1:
                first_block = False
            in_in = self.decoder[x]( in_in , in_ex, first_block=first_block )

        ## Output layer
        in_in = torch.sigmoid( self.out( in_in ) )
        return in_in.squeeze(-1)


In [None]:
import pickle
with open('/content/drive/MyDrive/group_nm5.pkl', 'rb') as f:
    group = pickle.load(f)

In [None]:
valid_group = group.sample(frac=0.2)
train_group = group.drop(valid_group.index).reset_index(drop=True)
valid_group.reset_index(drop=True, inplace=True)
train_group.shape, valid_group.shape

((23238,), (5810,))

In [None]:
gc.collect()
NUM_QUESTIONS = 9985
MAX_SEQ = 100
BS = 64
NUM_DIFFS = 10
train_dataset = SAINT2Dataset(train_group, NUM_QUESTIONS, max_seq=MAX_SEQ)
train_dataloader = DataLoader(train_dataset, batch_size=BS, shuffle=True, num_workers=8)

valid_dataset = SAINT2Dataset(valid_group, NUM_QUESTIONS, max_seq=MAX_SEQ, subset='valid')
valid_dataloader = DataLoader(valid_dataset, batch_size=BS, shuffle=False, num_workers=8)



In [None]:
def train_epoch(model, train_iterator, optim, criterion, device="cpu"):
    model.train()

    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    tbar = tqdm(train_iterator)
    for item in tbar:
        e = item[0].to(device).long()
        r = item[1].to(device).long()
        label = item[2].to(device).float()
        #diff = item[3].to(device).long()


        # Zero the gradients in the optimizer
        optim.zero_grad()
        # The results of one forward pass
        output = model(e, r)
        # Calculate the loss
        loss = criterion(output, torch.sigmoid(label))
        # Calculate the gradients with respect to the loss
        loss.backward()
        # Adjust the parameters to minimize the loss based on these gradients
        optim.step()
        # Add our loss to the list of losses
        train_loss.append(loss.item())

        output = output[:, -1]
        label = label[:, -1]
        pred = (output >= 0.5).long()

        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

        tbar.set_description('loss - {:.4f}'.format(loss))

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(train_loss)

    return loss, acc, auc

In [None]:
def valid_epoch(model, valid_iterator, criterion, device="cpu"):
    model.eval()

    valid_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    #tbar = tqdm(valid_iterator)
    for item in valid_iterator: # tbar:
        e = item[0].to(device).long()
        r = item[1].to(device).long()
        label = item[2].to(device).float()
        #diff = item[3].to(device).long()


        with torch.no_grad():
            output = model(e, r)
        loss = criterion(output, torch.sigmoid(label))
        valid_loss.append(loss.item())

        output = output[:, -1] # (BS, 1)
        label = label[:, -1]
        pred = (output >= 0.5).long()

        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(valid_loss)

    return loss, acc, auc

In [None]:
import copy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAINT2(dim_model=128,
            num_en=2,
            num_de=2,
            heads_en=4,
            heads_de=4,
            total_ex=NUM_QUESTIONS,
            total_in=2,
            seq_len=100
            )

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

model.to(device)
criterion.to(device)

BCELoss()

In [None]:
gc.collect()
epochs = 30
history = []
auc_max = -np.inf

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_auc = train_epoch(model, train_dataloader, optimizer, criterion, device)
    print(f'Epoch {epoch}, train_loss: {train_loss:5f}, train_acc: {train_acc:5f}, train_auc: {train_auc:5f}')
    valid_loss, valid_acc, valid_auc = valid_epoch(model, valid_dataloader, criterion, device)
    print(f'Epoch {epoch}, valid_loss: {valid_loss:5f}, valid_acc: {valid_acc:5f}, valid_auc: {valid_auc:5f}')

    lr = optimizer.param_groups[0]['lr']
    history.append({"epoch":epoch, "lr": lr, **{"train_auc": train_auc, "train_acc": train_acc}, **{"valid_auc": valid_auc, "valid_acc": valid_acc}})
    if valid_auc > auc_max:
        print("Epoch#%s, valid loss %.4f, Metric loss improved from %.4f to %.4f, saving model ..." % (epoch, valid_loss, auc_max, valid_auc))
        auc_max = valid_auc

loss - 0.6575: 100%|██████████| 361/361 [00:13<00:00, 26.89it/s]

Epoch 1, train_loss: 0.660846, train_acc: 0.715752, train_auc: 0.709335





Epoch 1, valid_loss: 0.658191, valid_acc: 0.718264, valid_auc: 0.735711
Epoch#1, valid loss 0.6582, Metric loss improved from -inf to 0.7357, saving model ...


loss - 0.6588: 100%|██████████| 361/361 [00:13<00:00, 26.83it/s]

Epoch 2, train_loss: 0.658537, train_acc: 0.716705, train_auc: 0.744364





Epoch 2, valid_loss: 0.658051, valid_acc: 0.717918, valid_auc: 0.740152
Epoch#2, valid loss 0.6581, Metric loss improved from 0.7357 to 0.7402, saving model ...


loss - 0.6558: 100%|██████████| 361/361 [00:13<00:00, 26.76it/s]


Epoch 3, train_loss: 0.658452, train_acc: 0.717138, train_auc: 0.746978




Epoch 3, valid_loss: 0.658027, valid_acc: 0.718609, valid_auc: 0.741820
Epoch#3, valid loss 0.6580, Metric loss improved from 0.7402 to 0.7418, saving model ...


loss - 0.6589: 100%|██████████| 361/361 [00:13<00:00, 26.74it/s]


Epoch 4, train_loss: 0.658324, train_acc: 0.717614, train_auc: 0.755173




Epoch 4, valid_loss: 0.657708, valid_acc: 0.721031, valid_auc: 0.781726
Epoch#4, valid loss 0.6577, Metric loss improved from 0.7418 to 0.7817, saving model ...


loss - 0.6587: 100%|██████████| 361/361 [00:13<00:00, 26.49it/s]


Epoch 5, train_loss: 0.657897, train_acc: 0.721512, train_auc: 0.795426




Epoch 5, valid_loss: 0.657189, valid_acc: 0.727430, valid_auc: 0.814843
Epoch#5, valid loss 0.6572, Metric loss improved from 0.7817 to 0.8148, saving model ...


loss - 0.6567: 100%|██████████| 361/361 [00:13<00:00, 26.56it/s]


Epoch 6, train_loss: 0.657401, train_acc: 0.720906, train_auc: 0.819355




Epoch 6, valid_loss: 0.656911, valid_acc: 0.718437, valid_auc: 0.825753
Epoch#6, valid loss 0.6569, Metric loss improved from 0.8148 to 0.8258, saving model ...


loss - 0.6548: 100%|██████████| 361/361 [00:13<00:00, 26.40it/s]


Epoch 7, train_loss: 0.657162, train_acc: 0.718177, train_auc: 0.828699




Epoch 7, valid_loss: 0.656832, valid_acc: 0.717918, valid_auc: 0.827948
Epoch#7, valid loss 0.6568, Metric loss improved from 0.8258 to 0.8279, saving model ...


loss - 0.6605: 100%|██████████| 361/361 [00:13<00:00, 25.84it/s]

Epoch 8, train_loss: 0.657020, train_acc: 0.717311, train_auc: 0.832608





Epoch 8, valid_loss: 0.656781, valid_acc: 0.718437, valid_auc: 0.828014
Epoch#8, valid loss 0.6568, Metric loss improved from 0.8279 to 0.8280, saving model ...


loss - 0.6561: 100%|██████████| 361/361 [00:13<00:00, 26.32it/s]


Epoch 9, train_loss: 0.656925, train_acc: 0.717961, train_auc: 0.836333




Epoch 9, valid_loss: 0.656810, valid_acc: 0.717918, valid_auc: 0.826090


loss - 0.6583: 100%|██████████| 361/361 [00:13<00:00, 26.06it/s]

Epoch 10, train_loss: 0.656831, train_acc: 0.717095, train_auc: 0.841416





Epoch 10, valid_loss: 0.656824, valid_acc: 0.718955, valid_auc: 0.826297


loss - 0.6545: 100%|██████████| 361/361 [00:13<00:00, 25.94it/s]

Epoch 11, train_loss: 0.656732, train_acc: 0.717355, train_auc: 0.844239





Epoch 11, valid_loss: 0.656839, valid_acc: 0.717918, valid_auc: 0.823113


loss - 0.6623: 100%|██████████| 361/361 [00:13<00:00, 26.40it/s]

Epoch 12, train_loss: 0.656646, train_acc: 0.717311, train_auc: 0.846466





Epoch 12, valid_loss: 0.656975, valid_acc: 0.717918, valid_auc: 0.823582


loss - 0.6595: 100%|██████████| 361/361 [00:13<00:00, 26.34it/s]


Epoch 13, train_loss: 0.656533, train_acc: 0.716705, train_auc: 0.849248
Epoch 13, valid_loss: 0.657024, valid_acc: 0.719301, valid_auc: 0.820576


loss - 0.6559: 100%|██████████| 361/361 [00:13<00:00, 26.41it/s]

Epoch 14, train_loss: 0.656411, train_acc: 0.717528, train_auc: 0.852764





Epoch 14, valid_loss: 0.656950, valid_acc: 0.718782, valid_auc: 0.821983


loss - 0.6626: 100%|██████████| 361/361 [00:13<00:00, 26.09it/s]


Epoch 15, train_loss: 0.656289, train_acc: 0.716792, train_auc: 0.854818




Epoch 15, valid_loss: 0.657116, valid_acc: 0.718437, valid_auc: 0.817957


loss - 0.6555: 100%|██████████| 361/361 [00:13<00:00, 26.35it/s]


Epoch 16, train_loss: 0.656164, train_acc: 0.717008, train_auc: 0.858775




Epoch 16, valid_loss: 0.657110, valid_acc: 0.718091, valid_auc: 0.816707


loss - 0.6611: 100%|██████████| 361/361 [00:13<00:00, 25.88it/s]


Epoch 17, train_loss: 0.656044, train_acc: 0.717095, train_auc: 0.863186




Epoch 17, valid_loss: 0.657300, valid_acc: 0.719647, valid_auc: 0.812082


loss - 0.6542: 100%|██████████| 361/361 [00:18<00:00, 20.02it/s]

Epoch 18, train_loss: 0.655907, train_acc: 0.717311, train_auc: 0.865654





Epoch 18, valid_loss: 0.657465, valid_acc: 0.718437, valid_auc: 0.809116


loss - 0.6566:  24%|██▍       | 86/361 [00:03<00:12, 22.21it/s]


KeyboardInterrupt: ignored