In [1]:
from datasets import load_dataset #type: ignore
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#code for the model, the architecture of the model is defined here
#input features: q_type, wpm, idle, latency, mouse_movement
#hidden layer 1: 10 nodes
#hidden layer 2: 10 nodes
#hidden layer 3: 5 nodes
#output logits layer: 3 nodes, where classes are low risk, mild risk, high risk

In [6]:
g = torch.Generator().manual_seed(2147)   #setting the seed for reproducibility  

In [16]:
#input will be a tensor of shape (batch_size, 5): X
X = torch.randn((80, 5) , generator=g) #random input tensor, for now
y = torch.randint(0, 3, (80,)) #random target tensor, for now
#output will be a tensor of shape (batch_size, 3): y

In [26]:
print(y)

tensor([1, 2, 0, 0, 2, 1, 2, 1, 1, 0, 2, 1, 0, 2, 2, 0, 0, 1, 2, 2, 2, 2, 1, 0,
        0, 2, 1, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 2, 1, 1, 2, 0, 2, 1, 1, 1,
        0, 2, 0, 1, 1, 2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 2, 1, 2, 0, 2, 0, 1, 1, 2,
        2, 0, 2, 1, 0, 1, 2, 2])


In [None]:
#random initialization of weights and biases
W1 = torch.randn((5, 10), generator=g, requires_grad=True)
b1 = torch.randn((10, ), generator=g, requires_grad=True)

W2 = torch.randn((10, 10), generator=g, requires_grad=True)
b2 = torch.randn((10, ), generator=g, requires_grad=True)

W3 = torch.randn((10, 5), generator=g, requires_grad=True)
b3 = torch.randn((5, ), generator=g, requires_grad=True)

W4 = torch.randn((5, 3), generator=g, requires_grad=True)
b4 = torch.randn((3, ), generator=g, requires_grad=True)


In [8]:
parameters = [W1, b1, W2, b2, W3, b3]

In [9]:
#total num of params for bookkeeping
sum(p.nelement() for p in parameters) 

225

In [19]:
def train(iters, alpha=0.01):
    """
    function to train the model, uses stochastic gradient descent
    inputs: iters - number of iterations to train the model
    """

    for step in range(iters):
        #forward pass
        l1 = X@W1 + b1  #shape: (batch_size, 10)
        a1 = F.relu(l1) #shape: (batch_size, 10)

        l2 = a1@W2 + b2  #shape: (batch_size, 10)
        a2 = F.relu(l2)

        l3 = a2@W3 + b3  #shape: (batch_size, 5)
        a3 = F.relu(l3)

        logits = a3@W4 + b4  #shape: (batch_size, 3)
        neg_log_loss = F.cross_entropy(logits, y)

        print(f"loss at iter {step+1}: {neg_log_loss.item()}")

        #backward pass
        for p in parameters:
            p.grad = None  #resetting the gradients, to avoid gradient accumulation
        neg_log_loss.backward()

        #update params
        for p in parameters:
            p.data -= alpha*p.grad  #SGD update        

In [50]:
train(10)

loss at iter 1: 1.3485989570617676
loss at iter 2: 1.3458372354507446
loss at iter 3: 1.3430922031402588
loss at iter 4: 1.340301752090454
loss at iter 5: 1.3374783992767334
loss at iter 6: 1.3344026803970337
loss at iter 7: 1.3313614130020142
loss at iter 8: 1.3283437490463257
loss at iter 9: 1.3253614902496338
loss at iter 10: 1.3223849534988403


In [51]:
@torch.no_grad()
def out_logits(X):
        l1 = X@W1 + b1
        a1 = F.relu(l1)

        l2 = a1@W2 + b2  
        a2 = F.relu(l2)

        l3 = a2@W3 + b3 
        a3 = F.relu(l3)

        logits = a3@W4 + b4  
        return logits

In [52]:
logits_1 = out_logits(X)
probs = F.softmax(logits_1, dim=1)  #final probabilities of the classes

max_probs, max_indices = torch.max(probs, dim=1)  #get the max prob class
print(max_indices)

tensor([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])


In [43]:
print(y)

tensor([1, 2, 0, 0, 2, 1, 2, 1, 1, 0, 2, 1, 0, 2, 2, 0, 0, 1, 2, 2, 2, 2, 1, 0,
        0, 2, 1, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 2, 1, 1, 2, 0, 2, 1, 1, 1,
        0, 2, 0, 1, 1, 2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 2, 1, 2, 0, 2, 0, 1, 1, 2,
        2, 0, 2, 1, 0, 1, 2, 2])
