::: {.callout-caution}
## Availability

This section will be available Tuesday, 27 June 2023.
:::

In [None]:
from abc import ABC, abstractmethod
from numpy import round, mean, inf
from numpy import random
from numpy.random import choice

class SelectionModelTrainerABC(ABC):
    data_class = SelectionData
    
    @abstractmethod
    def construct_model_parameters(self, data: pd.DataFrame) -> SelectionModelParameters:
        raise NotImplementedError
        
    def construct_model_data(self, data: pd.DataFrame) -> SelectionData:            
        frame_hashed, verb_hashed, subj_hashed = self._construct_hashes(data)
        
        model_data = {
            "verb": verb_hashed,
            "frame": frame_hashed,
            "subj": subj_hashed,
            "resp": data.response.astype(int).values - 1
        }
        
        return self.data_class(**model_data)
    
    def _construct_hashes(self, data: pd.DataFrame):
        if hasattr(self, "frame_hash_map"):
            _, frame_hashed = hash_series(data.frame, self.frame_hash_map, indexation=0)
        else:
            self.frame_hash_map, frame_hashed = hash_series(data.frame, indexation=0)
            
        if hasattr(self, "verb_hash_map"):
            _, verb_hashed = hash_series(data.verb, self.verb_hash_map, indexation=0)
        else:
            self.verb_hash_map, verb_hashed = hash_series(data.verb, indexation=0)

        if hasattr(self, "subj_hash_map"):
            _, subj_hashed = hash_series(data.participant, self.subj_hash_map, indexation=0)
        else:
            self.subj_hash_map, subj_hashed = hash_series(data.participant, indexation=0)
            
        return frame_hashed, verb_hashed, subj_hashed
    
    def _initialize_model(self, data: pd.DataFrame):
        model_parameters = self.construct_model_parameters(data)
        
        return self.model_class(model_parameters)
    
    def _construct_splits(self, data: pd.DataFrame) -> tuple[SelectionData]:
        verbs = data.verb.unique()
        frames = data.frame.unique()
        
        verb_frame_pairs = [v + "_" + f for v in verbs for f in frames]
        
        n_dev = int(len(verb_frame_pairs) / 10)
        
        verb_frame_pairs_dev = choice(verb_frame_pairs, n_dev, replace=False)
        
        dev_indicator = (data.verb + "_" + data.frame).isin(verb_frame_pairs_dev)
        
        data_train = data[~dev_indicator]
        data_dev = data[dev_indicator]
        
        return data_train, data_dev
    
    def fit(
        self, data: pd.DataFrame, batch_size=1000, max_epochs:int=10_000, 
        lr: float = 1e-5, patience: int = 0, tolerance: float = 0.05, 
        window_size: int = 100, verbosity: int=100, seed: int = 403928
    ) -> UnconstrainedSelectionModel:
        manual_seed(seed)
        random.seed(seed)
        
        # necessary for initializing hashes
        self._construct_hashes(data)
        data_train, data_dev = self._construct_splits(data)
        self.model = self._initialize_model(data_train)
        
        # wrap the dev split responses in a tensor
        # this tensor will be used to compute the correlation between
        # the models expected value for a response and the actual
        # response
        target_dev = torch.tensor(data_dev.response.values)
        
        # initialize the optimizer
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
        # initialize the dev-train correlation differences
        self.corr_diffs = []
        
        for e in range(max_epochs):
            # shuffle the training data
            data_shuffled = data_train.sample(frac=1.)
            data_shuffled = data_shuffled.reset_index(drop=True)
            
            # compute the number of batches based on the batch size
            n_batches = int(data_shuffled.shape[0]/batch_size)
        
            # zero the total loss for the epoch
            epoch_total_loss = 0.
        
            # initialize the list of correlations
            correlations_train = []
            
            for i in range(n_batches):
                # construct the minibatch
                lower_bound = i*batch_size
                
                if i == (n_batches - 1):
                    upper_bound = data_shuffled.shape[0]
                else:
                    upper_bound = (i+1)*batch_size

                data_sub = self.construct_model_data(
                    data_shuffled.iloc[lower_bound:upper_bound]
                )
                
                # wrap the responses in a tensor
                target = torch.tensor(data_sub.resp)

                # zero out the gradients
                optimizer.zero_grad()

                # compute the (log-)probabilities for the minibatch
                probs = self.model(data_sub)
                logprobs = torch.log(probs)

                # compute the loss
                loss = self.loss_function(logprobs, target)
                loss += self._prior_loss()

                # compute correlation between expected value and target
                expected_value_train = torch.sum(
                    torch.arange(1, probs.shape[1]+1)[None,:] * probs, 
                    axis=1
                )
                corr_train = torch.corrcoef(
                    torch.cat([
                        expected_value_train[None,:], 
                        target[None,:]
                    ], axis=0)
                )
                correlations_train.append(corr_train[0,1].item())
                
                loss.backward()

                optimizer.step()
                
                epoch_total_loss += loss.item()
            
            expected_value_dev = self.expected_value(data_dev)
            corr_dev = torch.corrcoef(
                torch.cat([
                    expected_value_dev[None,:], 
                    target_dev[None,:]
                ], axis=0)
            )[0,1]
            
            correlations_train_mean = mean(correlations_train)
            
            self.corr_diffs.append(
                correlations_train_mean - corr_dev
            )
            
            if verbosity and not e % verbosity:
                print(f"Epoch:             {e}")
                print(f"Mean loss:         {round(epoch_total_loss / n_batches, 2)}")
                print(f"Mean train corr.:  {round(correlations_train_mean, 2)}")
                print(f"Dev corr.:         {round(corr_dev.data.numpy(), 2)}")
                print()
          
            max_window_size = min(len(self.corr_diffs), window_size)
            mean_diff = torch.mean(torch.tensor(self.corr_diffs[-max_window_size:]))
            
            if e > patience and mean_diff > tolerance:
                
                if verbosity:
                    print(f"Epoch:             {e}")
                    print(f"Mean loss:         {round(epoch_total_loss / n_batches, 2)}")
                    print(f"Mean train corr.:  {round(correlations_train_mean, 2)}")
                    print(f"Dev corr.:         {round(corr_dev.data.numpy(), 2)}")
                    print()
                
                break
            else:
                prev_corr_dev = corr_dev
        
        return self
    
    @abstractmethod
    def _prior_loss(self):
        raise NotImplementedError
    
    def expected_value(self, data: pd.DataFrame):
        model_data = self.construct_model_data(data)
        probs = self.model(model_data)
        
        expected_value = torch.sum(
            torch.arange(1, 8)[None,:] * probs, 
            axis=1
        )
        
        return expected_value
    
    def likelihood(self, data: pd.DataFrame):
        model_data = self.construct_model_data(data)
        probs = self.model(model_data)
        
        return probs[model_data.resp]
    
    def predict(self, data: pd.DataFrame):
        model_data = self.construct_model_data(data)
        probs = self.model(model_data)
        
        return probs[model_data.resp]