In [1]:
# Data packages
import pandas as pd 
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, average_precision_score
from sklearn.model_selection import KFold


from model.rnn import GRUDecoder
from model.encoders import CustomExcelFormer
from data_processing.pipeline import encoding_pipeline, get_generic_name


import matplotlib.pyplot as plt
from model.utils import get_device
from model.dataset import PretrainingDataset
from model.dataset import FinetuningDataset
import pickle



In [2]:
device = get_device()

Using MPS (Metal Performance Shaders) device


In [3]:
class DataClass:
    def __init__(self,
                 data_path: str = "data/training_data/PreFer_train_data.csv",
                 targets_path: str = 'data/training_data/PreFer_train_outcome.csv',
                 codebook_path: str = 'data/codebooks/PreFer_codebook.csv',
                 importance_path: str = 'features_importance_all.csv') -> None:
        self.data = pd.read_csv(data_path, low_memory=False)
        self.targets = pd.read_csv(targets_path)
        self.codebook = pd.read_csv(codebook_path)
        self.col_importance = pd.read_csv(importance_path)
    def make_sequences(self, n_cols: int, use_codebook: bool = True):
        custom_pairs = self.col_importance.feature.map(lambda x: get_generic_name(x)).unique()[:n_cols]
        self.sequences = encoding_pipeline(self.data, self.codebook, 
                                           custom_pairs=custom_pairs, 
                                           importance=self.col_importance, 
                                           use_codebook=use_codebook)
    def make_pretraining(self):
        self.pretrain_dataset = PretrainingDataset(self.sequences)
        self.seq_len = self.pretrain_dataset.get_seq_len()
        self.vocab_size = self.pretrain_dataset.get_vocab_size()
    def make_finetuning(self, batch_size, test_size: float = 0.2, val_size: float = 0.2):
        targets = self.targets[self.targets.new_child.notna()]
        train_person_ids, test_person_ids = train_test_split(targets['nomem_encr'], test_size=test_size, random_state=42)
        train_person_ids, val_person_ids = train_test_split(train_person_ids, test_size=val_size, random_state=42)
        rnn_data = {person_id: (
                torch.tensor([year-2007 for year, _ in wave_responses.items()]).to(device),
                torch.tensor([ wave_response for _, wave_response in wave_responses.items()]).to(device)
                )
                for person_id, wave_responses in self.sequences.items()
                }

        # split data based on the splits made for the target
        train_data = {person_id: rnn_data[person_id] for person_id in train_person_ids}
        val_data = {person_id: rnn_data[person_id] for person_id in val_person_ids}
        test_data = {person_id: rnn_data[person_id] for person_id in test_person_ids}

        self.train_dataset = FinetuningDataset(train_data, targets = targets)
        self.val_dataset = FinetuningDataset(val_data, targets = targets)
        self.test_dataset = FinetuningDataset(test_data, targets = targets)
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        self.val_dataloader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
        self.test_dataloader  = DataLoader(self.test_dataset,  batch_size=batch_size)

    def make_finetuning_cv(self, batch_size: int, split_id: int, n_splits: int = 5, test_size: float = 0.2):
        """
        Stupid Imolementation of the K-fold CV

        """
        assert split_id >= 0
        assert split_id < n_splits
        targets = self.targets[self.targets.new_child.notna()]


        train_person_ids, test_person_ids = train_test_split(targets['nomem_encr'], test_size=test_size, random_state=42)
        
        val_person_ids = [idx for i, idx in enumerate(train_person_ids) if i%n_splits == split_id]
        train_person_ids = [idx for i, idx in enumerate(train_person_ids) if i%n_splits != split_id]

        rnn_data = {person_id: (
                torch.tensor([year-2007 for year, _ in wave_responses.items()]).to(device),
                torch.tensor([ wave_response for _, wave_response in wave_responses.items()]).to(device)
                )
                for person_id, wave_responses in self.sequences.items()
                }

        # split data based on the splits made for the target
        train_data = {person_id: rnn_data[person_id] for person_id in train_person_ids}
        val_data = {person_id: rnn_data[person_id] for person_id in val_person_ids}
        test_data = {person_id: rnn_data[person_id] for person_id in test_person_ids}

        self.train_dataset = FinetuningDataset(train_data, targets = targets)
        self.val_dataset = FinetuningDataset(val_data, targets = targets)
        self.test_dataset = FinetuningDataset(test_data, targets = targets)
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        self.val_dataloader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
        self.test_dataloader  = DataLoader(self.test_dataset,  batch_size=batch_size)

# What is the effect of increasing the number of questions?

## Pretraining
I pretrain on all the data. Currently, I only use the Attn-based autoencoder as it seems to train the fastest.

## Finetuning
We perform 5-fold cross validation for the FT.

### Pretraining

In [4]:
# set parameters for the PT
#ENCODING_SIZE = 64
BATCH_SIZE = 8
HIDDEN_SIZE = 64
ENCODING_SIZE = 64
NUM_HEADS = 8
NUM_LAYERS = 3
NUM_EPOCHS = 10
DETECT_ANOMALY = False
assert HIDDEN_SIZE % NUM_HEADS == 0, "Check that the hidden size is divisible"


LEARNING_RATE = 1e-3

n_questions = [25, 50, 100, 150, 250, 500, 1000, 5000]#, 1000, 2000, 4000, 8000, 16000, 27000]
MODEL_PATH = model_name = f"saturation_test_ENC_nquestions"

In [5]:
class PreFerPredictor(nn.Module):
    def __init__(self, vocab_size: int, seq_len: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.encoder = CustomExcelFormer(vocab_size=vocab_size, 
                            hidden_size=HIDDEN_SIZE, 
                            out_size=ENCODING_SIZE,
                            n_years=14,
                            num_heads=NUM_HEADS,
                            num_layers=NUM_LAYERS, 
                            num_classes=2,
                            sequence_len=seq_len, 
                            aium_dropout=0.3,
                            diam_dropout=0.2,
                            residual_dropout=0.2,
                            embedding_dropout=0.3,
                            mixup=None,
                            beta=0.2)
        self.decoder = GRUDecoder(
            input_size=ENCODING_SIZE,
            hidden_size=HIDDEN_SIZE,
            num_layers=3,
            max_seq_len=14,
            dropout=0.3,
            bidirectional=True,
            with_attention = True
        )
        self.seq_len = seq_len

    def forward(self, input_year, input_seq, labels):
        bs, ss = labels.size(0), 14
        input_year = input_year.reshape(-1).to(device)
        input_seq = input_seq.reshape(bs * ss, -1).to(device)

        encodings, _ = self.encoder(input_year, input_seq)#, y=labels.unsqueeze(-1).expand(-1, 14).reshape(-1), mixup_encoded=True)
        encodings = encodings.view(bs,ss, -1)
        mask = ~((input_seq == 101).sum(-1) == self.seq_len).view(bs,ss).detach()

        # Forward pass
        out = self.decoder(encodings, mask=mask).flatten()
        return out

In [6]:
all_train_loss = []   # for plotting
n_cols_list = []
metric_per_run = {}
data = DataClass()
for n_quest in n_questions:
    model_name = MODEL_PATH + '-' + str(n_quest)
    data.make_sequences(n_cols=n_quest)
    data.make_pretraining()
    metric_per_run[n_quest] = {
        "f1": list(),
        "mAP": list(),
        "precision": list(),
        "recall":list()
    }
    
    NUM_FOLDS = 5
    for fold_id in range(NUM_FOLDS):

        data.make_finetuning_cv(batch_size=BATCH_SIZE, split_id=fold_id, n_splits=NUM_FOLDS )


        SEQ_LEN = data.seq_len
        VOCAB_SIZE = data.vocab_size
    
        model = PreFerPredictor(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN).to(device)

        # Define loss function and optimizer for RNN
        loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([1/0.5]).to(device))
        optimizer = torch.optim.RAdam(model.parameters(), lr=LEARNING_RATE,
                                     weight_decay=1e-2, decoupled_weight_decay=True)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = NUM_EPOCHS,
                                                        eta_min = 1e-5, last_epoch = -1)

    
        loss_per_epoch = []
        for epoch in range(NUM_EPOCHS):
            # print(epoch)
            loss_per_step = []
            loop_object  = tqdm(enumerate(data.train_dataloader), desc=f"Epochs {epoch}")
            for i, batch in loop_object:        
                optimizer.zero_grad() 
                inputs, labels = batch
                labels = labels.to(torch.float).to(device)
                input_year, input_seq = inputs
                ### Model
                output = model(input_year=input_year, input_seq=input_seq, labels=labels)
                probs = F.sigmoid(output).flatten()
                ### Loss
                loss = loss_fn(output, labels)  
                loss_per_step.append(loss.detach().cpu().numpy())
                loop_object.set_postfix_str("mean loss: %.3f"%np.mean(loss_per_step[-100:]))
                loss.backward()
                optimizer.step()
            # On epoch end
            scheduler.step()
            loss_per_epoch.append(np.mean(loss_per_step))
        #### Validation
        val_loss = []
        preds = []
        targets = []
        model.eval()
        for batch in data.val_dataloader:
            inputs, labels = batch
            labels = labels.to(torch.float).to(device)
            input_year, input_seq = inputs
            output = model(input_year=input_year, input_seq=input_seq, labels=labels)
            probs = F.sigmoid(output).flatten()
            loss = loss_fn(output, labels)  
            val_loss.append(loss.detach().cpu().numpy())
            preds.extend(probs.detach().cpu().numpy().tolist())
            targets.extend(labels.cpu().numpy().tolist())

        # Concatenate all the batches
        yhat = torch.tensor(preds).flatten().detach().cpu().numpy()
        ytrue = torch.tensor(targets).flatten().cpu().numpy()

        # Calculate precision, recall, and F1 score
        precision, recall, f1, _ = precision_recall_fscore_support(ytrue, yhat > 0.5, average='binary')
        map_roc = average_precision_score(ytrue, yhat)
        metric_per_run[n_quest]["f1"].append(f1)
        metric_per_run[n_quest]["precision"].append(precision)
        metric_per_run[n_quest]["recall"].append(recall)
        metric_per_run[n_quest]["mAP"].append(map_roc)
        print(metric_per_run[n_quest])
        _f1 = np.median(metric_per_run[n_quest]["f1"])
        _map_roc = np.median(metric_per_run[n_quest]["mAP"])
        print(f"-- {n_quest} mAP Score: {_map_roc:.4f} -- median f1-score: {_f1:.3f}")

        with open("metric_%s.pkl" %n_quest, "wb") as f:
            pickle.dump(metric_per_run, f)
        model.train()


KeyboardInterrupt: 