# Outlier Detection
## Based on Action Rate

Classify subjects as outliers based on:
1) Rate of Bad-Action Trials
2) Rate of No-Action Trials

Classify specific trials as outliers based on:
1) Rate of Bad Actions (compared to subject's average)
2) Rate of No Actions (compared to subject's average)

In [58]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.io as pio

import config as cnfg
from analysis.hyper_parameter_selection.on_target_threshold import column_titles
from data_models.SearchArray import SearchArray

pio.renderers.default = "notebook"
# pio.renderers.default = "browser"

### Read data

In [15]:
from analysis.pipeline.full_pipeline import read_saved_data

_targets, _actions, metadata, _idents, _fixations, _visits = read_saved_data()
is_bad = metadata[[cnfg.SUBJECT_STR, cnfg.TRIAL_STR, cnfg.TRIAL_CATEGORY_STR, "bad_actions", "no_actions"]]
is_bad["bad_trial"] = is_bad["bad_actions"] | is_bad["no_actions"]



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



### Find Outlier Subjects

In [95]:
STD_COEFF = 1  # Coefficient for standard deviation to define outliers

subject_counts = (
    is_bad
    .groupby(cnfg.SUBJECT_STR)[["bad_actions", "no_actions"]]
    .sum()
    .sort_index(ascending=True)
)

sum_per_subject = subject_counts[["bad_actions"]].sum(axis=1)
mean, std = sum_per_subject.mean(), sum_per_subject.std()
to_exclude = sum_per_subject[sum_per_subject > mean + STD_COEFF * std].index
print(f"Subjects to exclude: {to_exclude.tolist()}")

fig = make_subplots(
    rows=1, cols=2, column_titles=["Bad/No Actions", "Bad Actions Only"], shared_yaxes=True,
)
for c in range(2):
    cols = [["bad_actions", "no_actions"], ["bad_actions"]][c]
    sums = subject_counts[cols].sum(axis=1)
    mean, std = sums.mean(), sums.std()
    for col in cols:
        fig.add_trace(
            row=1, col=c + 1, trace=go.Bar(
                x=subject_counts.index.map(lambda s: f"Subject {s:02d}"),
                y=subject_counts[col],
                name=col, legendgroup=col, showlegend=c == 0,
                marker_color="red" if col == "bad_actions" else "blue",
            ),
        )
    for i in [-1, 0, 1]:
        name = "mean" if i == 0 else f"mean Â± {abs(i)} std"
        fig.add_hline(
            y=mean + i * std,
            line=dict(
                dash="dash" if i == 0 else "dot",
                color="black" if i == 0 else "gray"
            ),
            name=name, legendgroup=name,
            # showlegend=c == 0 and i <= 0,
            row=1, col=c + 1,
        )

fig.update_layout(
    width=900, height=500,
    title="Number of Trials with Bad Actions or No Actions",
    xaxis_title=cnfg.SUBJECT_STR,
    yaxis_title="Count",
    barmode='stack',
    template="plotly_white",
)
fig.show()

Subjects to exclude: [19]


In [101]:
print(f"Subjects to exclude: {to_exclude.tolist()}")
trials_to_exclude = list(
    metadata
    .loc[~np.isin(metadata[cnfg.SUBJECT_STR], to_exclude.tolist())]
    .loc[metadata["bad_actions"], [cnfg.SUBJECT_STR, cnfg.TRIAL_STR]]
    .itertuples(index=False, name=None)
)

print(f"Trials to exclude: {trials_to_exclude}")

Subjects to exclude: [19]
Trials to exclude: [(2, 32), (2, 41), (2, 42), (2, 47), (2, 48), (2, 49), (2, 55), (2, 56), (3, 47), (12, 2), (12, 8), (12, 15), (12, 23), (12, 28), (12, 44), (12, 58), (12, 59), (13, 9), (13, 12), (13, 33), (13, 35), (13, 39), (13, 55), (13, 60), (14, 5), (14, 8), (14, 9), (14, 39), (14, 44), (15, 24), (15, 25), (15, 42), (15, 52), (15, 54), (16, 3), (16, 21), (16, 57), (17, 12), (17, 14), (17, 15), (17, 26), (17, 34), (17, 36), (17, 46), (17, 47), (18, 13), (18, 27), (18, 28), (18, 30), (18, 32), (18, 34), (18, 35), (18, 38), (18, 39), (18, 51), (18, 54), (18, 60), (20, 38), (20, 58), (21, 50), (21, 54), (22, 17), (22, 24)]
