In [100]:
import pickle
import torch.nn as nn
import torch.optim as optim
import torch
import numpy as np

from sklearn.metrics import precision_score, recall_score

In [2]:
with open("data.pkl",'rb') as fp:
    data = pickle.load(fp)

In [51]:
skills = []
subtests = []
questions = []
answers = []
y = []
for entry in data:
    skills.append(entry[0])
    subtests.append(entry[1])
    questions.append(entry[2])
    answers.append(entry[3])
    y.append(entry[4])

skills = torch.tensor(skills).type(torch.float)
subtests = torch.tensor(subtests)
questions = torch.tensor(questions)
answers = torch.tensor(answers)
y = torch.tensor(y).type(torch.float)


In [57]:
class BertModel(nn.Module):
    def __init__(self, sentence_dim, skill_dim, dropout):
        super().__init__()
        self.fc_test = nn.Linear(768,sentence_dim)
        self.fc_question = nn.Linear(768,sentence_dim)
        self.fc_answer = nn.Linear(768,sentence_dim)
        self.fc_skill = nn.Linear(skill_dim,skill_dim*2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(3*sentence_dim+skill_dim*2,128)
        self.out = nn.Linear(128,1)
        self.sig = nn.Sigmoid()
    
    def forward(self, skills,test,question,answer):
        x1 = self.fc_skill(skills)
        x2 = self.fc_test(test)
        x3 = self.fc_question(question)
        x4 = self.fc_answer(answer)
        x = torch.cat((x1,x2,x3,x4),dim=1)
        x = self.fc2(self.relu(x))
        pred = self.sig(self.out(x))
        
        return pred

In [13]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

In [97]:

def precision(preds,y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    return precision_score(y,rounded_preds)

def recall(preds,y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    return recall_score(y,rounded_preds)
    

In [95]:
# 3. train model
max_epochs = 1
ep_log_interval = 25
lrn_rate = 0.002
sentence_dim = 128
skill_dim = 9
dropout = 0.1

net = BertModel(sentence_dim, skill_dim, dropout)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=lrn_rate)

In [83]:
def train(model, data, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
        
    optimizer.zero_grad()
    
    predictions = model(data[0],data[1],data[2],data[3]).squeeze(1)
    
#     print(predictions)

    loss = criterion(predictions, data[4])
    

    acc = binary_accuracy(predictions, data[4])
    

    loss.backward()

    optimizer.step()

#     epoch_loss += loss.item()
#     epoch_acc += acc.item()
#     print(loss)
#     print(acc)
    return loss,acc 


In [102]:
def evaluate(model, data, criterion):
    
    model.eval()
    
    with torch.no_grad():

        predictions = model(data[0],data[1],data[2],data[3]).squeeze(1)
        
        print(predictions)

        loss = criterion(predictions, data[4])

        acc = binary_accuracy(predictions, data[4])
        
        print('precision: ', precision(predictions, data[4]))
        print('recall: ', recall(predictions, data[4]))

        
    return loss, acc

In [103]:
for epoch in range(max_epochs):
    train_loss, train_acc = train(net,[skills[0:4000],subtests[0:4000],questions[0:4000],answers[0:4000],y[0:4000]],optimizer,criterion)
    valid_loss, valid_acc = evaluate(net,[skills[4000:],subtests[4000:],questions[4000:],answers[4000:],y[4000:]],criterion)
    if epoch % ep_log_interval == 0:
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

tensor([0.0090, 0.0082, 0.0183,  ..., 0.0099, 0.0101, 0.0093])
precision:  0.6587389380530974
recall:  1.0
	Train Loss: 0.692 | Train Acc: 51.25%
	 Val. Loss: 0.690 |  Val. Acc: 65.87%
