# Imports

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from typing import List, Dict, Tuple

# Macros

In [2]:
DEVICE = 'cuda'
DATASET_PATH = 'dataset/gpt_ahh/dataset.pt'

# Dataset

In [3]:
class ModSeekDataset(Dataset):
    def __init__(self, qs: pd.DataFrame, sols = pd.DataFrame):
        self.qs_cols = qs.columns
        self.sols_cols = sols.columns
        
        self.num_qs = len(self.qs_cols)
        self.num_sols = len(self.sols_cols)
        
        self.qs_inst_tensor = torch.tensor(qs.to_numpy())
        self.sols_inst_tesor = torch.tensor(sols.to_numpy())
        
    def __getitem__(self, index):
        return self.qs_inst_tensor[index], self.sols_inst_tesor[index]
    
    def __len__(self):
        return len(self.qs_inst_tensor)

In [4]:
dataset: ModSeekDataset = torch.load(DATASET_PATH)
DATASET_SIZE = len(dataset)
TRAIN_RATIO = 0.9
TEST_RATIO = 0.1

train_len = int(DATASET_SIZE * TRAIN_RATIO)
test_len = DATASET_SIZE - train_len

torch.manual_seed(196)
train_dataset, test_dataset = random_split(dataset, (train_len, test_len))
train_loader = DataLoader(train_dataset, shuffle=True)
test_loader = DataLoader(test_dataset)

print(len(train_loader))
print(len(test_loader))

90
10


  dataset: ModSeekDataset = torch.load(DATASET_PATH)


# Model

In [5]:
class ModSeek(torch.nn.Module):
    def __init__(self, num_qs: int, num_sols: int):
        super().__init__()
        
        self.num_qs = num_qs
        self.num_sols = num_sols
        self.W_yes = torch.nn.Parameter(torch.randn(num_qs, num_sols))
        self.W_no = torch.nn.Parameter(torch.randn(num_qs, num_sols))
    
    def forward(self, answered_qs):
        init_prob = 1 / self.num_sols
        
        p = torch.full((self.num_sols,), init_prob)
        z = torch.log(p)
        for i, answered in enumerate(answered_qs):
            if answered == 1:
                z = z + self.W_yes[i]
            else:
                z = z + self.W_no[i]
        
        return z

In [6]:
mod_seek = ModSeek(dataset.num_qs, dataset.num_sols)
optimizer = torch.optim.Adam(mod_seek.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
num_epochs = 1000

In [9]:
mod_seek.train()
for epoch in range(num_epochs):
    total_loss = 0.0
    for answered_qs, sols in train_loader:
        optimizer.zero_grad()
        probs = torch.sigmoid(mod_seek(answered_qs.reshape(-1)))
        loss = criterion(probs, sols)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch}: loss = {total_loss / len(train_loader)}")


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 0: loss = 0.2137380445169078
Epoch 1: loss = 0.19496541544795037
Epoch 2: loss = 0.1787091467115614
Epoch 3: loss = 0.16385896876454353
Epoch 4: loss = 0.14939884444077808
Epoch 5: loss = 0.13467568527493212
Epoch 6: loss = 0.11929753377205796
Epoch 7: loss = 0.1030414113154014
Epoch 8: loss = 0.09105011297182905
Epoch 9: loss = 0.08272532114966048
Epoch 10: loss = 0.07586129429853625
Epoch 11: loss = 0.07013700993524657
Epoch 12: loss = 0.06524649070989755
Epoch 13: loss = 0.06089902946518527
Epoch 14: loss = 0.056853337265137166
Epoch 15: loss = 0.05282041314575407
Epoch 16: loss = 0.04858647585949964
Epoch 17: loss = 0.044395491542915506
Epoch 18: loss = 0.038755148607823586
Epoch 19: loss = 0.03350563529464934
Epoch 20: loss = 0.030136385218550762
Epoch 21: loss = 0.027875844627204868
Epoch 22: loss = 0.026152350412060817
Epoch 23: loss = 0.024712773614252607
Epoch 24: loss = 0.023523932938567468
Epoch 25: loss = 0.02238569227564666
Epoch 26: loss = 0.02080211976232628
Epoch 

# Saving & Loading Model

Save

In [15]:
torch.save(mod_seek, 'mod_seek.pt')

Load

In [17]:
mod_seek = ModSeek(dataset.num_qs, dataset.num_sols)
mod_seek = torch.load('mod_seek.pt')
mod_seek.eval()

  mod_seek = torch.load('mod_seek.pt')


ModSeek()

# Testing

In [14]:
TP = 0
FP = 0
FN = 0
TN = 0

with torch.no_grad():
    for answered_qs, sols in test_loader:
        probs = torch.sigmoid(mod_seek(answered_qs.reshape(-1)))
        sols_flattened = sols.reshape(-1)
        for i, prob in enumerate(probs):
            if prob >= 0.5 and sols_flattened[i] == 1:
                TP += 1
            elif prob >= 0.5 and sols_flattened[i] != 1:
                FP += 1
            elif prob < 0.5 and sols_flattened[i] == 1:
                FN += 1
            elif prob < 0.5 and sols_flattened[i] != 1:
                TN += 1

precision = TP / (TP + FP)
recall = TP / (TP + FN)

print(f'Precision: {precision}')
print(f'Recall: {recall}')

Precision: 1.0
Recall: 1.0
