# Wykrywanie Halucynacji

<img src="https://live.staticflickr.com/65535/54208132682_73767c3560_b.jpg" alt="Embedded Photo" width="500">

*Obraz wygenerowany przy użyciu modelu DALL-E.*

## Wstęp

Modele językowe pomagają nam w codziennych zadaniach, takich jak poprawianie tekstów, pisanie kodu czy odpowiadanie na pytania.
Są one również coraz częściej wykorzystywane w takich dziedzinach jak medycyna czy edukacja.

Jednak skąd możemy wiedzieć, czy wygenerowane przez nie odpowiedzi są poprawne? Modele językowe nie zawsze posiadają pełną wiedzę na zadany temat, a mimo to mogą formułować odpowiedzi, które brzmią wiarygodnie, lecz w rzeczywistości wprowadzają w błąd. Takie niepoprawne odpowiedzi nazywamy halucynacjami.

## Zadanie

W tym zadaniu zmierzysz się z wykrywaniem halucynacji w odpowiedziach na pytania faktograficzne generowane przez duże modele językowe (LLM).
Przeanalizujesz zbiór danych, który pomoże w ocenie, czy odpowiedzi generowane przez model językowy są faktycznie poprawne, czy zawierają halucynacje.

Każdy przykład w zbiorze danych zawiera:

- **Pytanie** np. "Jaka jest główna odpowiedzialność Departamentu Obrony USA?"
- **Odpowiedź modelu językowego** np. "Główną odpowiedzialnością jest obrona kraju."
- **Tokeny** związane z generacją odpowiedzi.
- **Cztery alternatywne odpowiedzi** wygenerowane z przez ten sam model z większą temperaturą.
- **Tokeny alternatywnych odpowiedzi** wygenerowane z przez ten sam model z większą temperaturą.
- **Prawdopodobieństwa alternatywnych odpowiedzi** wygenerowane z przez ten sam model z większą temperaturą.
- **Etykietę (`is_correct`)** wskazującą, czy główna odpowiedź jest poprawna według zaufanego źródła.


Przykład:
```json
[
    {
        "question_id": 34,
        "question": "What is the name of the low-cost carrier that operates as a wholly owned subsidiary of Singapore Airlines?",
        "answer": "Scoot is the low-cost carrier that operates as a wholly owned subsidiary of Singapore Airlines.",
        "tokens": [" Sco", "ot", " is", ..., " Airlines", ".", "\n"],
        "supporting_answers": [
            "As a wholly owned subsidiary of Singapore Airlines, <answer> Scoot </answer> stands as a low-cost carrier that revolutionized air travel in the region.",
            "Scoot, a subsidiary of <answer> Singapore Airlines </answer> , is the low-cost carrier that operates under the same brand.",
            "<answer> Scoot </answer> is the low-cost carrier that operates as a wholly owned subsidiary of Singapore Airlines.",
            "Singapore Airlines operates a low-cost subsidiary named <answer> Scoot </answer> , offering affordable and efficient air travel options to passengers."
        ],
        "supporting_tokens": [
            [" As", " a", ..., ".", "<answer>"],
            [" Sco", "ot", ..., " brand", ".", "\n"],
            ["<answer>", " Sco", ..., ".", "\n"],
            [" Singapore", " Airlines", ..., ".", "\n"]
        ],
        "supporting_probabilities": [
            [0.0029233775567263365, 0.8621460795402527, ..., 0.018515007570385933],
            [0.42073577642440796, 0.9999748468399048, ..., 0.9166142344474792],
            [0.3258324861526489, 0.9969879984855652, ..., 0.921079695224762],
            [0.11142394691705704, 0.960810661315918, ..., 0.9557166695594788]
        ],
        "is_correct": true
    },
    .
    .
    .
]
```

### Dane
Dane dostępne dla Ciebie w tym zadaniu to:

* `train.json` - zbiór danych zawierający 2967 pytań oraz odpowiedzi.
* `valid.json` - 990 dodatkowych pytań.


### Kryterium Oceny

ROC AUC (ang. *Receiver Operating Characteristic Area Under Curve*) to miara jakości klasyfikatora binarnego. Pokazuje zdolność modelu do odróżniania między dwiema klasami - tutaj halucynacją (false) i poprawną odpowiedzią (true).

- **ROC (Receiver Operating Characteristic)**: Wykres pokazujący zależność między *True Positive Rate* (czułość) a *False Positive Rate* (1-specyficzność) przy różnych progach decyzyjnych.
- **AUC (Area Under Curve)**: Pole pod wykresem ROC, które przyjmuje wartości od 0 do 1:
  - **1.0**: Model perfekcyjny.
  - **0.5**: Model losowy (brak zdolności do odróżniania klas).

Im wyższa wartość AUC, tym lepiej model radzi sobie z klasyfikacją.

Za to zadanie możesz zdobyć pomiędzy 0 a 100 punktów. Wynik będzie skalowany liniowo w zależności od wartości ROC AUC:

- **ROC AUC ≤ 0.7**: 0 punktów.
- **ROC AUC ≥ 0.82**: 100 punktów.
- **Wartości pomiędzy 0.7 a 0.82**: skalowane liniowo.

Wzór na wynik:  
$$
\text{Punkty} =
\begin{cases}
0 & \text{dla } \text{ROC AUC} \leq 0.7 \\
100 \times \frac{\text{ROC AUC} - 0.7}{0.82 - 0.7} & \text{dla } 0.7 < \text{ROC AUC} < 0.82 \\
100 & \text{dla } \text{ROC AUC} \geq 0.82
\end{cases}
$$


## Ograniczenia
* Twoje rozwiazanie będzie testowane na Platformie Konkursowej bez dostępu do internetu oraz w środowisku bez GPU.
* Ewaluacja Twojego finalnego rozwiązania na Platformie Konkursowej nie może trwać dłużej niż 5 minut bez GPU.
* Lista dopuszczalnych bibliotek: `xgboost`, `scikit-learn`, `numpy`, `pandas`, `matplotlib`.


## Pliki Zgłoszeniowe
Ten notebook uzupełniony o Twoje rozwiązanie (patrz funkcja `predict_hallucinations`).

## Ewaluacja
Pamiętaj, że podczas sprawdzania flaga `FINAL_EVALUATION_MODE` zostanie ustawiona na `True`.

Za to zadanie możesz zdobyć pomiędzy 0 a 100 punktów. Liczba punktów, którą zdobędziesz, będzie wyliczona na (tajnym) zbiorze testowym na Platformie Konkursowej na podstawie wyżej wspomnianego wzoru, zaokrąglona do liczby całkowitej. Jeśli Twoje rozwiązanie nie będzie spełniało powyższych kryteriów lub nie będzie wykonywać się prawidłowo, otrzymasz za zadanie 0 punktów.


# Kod Startowy
W tej sekcji inicjalizujemy środowisko poprzez zaimportowanie potrzebnych bibliotek i funkcji. Przygotowany kod ułatwi Tobie efektywne operowanie na danych i budowanie właściwego rozwiązania.

In [1]:
######################### NIE ZMIENIAJ TEJ KOMÓRKI PODCZAS WYSYŁANIA ##########################

FINAL_EVALUATION_MODE = False  # W czasie sprawdzania twojego rozwiązania, zmienimy tą wartość na True

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn as sk
from sklearn.metrics import roc_auc_score
import xgboost as xgb
import shutil

def download_data(train=("1TGEDaxw4GKfSq0fpqSk0wRpUSc8GgZN0", "train.json"),
                  valid=("1qrr7bZk6Uct8DeC-V8Bc1qD5su56ryFd", "valid.json")):
    """Pobiera zbiór danych z Google Drive i zapisuje go w folderze 'data'."""
    import gdown

    # Utwórz lub zresetuj folder 'data'
    if not os.path.exists('data'):
        os.makedirs('data')
    else:
        shutil.rmtree('data')
        os.makedirs('data')

    GDRIVE_DATA = [train, valid]

    for file_id, file_name in GDRIVE_DATA:
        # Pobierz plik z Google Drive i zapisz go w folderze 'data'
        url = f'https://drive.google.com/uc?id={file_id}'
        output = f'data/{file_name}'
        gdown.download(url, output, quiet=False)

        print(f"Downloaded: {file_name}")

# Pobierz dane tylko jeśli nie jesteś w trybie FINAL_EVALUATION_MODE
if not FINAL_EVALUATION_MODE:
    download_data()


Downloading...
From: https://drive.google.com/uc?id=1TGEDaxw4GKfSq0fpqSk0wRpUSc8GgZN0
To: /content/data/train.json
100%|██████████| 14.2M/14.2M [00:00<00:00, 41.3MB/s]


Downloaded: train.json


Downloading...
From: https://drive.google.com/uc?id=1qrr7bZk6Uct8DeC-V8Bc1qD5su56ryFd
To: /content/data/valid.json
100%|██████████| 4.77M/4.77M [00:00<00:00, 83.4MB/s]

Downloaded: valid.json





## Ładowanie Danych
Za pomocą poniższego kodu dane zostaną wczytane i odpowiednio przygotowane.

In [2]:
######################### NIE ZMIENIAJ TEJ KOMÓRKI PODCZAS WYSYŁANIA ##########################

def load_data(folder='data'):
    # Wczytaj dane z plików
    train_path = os.path.join(folder, 'train.json')
    valid_path = os.path.join(folder, 'valid.json')

    with open(train_path, 'r') as f:
        train = json.load(f)
    with open(valid_path, 'r') as f:
        valid = json.load(f)

    return train, valid

train, valid = load_data("data")

print(json.dumps(train[0], indent=2))

print(f"\nWszystkie przykłady treningowe: {len(train)}")
print(f"Wszystkie przykłady walidacyjne: {len(valid)}")

{
  "question_id": 2147,
  "question": "What is the name of the American multinational toy manufacturing and entertainment company founded in 1945?",
  "answer": "With a rich history spanning decades, the name of the American multinational toy manufacturing and entertainment company founded in 1945 is Hasbro .",
  "tokens": [
    " With",
    " a",
    " rich",
    " history",
    " spanning",
    " decades",
    ",",
    " the",
    " name",
    " of",
    " the",
    " American",
    " multinational",
    " toy",
    " manufacturing",
    " and",
    " entertainment",
    " company",
    " founded",
    " in",
    " ",
    "1",
    "9",
    "4",
    "5",
    " is",
    " Hasbro",
    ".",
    "\n"
  ],
  "supporting_answers": [
    "The iconic American toy manufacturing and entertainment company, known for its beloved characters, is <answer> Hasbro </answer> .",
    "Mattel, the American multinational toy manufacturing and entertainment company, was founded by <answer> Ruth Handler <

## Kod z Kryterium Oceniającym

Kod, zbliżony do poniższego, będzie używany do oceny rozwiązania na zbiorze testowym.

In [3]:
######################### NIE ZMIENIAJ TEJ KOMÓRKI PODCZAS WYSYŁANIA ##########################

def compute_score(roc_auc: float) -> float:
    """
    Oblicza wynik punktowy na podstawie wartości ROC AUC.

    :param roc_auc: Wartość float w zakresie [0.0, 1.0]
    :return: Wynik punktowy zgodny z określoną funkcją
    """
    if roc_auc <= 0.7:
        return 0
    elif 0.7 < roc_auc < 0.82:
        return int(round(100 * (roc_auc - 0.7) / (0.82 - 0.7)))
    else:
        return 100


def evaluate_algorithm(dataset, algorithm, verbose=False):
    """
    Ewaluacja algorytmu wykrywania halucynacji na podanym zbiorze danych.

    Parametry
    ----------
    dataset : list
        Oznaczony zbiór danych, gdzie każdy element to słownik zawierający klucz 'is_correct'.
    algorithm : callable
        Funkcja, która przyjmuje pojedynczy przykład (słownik) i zwraca prawdopodobieństwo halucynacji.
    verbose : bool
        Jeśli True, wypisuje dodatkowe informacje dla każdego przykładu oraz podsumowanie.

    Zwraca
    -------
    roc_auc : float
        Wartość pola pod krzywą ROC (ROC AUC) dla predykcji.
    """
    predicted_ys = [] # Lista przechowująca przewidywane prawdopodobieństwa halucynacji
    for i, entry in enumerate(dataset):
        # Tworzenie kopii próbki i usunięcie etykiety, aby uzyskać dane wejściowe bez oznaczeń
        sample_unlabeled = dict(entry)
        sample_unlabeled.pop('is_correct', None)

        try:
            # Przewidywanie prawdopodobieństwa dla pojedynczej próbki
            pred_prob = algorithm(sample_unlabeled)
            predicted_ys.append(pred_prob)

        except Exception as e:
            # Jeśli wystąpi błąd, domyślnie ustawiamy prawdopodobieństwo na 0.5
            predicted_ys.append(0.5)
            if verbose:
                print(f"Sample {i} => Error: {e}")

    predicted_ys = np.array(predicted_ys, dtype=np.float32)
    ys = []
    for entry in dataset:
        ys.append(1 if entry.get('is_correct') else 0)
    ys = np.array(ys, dtype=np.int32)

    # Obliczenie metryki ROC AUC
    roc_auc = roc_auc_score(ys, predicted_ys)

    # Obliczenie końcowego wyniku na podstawie ROC AUC
    points = compute_score(roc_auc)

    if verbose:
        print(f"\nLiczba próbek: {len(dataset)}")
        print(f"ROC AUC: {roc_auc:.4f}")
        print(f"Wynik punktowy: {points}")

    return points

# Twoje Rozwiązanie
W tej sekcji należy umieścić Twoje rozwiązanie. Wprowadzaj zmiany wyłącznie tutaj!

In [4]:
sample = train[5]
print(sample)
print(sample['is_correct'])
print(len(sample['supporting_tokens'][0]))
print(len(sample['supporting_probabilities'][0]))

{'question_id': 2743, 'question': 'In what year was the University of Iowa founded?', 'answer': 'Established in 1847 , the University of Iowa has a rich history and is a prominent institution of higher learning in the Midwest.', 'tokens': [' Established', ' in', ' ', '1', '8', '4', '7', ',', ' the', ' University', ' of', ' Iowa', ' has', ' a', ' rich', ' history', ' and', ' is', ' a', ' prominent', ' institution', ' of', ' higher', ' learning', ' in', ' the', ' Midwest', '.', '\n'], 'supporting_answers': ['The University of Iowa, established in <answer> 1847 </answer>, is one of the oldest public universities in the United States.', 'Established in <answer> 1847 </answer> , the University of Iowa is one of the oldest public research universities in the United States.', 'Founded in 1847, the University of Iowa holds the distinction of being the oldest state-run university in <answer> Iowa </answer> , USA.', 'With a rich history dating back to 1847, the University of Iowa is the oldest <

In [5]:
def max(arr):
  return np.max(arr)

def min(arr):
  return np.min(arr)

def sum(arr):
  return np.sum(arr)

def range_stat(arr):
    return np.max(arr) - np.min(arr)

def mean_stat(arr):
    return np.mean(arr)

def median_stat(arr):
    return np.median(arr)

def variance_stat(arr):
    return np.var(arr, ddof=1)

def std_dev_stat(arr):
    return np.std(arr, ddof=1)

def mean_absolute_deviation(arr):
    return np.mean(np.abs(arr - np.mean(arr)))

def root_mean_square(arr):
    return np.sqrt(np.mean(arr**2))

def median_absolute_deviation(arr):
    return np.median(np.abs(arr - np.median(arr)))

def skewness_stat(arr):
    mean = mean_stat(arr)
    std_dev = std_dev_stat(arr)

    return np.mean(((arr - mean) / std_dev) ** 3)

def kurtosis_stat(arr):
    mean = mean_stat(arr)
    std_dev = std_dev_stat(arr)
    return np.mean(((arr - mean) / std_dev) ** 4) - 3

def interquartile_range(arr):
    q75, q25 = np.percentile(arr, [75 ,25])
    return q75 - q25

def geometric_mean(arr):
    return np.exp(np.mean(np.log(arr[arr > 0])))

def harmonic_mean(arr):
    return len(arr) / np.sum(1.0 / arr[arr > 0])

def coefficient_of_variation(arr):
    return std_dev_stat(arr) / mean_stat(arr)

def entropy_stat(arr):
    hist, _ = np.histogram(arr, bins=10, density=True)
    hist = hist[hist > 0]  # Avoid log(0)
    return -np.sum(hist * np.log(hist))

def mode_stat(arr):
    values, counts = np.unique(arr, return_counts=True)
    return values[np.argmax(counts)]

def percentile_90(arr):
    return np.percentile(arr, 90)

def percentile_10(arr):
    return np.percentile(arr, 10)

def energy_stat(arr):
    return np.sum(arr**2)

def signal_to_noise_ratio(arr):
    return mean_stat(arr) / std_dev_stat(arr)

def variation_stat(arr):
    return np.std(arr) / np.mean(arr)

def root_sum_of_squares(arr):
    return np.sqrt(np.sum(arr**2))

def quartile_coeff_dispersion(arr):
    q75, q25 = np.percentile(arr, [75 ,25])
    return (q75 - q25) / (q75 + q25)

def mean_squared_error(arr):
    """Mean Squared Error from the mean."""
    mean = mean_stat(arr)
    return np.mean((arr - mean) ** 2)

def root_median_square(arr):
    """Root Median Square as a robust alternative to RMS."""
    return np.sqrt(np.median(arr**2))

def trimmed_mean(arr, proportion=0.1):
    """Trimmed mean, discarding extreme values."""
    lower = int(proportion * len(arr))
    upper = len(arr) - lower
    return np.mean(np.sort(arr)[lower:upper])

def mean_log_deviation(arr):
    """Mean Log Deviation (MLD) to measure inequality in data."""
    mean = mean_stat(arr)
    if mean <= 0:
      mean *= (-1)
    return np.mean(np.log(arr[arr > 0] / mean))

def gini_coefficient(arr):
    """Gini coefficient as a measure of dispersion."""
    sorted_arr = np.sort(arr)
    n = len(arr)
    cumulative_sum = np.cumsum(sorted_arr)
    return (2 * np.sum((np.arange(1, n+1) * sorted_arr))) / (n * np.sum(sorted_arr)) - (n + 1) / n

def spectral_entropy(arr):
    """Entropy in the frequency domain."""
    fft_vals = np.abs(np.fft.fft(arr)) ** 2
    fft_vals = fft_vals / np.sum(fft_vals)  # Normalize
    fft_vals = fft_vals[fft_vals > 0]  # Avoid log(0)
    return -np.sum(fft_vals * np.log(fft_vals))

def hurst_exponent(arr):
    """Estimates the Hurst exponent for detecting long-range dependencies."""
    lags = range(2, 20)
    tau = [np.std(arr[lag:] - arr[:-lag]) for lag in lags]

    # Ensure no zero values before taking log
    tau = np.array(tau)
    lags = np.array(lags)

    if np.any(tau <= 0):  # If any tau value is zero or negative, return 0
        return 0.0

    return np.polyfit(np.log(lags), np.log(tau), 1)[0]

def mean_power(arr, exponent=3):
    """Mean of powered values to analyze data distribution."""
    return np.mean(np.abs(arr) ** exponent)

def mad_stat(arr):
    median = np.median(arr)
    return np.median(np.abs(arr - median))

def range_coefficient(arr):
    return (np.max(arr) - np.min(arr)) / np.mean(arr)

def root_mean_cube(arr):
    return np.cbrt(np.mean(arr**3))
def harmonic_variance(arr):
    harm_mean = harmonic_mean(arr)
    if harm_mean == 0:
      harm_mean = 0.0001
    arr[arr == 0] = 0.0001
    return np.mean((1 / arr - 1 / harm_mean) ** 2)
def log_std_dev(arr):
    return np.std(np.log(arr[arr > 0]))
def root_mean_log_deviation(arr):
    log_mean = np.mean(np.log(arr[arr > 0]))
    return np.sqrt(np.mean((np.log(arr[arr > 0]) - log_mean) ** 2))
def peak_to_peak_amplitude(arr):
    return np.ptp(arr)
def masd_stat(arr):
    return np.mean(np.abs(arr - np.mean(arr))) / np.mean(arr)
def rms_deviation_from_median(arr):
    median = np.median(arr)
    return np.sqrt(np.mean((arr - median) ** 2))


In [6]:
functions = [
    range_stat,
    mean_stat,
    median_stat,
    variance_stat,
    std_dev_stat,
    mean_absolute_deviation,
    root_mean_square,
    median_absolute_deviation,
    skewness_stat,
    kurtosis_stat,
    interquartile_range,
    geometric_mean,
    harmonic_mean,
    coefficient_of_variation,
    entropy_stat,
    mode_stat,
    percentile_90,
    percentile_10,
    energy_stat,
    signal_to_noise_ratio,
    variation_stat,
    root_sum_of_squares,
    quartile_coeff_dispersion,
    mean_squared_error,
    root_median_square,
    trimmed_mean,
    mean_log_deviation,
    gini_coefficient,
    spectral_entropy,
    #hurst_exponent,
    mean_power,
    mad_stat,
    range_coefficient,
    root_mean_cube,
    harmonic_variance,
    log_std_dev,
    root_mean_log_deviation,
    peak_to_peak_amplitude,
    masd_stat,
    rms_deviation_from_median,
    max,
    min,
    sum,
]

In [7]:
def extract_seqs(sample):
  spl = sample['answer'].split()
  words = {}
  in_ans = []
  is_in_ans = False
  for idx_support, supporting in enumerate(sample['supporting_tokens']):
    for idx_word, word in enumerate(supporting):
      if word == "<answer>":
        is_in_ans = True
        continue
      elif word == "</answer>":
        is_in_ans = False
        continue

      if is_in_ans:
        in_ans.append(word)

      prob = sample['supporting_probabilities'][idx_support][idx_word]
      if word not in words:
        words[word] = [prob]
      else:
        words[word].append(prob)

  seqs = [[],[],[],[]]
  for word, probs in words.items():
    is_in_answer = 1 if word in spl else 0
    is_in_ans_tokens = 1 if word in in_ans else 0

    for prob in probs:
      seqs[is_in_answer + 2 * is_in_ans_tokens].append(prob)

  return seqs

def remove_special_chars(s):
    return ''.join(c if c.isalnum() or c == ' ' else '' for c in s)

def merge(ans,tokens,probs):
  cumulative_prob = 1.0
  pointer = 0
  reconstructed = ""

  new_tokens = []
  new_probs = []

  ans_copy = remove_special_chars(ans)
  for word in ans_copy.split():
    while reconstructed != word and word.startswith(reconstructed):
      if pointer >= len(tokens):
        break

      reconstructed += remove_special_chars(tokens[pointer].strip())
      cumulative_prob *= probs[pointer]
      pointer += 1
    if not word.startswith(reconstructed):
      pointer -= 1
    new_probs.append(cumulative_prob)
    new_tokens.append(word)

    cumulative_prob = 1.0
    reconstructed = ""

  return new_tokens, new_probs

def manual_deepcopy(obj):
    if isinstance(obj, dict):
        return {key: manual_deepcopy(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [manual_deepcopy(item) for item in obj]
    else:
        return obj

def merge_in_sample(sample):
  sample['tokens'], _ = merge(sample['answer'], sample['tokens'], [1.0] * len(sample['tokens']))

  for sam in range(len(sample['supporting_answers'])):
    sample['supporting_tokens'][sam], sample['supporting_probabilities'][sam] = merge(sample['supporting_answers'][sam], sample['supporting_tokens'][sam], sample['supporting_probabilities'][sam])

  return sample

def extract_all_seqs(sample):
  sample2 = manual_deepcopy(sample)
  sample2 = merge_in_sample(sample2)

  return extract_seqs(sample2) + extract_seqs(sample)

def to_seqs(dataset):
  X = []
  y = []
  for sample in dataset:
    X.append(extract_all_seqs(sample))
    y.append(sample['is_correct'])

  y = np.array(y)
  return X,y

In [9]:
def to_numpy(dataset, feature_list):
  seq, y = to_seqs(dataset)

  fun_res = []
  for idx, function in enumerate(functions):
    arr = []
    for s in seq:
      arr2 = []
      for val in s:
        if len(val) <= 1:
          arr2.append(0)
          continue
        arr2.append(function(np.array(val)))
      arr.append(arr2)
    fun_res.append(arr)
  fun_res = np.array(fun_res)

  arr = np.zeros((len(dataset),0))
  for f_list in feature_list:
    list_of_encoded = f_list[1]
    idx1 = list_of_encoded[0]
    idx2 = list_of_encoded[1]
    idx3 = list_of_encoded[2]

    symb1 = list_of_encoded[3]
    symb2 = list_of_encoded[4]

    res = eval("fun_res[idx1]" + symb1 + "fun_res[idx2]" + symb2 + "fun_res[idx3]")

    arr = np.concatenate((arr,res),axis=1)

  return arr, y

def individual_sample(sample, feature_list):
  seq = extract_all_seqs(sample)

  fun_res = []
  for idx, function in enumerate(functions):
    arr2 = []
    for val in seq:
      if len(val) <= 1:
        arr2.append(0)
        continue
      arr2.append(function(np.array(val)))
    fun_res.append(arr2)
  fun_res = np.array(fun_res)

  arr = np.zeros((0))
  for f_list in feature_list:
    list_of_encoded = f_list[1]
    idx1 = list_of_encoded[0]
    idx2 = list_of_encoded[1]
    idx3 = list_of_encoded[2]

    symb1 = list_of_encoded[3]
    symb2 = list_of_encoded[4]

    res = eval("fun_res[idx1]" + symb1 + "fun_res[idx2]" + symb2 + "fun_res[idx3]")

    arr = np.concatenate((arr,res),axis=0)

  return arr



In [10]:
features = [[0.7849194482073301, [24, 21, 37, '-', '*']], [0.7801946180059154, [3, 21, 27, '+', '*']], [0.7779868848463258, [28, 32, 6, '-', '*']], [0.7775111206941001, [28, 21, 27, '+', '*']], [0.774595600470608, [28, 19, 38, '-', '*']], [0.7744456058610394, [21, 1, 21, '-', '*']], [0.774030777018951, [7, 21, 27, '+', '*']], [0.7734425169095486, [13, 28, 19, '+', '*']], [0.7727628538349403, [4, 32, 28, '-', '*']], [0.7727112931879011, [28, 0, 37, '-', '*']], [0.7723363066639791, [17, 38, 22, '-', '*']], [0.7722566220276459, [17, 30, 38, '-', '*']], [0.7720972527549791, [17, 37, 20, '-', '*']], [0.7712840007312237, [28, 24, 13, '-', '*']], [0.7712816570654493, [17, 13, 37, '-', '*']], [0.7711761921055961, [28, 7, 38, '-', '*']], [0.7711761921055961, [28, 38, 7, '-', '*']], [0.7709301071992726, [27, 21, 20, '+', '*']], [0.7709301071992725, [17, 3, 29, '-', '*']], [0.7707121462822429, [2, 28, 0, '-', '*']], [0.7697371813200463, [24, 25, 28, '-', '*']], [0.7694840654163991, [28, 24, 32, '-', '*']], [0.7693786004565462, [11, 21, 20, '-', '*']], [0.769254386170497, [28, 35, 3, '-', '*']], [0.7689989266010753, [17, 3, 21, '-', '*']], [0.7689122109674184, [37, 3, 21, '+', '*']], [0.7688536193230556, [17, 23, 8, '-', '*']], [0.7679724009918393, [17, 32, 7, '-', '*']], [0.7676231947914371, [6, 1, 28, '-', '*']], [0.7675786651417214, [10, 17, 19, '-', '*']], [0.7675528848182018, [19, 16, 28, '-', '*']], [0.7669575937114761, [17, 23, 15, '-', '*']], [0.7667396327944465, [28, 2, 2, '-', '*']], [0.7666060438452993, [16, 21, 37, '+', '*']], [0.7663388659470051, [17, 20, 12, '+', '*']], [0.7661912150032107, [28, 11, 4, '-', '*']], [0.7661841840058874, [28, 22, 5, '-', '*']], [0.7659568484257597, [37, 2, 28, '+', '*']], [0.7658349778054852, [30, 28, 10, '-', '*']], [0.7654576476157887, [28, 4, 3, '-', '*']], [0.7653920249741025, [28, 2, 3, '+', '*']], [0.765302965674671, [28, 1, 22, '-', '*']], [0.7651107850811611, [6, 5, 28, '-', '*']], [0.7649561031400434, [28, 19, 23, '-', '*']], [0.7644709643247196, [38, 21, 22, '+', '*']], [0.7644709643247196, [38, 21, 22, '+', '*']], [0.7642834710627587, [28, 29, 37, '-', '*']], [0.7642201920868469, [32, 37, 29, '-', '*']], [0.7641194144585428, [27, 17, 19, '+', '*']], [0.7639037972072878, [28, 27, 23, '-', '*']], [0.7636459939720918, [24, 26, 21, '+', '*']], [0.7632756947797188, [17, 26, 23, '-', '*']], [0.7632499144561992, [17, 38, 27, '+', '*']], [0.7631866354802874, [28, 22, 37, '-', '*']], [0.7630085168814246, [17, 17, 1, '+', '*']], [0.7629499252370618, [4, 19, 17, '-', '*']], [0.7625585330527185, [10, 16, 11, '-', '*']], [0.7623499467987869, [28, 27, 0, '+', '*']], [0.7623499467987869, [28, 27, 36, '+', '*']], [0.7623452594672379, [28, 2, 0, '+', '*']], [0.7623171354779438, [28, 27, 6, '+', '*']], [0.7623077608148457, [28, 25, 26, '-', '*']], [0.7622491691704829, [28, 4, 26, '-', '*']], [0.7619866786037378, [28, 20, 2, '+', '*']], [0.761918712296277, [12, 32, 22, '-', '*']], [0.7618062163391004, [12, 29, 10, '-', '*']], [0.761731219034316, [11, 19, 19, '-', '*']], [0.7616445034006591, [10, 28, 35, '+', '*']], [0.761510914451512, [17, 3, 4, '-', '*']], [0.7614570101386982, [25, 5, 34, '-', '*']], [0.7614429481440511, [24, 21, 22, '+', '*']], [0.7613726381708158, [29, 31, 29, '-', '*']], [0.7612296745585705, [0, 32, 28, '+', '*']], [0.7610960856094233, [25, 17, 6, '+', '*']], [0.7609414036683056, [29, 35, 10, '-', '*']], [0.7608781246923939, [17, 11, 13, '-', '*']], [0.7608125020507075, [19, 28, 3, '-', '*']], [0.7607023497593055, [26, 21, 23, '-', '*']], [0.7606953187619819, [28, 14, 17, '+', '*']], [0.7605265748262173, [28, 16, 13, '+', '*']], [0.7604726705134035, [10, 19, 25, '-', '*']], [0.760378923882423, [17, 19, 22, '-', '*']], [0.7603062702434131, [22, 35, 4, '+', '*']], [0.7602851772514425, [17, 3, 37, '-', '*']], [0.7602383039359524, [5, 22, 35, '+', '*']], [0.7602195546097562, [1, 22, 24, '-', '*']], [0.7601867432889131, [22, 28, 24, '+', '*']], [0.7601164333156778, [10, 22, 31, '+', '*']], [0.7601000276552561, [19, 7, 3, '-', '*']], [0.760071903665962, [24, 6, 17, '+', '*']], [0.7600554980055405, [17, 24, 29, '+', '*']], [0.7600320613477953, [17, 3, 37, '+', '*']], [0.7598820667382267, [32, 37, 0, '-', '*']], [0.7597531651206285, [32, 28, 20, '-', '*']], [0.7597273847971089, [19, 16, 3, '+', '*']], [0.7596805114816186, [19, 13, 19, '+', '*']], [0.759453175901491, [23, 19, 17, '-', '*']], [0.7591484993508046, [28, 1, 17, '-', '*']], [0.7591461556850301, [28, 16, 31, '+', '*']], [0.7590781893775692, [19, 27, 32, '+', '*']], [0.7589844427465888, [28, 3, 7, '-', '*']], [0.7589633497546182, [11, 5, 15, '-', '*']], [0.758930538433775, [10, 12, 4, '-', '*']], [0.7586797661959024, [28, 27, 34, '-', '*']], [0.7586188308857651, [8, 0, 28, '+', '*']], [0.7585414899152061, [19, 1, 4, '+', '*']], [0.7583024360062061, [9, 0, 21, '-', '*']], [0.7582977486746572, [37, 31, 28, '-', '*']], [0.7581805653859314, [1, 4, 16, '-', '*']], [0.758173534388608, [1, 35, 27, '-', '*']], [0.7581407230677647, [1, 38, 37, '-', '*']], [0.7580891624207255, [10, 17, 27, '-', '*']], [0.7579907284581962, [17, 13, 2, '-', '*']], [0.7578524521774997, [28, 15, 26, '+', '*']], [0.7577118322310292, [2, 14, 17, '+', '*']], [0.7575805869476565, [19, 3, 32, '-', '*']], [0.7575735559503332, [37, 20, 2, '+', '*']], [0.7575571502899116, [7, 21, 37, '-', '*']], [0.7575454319610389, [27, 29, 37, '+', '*']], [0.7574165303434408, [11, 20, 7, '-', '*']], [0.7573134090493623, [19, 23, 0, '-', '*']], [0.7572501300734504, [25, 13, 38, '-', '*']], [0.7572173187526074, [19, 20, 17, '+', '*']], [0.7571962257606368, [26, 8, 17, '-', '*']], [0.7571821637659897, [19, 37, 20, '-', '*']], [0.7571188847900779, [11, 0, 23, '-', '*']], [0.7571048227954307, [19, 22, 34, '-', '*']], [0.7570954481323329, [19, 32, 4, '+', '*']], [0.75696420284896, [25, 21, 28, '-', '*']], [0.7568845182126268, [31, 16, 17, '-', '*']], [0.7565774979961658, [19, 3, 31, '-', '*']], [0.7565353120122246, [36, 17, 6, '-', '*']], [0.7565048443571559, [19, 12, 30, '+', '*']], [0.7564556273758912, [28, 20, 26, '-', '*']], [0.7563782864053323, [5, 16, 28, '-', '*']], [0.7563501624160383, [37, 2, 11, '-', '*']], [0.7563079764320969, [16, 28, 35, '-', '*']], [0.7562868834401264, [19, 1, 23, '-', '*']], [0.7561673564856263, [19, 26, 27, '-', '*']], [0.7561040775097145, [37, 37, 17, '-', '*']], [0.7559259589108516, [26, 25, 1, '+', '*']], [0.755914240581979, [27, 22, 16, '+', '*']], [0.7558837729269103, [38, 27, 28, '+', '*']], [0.7558064319563516, [22, 27, 35, '+', '*']], [0.7557525276435378, [6, 21, 27, '+', '*']], [0.7557220599884691, [32, 32, 10, '-', '*']], [0.7557150289911457, [22, 36, 28, '+', '*']], [0.7556470626836849, [23, 6, 11, '-', '*']], [0.7555345667265082, [0, 14, 17, '-', '*']], [0.7554220707693319, [31, 13, 3, '-', '*']], [0.7554220707693319, [31, 3, 13, '-', '*']], [0.7553658227907435, [28, 30, 12, '-', '*']], [0.7552533268335669, [22, 23, 11, '+', '*']], [0.755061146240057, [27, 3, 12, '-', '*']], [0.7550541152427335, [28, 6, 30, '+', '*']], [0.7550283349192138, [27, 3, 29, '+', '*']], [0.7550048982614687, [11, 19, 27, '-', '*']], [0.7549392756197826, [37, 25, 11, '-', '*']], [0.7549017769673901, [13, 28, 0, '+', '*']], [0.7548900586385177, [32, 19, 16, '-', '*']], [0.754754126023596, [25, 26, 35, '+', '*']], [0.7546603793926154, [19, 7, 28, '-', '*']], [0.7544705424648801, [31, 19, 13, '-', '*']], [0.7543861704969979, [29, 38, 6, '-', '*']], [0.7543510155103802, [31, 23, 27, '-', '*']], [0.7542432068847527, [30, 37, 28, '-', '*']], [0.7542361758874291, [11, 26, 24, '+', '*']], [0.7541963335692624, [22, 6, 32, '-', '*']], [0.7541611785826446, [10, 32, 26, '-', '*']], [0.754102586938282, [25, 28, 35, '-', '*']], [0.7541025869382819, [4, 19, 32, '+', '*']], [0.7540955559409583, [23, 28, 38, '-', '*']], [0.7540861812778603, [13, 4, 21, '+', '*']], [0.7540135276388504, [12, 28, 35, '-', '*']], [0.7538822823554778, [16, 10, 4, '-', '*']], [0.7538775950239288, [28, 23, 13, '+', '*']], [0.75372760041436, [37, 25, 4, '+', '*']], [0.7536854144304189, [38, 32, 4, '-', '*']], [0.7535869804678894, [37, 10, 11, '+', '*']], [0.7535799494705658, [5, 4, 28, '-', '*']], [0.7534885465053598, [26, 19, 24, '+', '*']], [0.7534533915187422, [19, 37, 28, '-', '*']], [0.7533244899011442, [10, 22, 20, '+', '*']], [0.753190900951997, [26, 22, 25, '-', '*']], [0.7531815262888991, [31, 7, 15, '-', '*']], [0.7530596556686245, [7, 1, 19, '+', '*']], [0.7530526246713009, [10, 10, 16, '+', '*']], [0.7529401287141243, [28, 22, 17, '+', '*']], [0.7529073173932813, [6, 36, 26, '+', '*']], [0.7528909117328597, [38, 28, 20, '+', '*']], [0.7528674750751143, [17, 23, 2, '-', '*']], [0.7528299764227222, [11, 13, 11, '-', '*']], [0.7528065397649771, [20, 4, 23, '-', '*']], [0.7527901341045556, [37, 13, 38, '-', '*']], [0.7526870128104771, [4, 19, 2, '-', '*']], [0.7525932661794966, [22, 24, 19, '-', '*']], [0.7525534238613301, [17, 20, 2, '+', '*']], [0.7525042068800651, [25, 2, 3, '-', '*']], [0.7525042068800651, [25, 3, 2, '-', '*']], [0.752447958901477, [37, 26, 24, '-', '*']], [0.7523190572838788, [11, 6, 6, '+', '*']], [0.7522815586314867, [19, 19, 10, '-', '*']], [0.7522674966368397, [27, 34, 27, '+', '*']], [0.7520331300593885, [22, 27, 10, '+', '*']], [0.7520120370674179, [19, 7, 25, '+', '*']], [0.7519932877412219, [16, 19, 19, '-', '*']], [0.7518971974444668, [32, 13, 31, '-', '*']], [0.7516300195461726, [1, 36, 31, '-', '*']], [0.7515550222413883, [10, 1, 24, '-', '*']], [0.7514214332922411, [17, 19, 24, '+', '*']], [0.7514120586291431, [16, 13, 0, '-', '*']], [0.751412058629143, [19, 8, 12, '-', '*']], [0.7513323739928096, [26, 29, 22, '-', '*']], [0.7512948753404174, [11, 17, 22, '-', '*']], [0.7512526893564763, [4, 29, 1, '-', '*']], [0.7511355060677507, [32, 2, 13, '-', '*']], [0.7511167567415546, [1, 13, 35, '-', '*']], [0.7510441031025448, [32, 6, 23, '-', '*']], [0.7509691057977604, [4, 17, 13, '-', '*']], [0.7509292634795938, [25, 25, 3, '-', '*']], [0.7508659845036819, [22, 5, 22, '-', '*']], [0.7508284858512897, [16, 34, 37, '+', '*']], [0.7505308402979268, [0, 19, 2, '-', '*']], [0.7504277190038483, [36, 19, 24, '-', '*']], [0.7503738146910346, [35, 24, 17, '-', '*']], [0.7501699157686521, [19, 30, 37, '-', '*']], [0.7501066367927403, [19, 30, 29, '-', '*']], [0.7500433578168284, [10, 23, 16, '+', '*']], [0.7500339831537304, [6, 10, 23, '-', '*']], [0.7498722702152891, [1, 4, 20, '-', '*']], [0.7498535208890931, [19, 38, 35, '-', '*']], [0.7498511772233184, [1, 27, 26, '+', '*']], [0.7497339939345931, [32, 27, 6, '-', '*']], [0.7496941516164263, [37, 37, 24, '+', '*']], [0.7496613402955832, [23, 21, 26, '+', '*']], [0.7495957176538968, [6, 19, 37, '-', '*']], [0.7495558753357301, [11, 27, 12, '-', '*']], [0.7495300950122105, [19, 24, 36, '-', '*']], [0.7495113456860144, [27, 17, 22, '-', '*']], [0.7495066583544654, [13, 29, 27, '+', '*']], [0.7494691597020733, [11, 7, 0, '-', '*']], [0.7493238524240535, [26, 25, 11, '+', '*']], [0.7493121340951809, [38, 19, 2, '-', '*']], [0.7492441677877202, [11, 23, 22, '+', '*']], [0.7492254184615241, [27, 25, 32, '-', '*']], [0.7491832324775829, [29, 22, 20, '-', '*']], [0.7491316718305436, [22, 6, 30, '-', '*']], [0.7490988605097004, [29, 6, 22, '-', '*']], [0.7490824548492788, [23, 32, 25, '-', '*']], [0.7489793335552003, [28, 0, 22, '-', '*']], [0.7489371475712592, [32, 16, 3, '-', '*']], [0.7488480882718278, [37, 11, 27, '-', '*']], [0.748803558622112, [11, 2, 38, '-', '*']], [0.7487965276247885, [25, 13, 19, '+', '*']], [0.748688718999161, [37, 27, 12, '-', '*']], [0.7486676260071903, [37, 16, 4, '+', '*']], [0.7485527863842393, [11, 1, 38, '-', '*']], [0.7484473214243863, [35, 9, 37, '-', '*']], [0.748391073445798, [36, 19, 0, '-', '*']], [0.748386386114249, [6, 4, 37, '-', '*']], [0.7483793551169255, [22, 3, 15, '+', '*']], [0.7483606057907294, [37, 20, 35, '+', '*']], [0.7483465437960823, [16, 23, 13, '-', '*']], [0.7482856084859449, [27, 12, 25, '-', '*']], [0.748276233822847, [28, 29, 24, '+', '*']], [0.7482457661677784, [27, 35, 20, '+', '*']], [0.7481848308576411, [22, 22, 7, '-', '*']], [0.7481262392132783, [25, 27, 17, '+', '*']], [0.7481028025555332, [31, 22, 0, '+', '*']], [0.7480981152239841, [26, 26, 6, '-', '*']], [0.7480699912346902, [29, 28, 0, '+', '*']], [0.747945776948641, [5, 26, 6, '-', '*']], [0.7478871853042781, [28, 31, 30, '-', '*']], [0.7478754669754055, [16, 24, 37, '-', '*']], [0.7475637594273956, [27, 23, 7, '-', '*']], [0.7474442324728955, [30, 31, 37, '+', '*']], [0.7473903281600819, [11, 16, 3, '-', '*']], [0.7473762661654347, [19, 13, 35, '+', '*']], [0.7473739224996601, [37, 34, 28, '+', '*']], [0.7472614265424837, [13, 32, 11, '-', '*']], [0.7472309588874151, [37, 34, 26, '-', '*']], [0.7472192405585424, [26, 31, 27, '-', '*']], [0.7472051785638953, [15, 19, 32, '-', '*']], [0.7471817419061503, [10, 25, 38, '+', '*']], [0.7470504966227777, [19, 27, 9, '-', '*']], [0.747024716299258, [19, 32, 7, '+', '*']], [0.7470223726334836, [34, 22, 31, '+', '*']], [0.74701534163616, [23, 2, 10, '+', '*']], [0.7469520626602479, [19, 38, 29, '-', '*']], [0.7469192513394051, [28, 7, 19, '-', '*']], [0.7467880060560323, [23, 12, 2, '-', '*']], [0.7466450424437872, [10, 6, 2, '-', '*']], [0.7466450424437872, [16, 11, 24, '-', '*']], [0.7466052001256205, [37, 0, 13, '+', '*']], [0.7464762985080224, [6, 37, 5, '-', '*']], [0.7463356785615517, [31, 20, 38, '+', '*']], [0.7462583375909928, [34, 31, 37, '+', '*']], [0.7461294359733948, [22, 10, 2, '+', '*']], [0.7460474076712867, [37, 38, 10, '+', '*']], [0.7460145963504436, [6, 20, 1, '-', '*']], [0.7457966354334141, [31, 30, 5, '-', '*']], [0.745777886107218, [28, 9, 4, '-', '*']], [0.7457028888024337, [22, 32, 7, '-', '*']], [0.7456888268077866, [31, 17, 35, '-', '*']], [0.7456513281553944, [34, 34, 1, '-', '*']], [0.745557581524414, [30, 19, 2, '-', '*']], [0.7453630572651295, [27, 3, 20, '-', '*']], [0.7453443079389334, [16, 10, 32, '+', '*']], [0.7451286906876784, [5, 5, 11, '+', '*']], [0.7451005666983842, [23, 10, 23, '+', '*']], [0.7450794737064137, [37, 19, 11, '+', '*']], [0.7450771300406391, [29, 11, 10, '-', '*']], [0.7450138510647274, [11, 37, 17, '-', '*']], [0.7449693214150117, [6, 23, 27, '-', '*']], [0.7449646340834626, [25, 27, 23, '+', '*']], [0.7449552594203647, [37, 17, 10, '+', '*']], [0.7448404197974134, [11, 21, 10, '+', '*']], [0.7447443295006585, [20, 31, 20, '+', '*']], [0.7446623011985507, [10, 4, 29, '+', '*']], [0.7445732418991192, [29, 34, 36, '-', '*']], [0.7445474615755996, [20, 25, 20, '+', '*']], [0.7445333995809525, [8, 31, 28, '+', '*']], [0.7444959009285603, [1, 4, 23, '+', '*']], [0.7441232580704131, [37, 7, 27, '+', '*']], [0.7441138834073151, [16, 10, 33, '-', '*']], [0.7440060747816875, [35, 8, 12, '+', '*']], [0.7439404521400013, [5, 25, 20, '+', '*']], [0.7437787392015598, [7, 11, 12, '-', '*']], [0.7437506152122658, [4, 22, 31, '+', '*']], [0.7437388968833933, [31, 1, 38, '+', '*']], [0.7437295222202951, [23, 5, 28, '-', '*']], [0.7436521812497363, [38, 16, 7, '+', '*']], [0.743567809281854, [17, 28, 3, '+', '*']], [0.7435514036214323, [13, 30, 36, '+', '*']], [0.7433404737017264, [6, 20, 4, '-', '*']], [0.7432186030814517, [32, 13, 5, '-', '*']], [0.7431975100894812, [13, 23, 1, '+', '*']], [0.7431295437820203, [37, 23, 22, '-', '*']], [0.7430850141323047, [5, 7, 3, '+', '*']], [0.7430592338087851, [35, 32, 28, '+', '*']], [0.7430404844825889, [36, 13, 20, '+', '*']], [0.7429678308435791, [17, 15, 26, '-', '*']], [0.7429490815173829, [16, 31, 16, '+', '*']], [0.7429022082018928, [17, 25, 8, '-', '*']], [0.7428834588756966, [22, 23, 17, '+', '*']], [0.7427287769345788, [38, 26, 7, '-', '*']], [0.7426162809774024, [28, 15, 35, '+', '*']], [0.7423842580657257, [38, 30, 28, '+', '*']], [0.7419483362316667, [16, 10, 33, '+', '*']], [0.741894431918853, [16, 36, 38, '-', '*']], [0.7418334966087156, [31, 36, 3, '+', '*']], [0.741711625988441, [4, 35, 31, '+', '*']], [0.7415616313788723, [38, 6, 3, '-', '*']], [0.7414327297612743, [4, 10, 23, '+', '*']], [0.7411608645314309, [34, 1, 24, '-', '*']], [0.7411608645314309, [10, 17, 26, '-', '*']], [0.7411163348817151, [5, 31, 31, '-', '*']], [0.7409100922935582, [22, 6, 13, '+', '*']], [0.7407671286813131, [30, 1, 27, '+', '*']], [0.7406733820503327, [32, 36, 19, '-', '*']], [0.7406452580610385, [23, 37, 26, '-', '*']], [0.7402983955264109, [32, 1, 32, '+', '*']], [0.7402280855531755, [31, 1, 38, '-', '*']],[0.7868506288055273, [37, 28, 39, '-', '*']], [0.781146146310367, [28, 10, 35, '+', '*']], [0.7793368363324443, [23, 28, 6, '-', '*']], [0.7737659427864311, [28, 6, 32, '+', '*']], [0.7724300532949597, [17, 6, 11, '+', '*']], [0.7716027392765572, [10, 39, 11, '-', '*']], [0.770941825528145, [20, 17, 39, '-', '*']], [0.7707355829399881, [25, 28, 25, '-', '*']], [0.7697371813200464, [17, 0, 10, '-', '*']], [0.7691325155502224, [22, 28, 36, '+', '*']], [0.768539568109271, [28, 17, 34, '-', '*']], [0.7682161422323885, [17, 40, 23, '-', '*']], [0.7681950492404179, [28, 4, 11, '-', '*']], [0.7676700681069275, [28, 30, 37, '-', '*']], [0.767508355168486, [37, 34, 10, '+', '*']], [0.7663576152732011, [23, 34, 37, '+', '*']], [0.7653873376425535, [34, 1, 1, '-', '*']], [0.7652420303645338, [3, 35, 28, '+', '*']], [0.7650756300945436, [28, 32, 2, '-', '*']], [0.7649420411453964, [28, 21, 30, '+', '*']], [0.7642412850788176, [28, 32, 30, '+', '*']], [0.763922546533484, [39, 37, 39, '-', '*']], [0.7638616112233466, [17, 3, 39, '+', '*']], [0.7634913120309738, [32, 29, 20, '-', '*']], [0.7631257001701501, [37, 37, 0, '+', '*']], [0.7627975869617185, [25, 26, 1, '+', '*']], [0.76257493871314, [7, 28, 13, '-', '*']], [0.7621905775261202, [19, 37, 23, '-', '*']], [0.7620194899245809, [9, 21, 27, '-', '*']], [0.7616632527268551, [13, 37, 27, '+', '*']], [0.7615132581172863, [28, 32, 38, '-', '*']], [0.7613187338580021, [11, 27, 32, '+', '*']], [0.7612718605425117, [37, 31, 39, '+', '*']], [0.7609648403260508, [31, 3, 25, '+', '*']], [0.7608359387084527, [31, 25, 25, '-', '*']], [0.7607023497593055, [3, 37, 20, '-', '*']], [0.7603320505669329, [19, 17, 23, '-', '*']], [0.7600039373585011, [26, 40, 7, '-', '*']], [0.7598515990831579, [17, 36, 19, '+', '*']], [0.7598117567649914, [29, 13, 31, '-', '*']], [0.7598094130992167, [39, 31, 13, '+', '*']], [0.759621919837256, [20, 16, 25, '-', '*']], [0.7593734912651576, [28, 9, 27, '-', '*']], [0.7592000599978438, [34, 11, 35, '-', '*']], [0.7590453780567261, [17, 17, 35, '+', '*']], [0.7588766341209613, [27, 0, 27, '+', '*']], [0.7588602284605397, [22, 20, 0, '+', '*']], [0.7588485101316672, [5, 37, 38, '+', '*']], [0.7587664818295592, [17, 30, 40, '+', '*']], [0.7586469548750591, [24, 35, 27, '-', '*']], [0.7586422675435102, [34, 29, 34, '-', '*']], [0.7586422675435102, [35, 35, 29, '-', '*']], [0.7584735236077452, [32, 11, 19, '-', '*']], [0.7582368133645198, [20, 17, 28, '-', '*']], [0.7581993147121274, [25, 32, 13, '-', '*']], [0.7581125990784706, [31, 37, 41, '+', '*']], [0.7580915060865001, [24, 17, 2, '+', '*']], [0.7579110438218626, [1, 36, 31, '-', '*']], [0.7577094885652547, [26, 4, 40, '+', '*']], [0.7576227729315976, [3, 3, 21, '+', '*']], [0.7573438767044309, [19, 40, 10, '-', '*']], [0.7565728106646168, [27, 4, 19, '-', '*']], [0.7564603147074402, [37, 9, 23, '-', '*']], [0.7563993793973031, [20, 23, 12, '+', '*']], [0.7563290694240676, [38, 31, 20, '-', '*']], [0.7562236044642145, [31, 26, 38, '-', '*']], [0.7559775195578908, [26, 30, 39, '-', '*']], [0.7558111192879006, [37, 36, 31, '+', '*']], [0.7557548713093124, [24, 5, 16, '-', '*']], [0.7556892486676261, [39, 10, 18, '-', '*']], [0.755574409044675, [6, 24, 28, '-', '*']], [0.7555228483976357, [19, 1, 38, '-', '*']], [0.755473631416371, [13, 28, 17, '-', '*']], [0.7553517607960964, [32, 10, 6, '-', '*']], [0.7552322338415964, [31, 1, 16, '-', '*']], [0.7551502055394885, [23, 20, 16, '+', '*']], [0.7551009885582237, [9, 19, 2, '+', '*']], [0.7550963012266747, [10, 20, 31, '+', '*']], [0.7550588025742825, [29, 22, 36, '-', '*']], [0.7547822500128901, [27, 24, 27, '+', '*']], [0.7546744413872626, [31, 26, 7, '+', '*']], [0.7546744413872626, [31, 30, 26, '+', '*']], [0.7545525707669881, [37, 32, 19, '-', '*']], [0.7543955451600959, [8, 19, 24, '-', '*']], [0.7543486718446055, [32, 26, 16, '+', '*']], [0.7537369750774582, [12, 19, 11, '+', '*']], [0.7535635438101442, [6, 12, 16, '+', '*']], [0.7532518362621344, [17, 40, 36, '+', '*']], [0.7532213686070657, [35, 37, 22, '+', '*']], [0.7531674642942519, [26, 20, 25, '-', '*']], [0.7530971543210165, [6, 13, 40, '+', '*']], [0.7530151260189086, [31, 10, 22, '+', '*']], [0.7529471597114479, [20, 6, 6, '-', '*']], [0.7529002863959576, [10, 3, 30, '-', '*']], [0.7528159144280753, [30, 20, 16, '+', '*']], [0.7527362297917417, [1, 38, 10, '-', '*']], [0.7525206125404867, [12, 1, 27, '-', '*']], [0.752443271569928, [13, 20, 4, '+', '*']], [0.7521456260165651, [19, 32, 36, '+', '*']], [0.751934696096859, [20, 30, 31, '+', '*']], [0.751315968332388, [35, 11, 16, '-', '*']], [0.7512128470383096, [19, 25, 1, '+', '*']], [0.7511683173885938, [13, 7, 11, '-', '*']], [0.7509456691400153, [10, 23, 11, '-', '*']], [0.7508472351774858, [38, 13, 7, '+', '*']], [0.7507909871988976, [11, 8, 23, '+', '*']], [0.7506972405679171, [13, 7, 26, '+', '*']], [0.7506573982497504, [23, 13, 39, '+', '*']], [0.7506503672524267, [11, 15, 35, '-', '*']], [0.7504745923193384, [10, 31, 5, '+', '*']], [0.7503199103782208, [26, 26, 36, '+', '*']], [0.750151166442456, [36, 13, 27, '+', '*']], [0.7501253861189363, [3, 27, 10, '+', '*']], [0.7501206987873873, [1, 18, 9, '+', '*']], [0.7499683605120442, [26, 3, 36, '+', '*']], [0.7498722702152891, [5, 3, 34, '+', '*']], [0.7498277405655734, [10, 2, 27, '+', '*']], [0.7497644615896616, [10, 23, 30, '+', '*']], [0.7497129009426224, [22, 29, 25, '-', '*']], [0.7495582190015047, [36, 17, 13, '-', '*']], [0.7489652715605533, [19, 3, 16, '+', '*']], [0.7487613726381708, [17, 15, 22, '+', '*']], [0.7484496650901609, [24, 26, 26, '-', '*']], [0.7480418672453959, [0, 8, 2, '+', '*']], [0.7480137432561018, [25, 3, 7, '-', '*']], [0.7475942270824643, [13, 23, 20, '-', '*']], [0.7473809534969837, [39, 2, 37, '+', '*']], [0.7468676906923658, [39, 8, 19, '+', '*']], [0.7465255154892871, [25, 1, 30, '+', '*']], [0.7463989575374635, [3, 19, 36, '+', '*']], [0.7463825518770419, [5, 30, 3, '-', '*']], [0.7460450640055123, [8, 13, 1, '+', '*']], [0.7460169400162181, [40, 27, 41, '-', '*']], [0.7459114750563651, [4, 34, 38, '+', '*']], [0.7459044440590417, [30, 9, 18, '-', '*']], [0.7459044440590417, [7, 9, 18, '-', '*']], [0.7454732095565315, [39, 35, 38, '-', '*']], [0.7454685222249824, [35, 8, 37, '+', '*']], [0.7446740195274233, [9, 21, 16, '-', '*']], [0.7446458955381291, [22, 30, 36, '-', '*']], [0.7446294898777076, [1, 41, 9, '-', '*']], [0.7445873038937664, [10, 4, 38, '+', '*']], [0.7444091852949036, [13, 14, 24, '+', '*']], [0.7439685761292953, [37, 2, 28, '+', '*']], [0.743682648904805, [15, 19, 16, '+', '*']], [0.7436357755893149, [3, 40, 38, '-', '*']], [0.7433264117070792, [34, 3, 26, '-', '*']], [0.7432209467472262, [35, 27, 6, '+', '*']], [0.7431061071242752, [34, 38, 13, '+', '*']], [0.7429397068542849, [31, 9, 6, '-', '*']], [0.7428717405468241, [16, 23, 20, '+', '*']], [0.7428201798997849, [10, 27, 30, '-', '*']], [0.7427639319211965, [19, 9, 11, '+', '*']], [0.742449880707412, [3, 6, 22, '+', '*']], [0.7420514575257452, [8, 22, 19, '+', '*']], [0.7419413052343432, [28, 32, 12, '+', '*']], [0.7418756825926568, [15, 1, 11, '+', '*']], [0.7418499022691373, [10, 9, 3, '-', '*']], [0.7417444373092843, [3, 35, 16, '+', '*']], [0.741587411702392, [22, 14, 8, '-', '*']], [0.7415522567157743, [22, 38, 21, '+', '*']], [0.7408702499753914, [13, 9, 21, '+', '*']], [0.7405679170904795, [26, 34, 39, '-', '*']], [0.7400171556334694, [35, 5, 3, '+', '*']], [0.7398296623715086, [7, 25, 12, '-', '*']], [0.7397874763875674, [10, 4, 11, '+', '*']], [0.7396937297565869, [36, 19, 38, '-', '*']], [0.7395718591363123, [29, 2, 37, '-', '*']], [0.7389578187033903, [3, 37, 9, '-', '*']], [0.7387422014521353, [8, 5, 24, '+', '*']], [0.7387375141205863, [30, 14, 2, '-', '*']], [0.7384398685672234, [4, 36, 3, '-', '*']], [0.7383461219362428, [19, 40, 41, '+', '*']], [0.7376641151958603, [7, 34, 13, '-', '*']], [0.7373430329847521, [12, 19, 20, '+', '*']], [0.737204756704056, [35, 2, 23, '-', '*']], [0.7371227284019481, [2, 14, 32, '-', '*']], [0.7370430437656147, [4, 10, 40, '+', '*']], [0.7370383564340657, [38, 0, 3, '+', '*']], [0.7369867957870264, [29, 6, 8, '-', '*']], [0.7367805531988694, [25, 9, 2, '-', '*']], [0.7365813416080359, [38, 40, 14, '+', '*']], [0.7365789979422614, [4, 40, 37, '+', '*']], [0.736532124626771, [24, 32, 1, '+', '*']], [0.7365227499636732, [30, 35, 39, '+', '*']], [0.7364782203139574, [31, 15, 11, '-', '*']], [0.7360118308248298, [34, 12, 38, '+', '*']], [0.735737621929212, [32, 6, 8, '-', '*']], [0.735664968290202, [10, 5, 13, '-', '*']], [0.7351048321700939, [1, 14, 2, '+', '*']], [0.7349899925471429, [29, 31, 7, '-', '*']]]

features.sort(reverse=True, key=lambda x: x[0])
print(features)

[[0.7868506288055273, [37, 28, 39, '-', '*']], [0.7849194482073301, [24, 21, 37, '-', '*']], [0.781146146310367, [28, 10, 35, '+', '*']], [0.7801946180059154, [3, 21, 27, '+', '*']], [0.7793368363324443, [23, 28, 6, '-', '*']], [0.7779868848463258, [28, 32, 6, '-', '*']], [0.7775111206941001, [28, 21, 27, '+', '*']], [0.774595600470608, [28, 19, 38, '-', '*']], [0.7744456058610394, [21, 1, 21, '-', '*']], [0.774030777018951, [7, 21, 27, '+', '*']], [0.7737659427864311, [28, 6, 32, '+', '*']], [0.7734425169095486, [13, 28, 19, '+', '*']], [0.7727628538349403, [4, 32, 28, '-', '*']], [0.7727112931879011, [28, 0, 37, '-', '*']], [0.7724300532949597, [17, 6, 11, '+', '*']], [0.7723363066639791, [17, 38, 22, '-', '*']], [0.7722566220276459, [17, 30, 38, '-', '*']], [0.7720972527549791, [17, 37, 20, '-', '*']], [0.7716027392765572, [10, 39, 11, '-', '*']], [0.7712840007312237, [28, 24, 13, '-', '*']], [0.7712816570654493, [17, 13, 37, '-', '*']], [0.7711761921055961, [28, 7, 38, '-', '*']], 

In [11]:
print(len(features))

554


In [12]:
X_train, y_train = to_numpy(train, features)
X_eval, y_eval = to_numpy(valid, features)

print(X_train.shape)
print(y_train.shape)


(2967, 4432)
(2967,)


In [13]:
xgboost = xgb.XGBClassifier(n_estimators=50, max_depth=1, learning_rate=0.3)
xgboost.fit(X_train,y_train)
y_pred = xgboost.predict_proba(X_eval)[:, 1]

roc_auc = roc_auc_score(y_eval, y_pred)

print(f'ROC AUC: {roc_auc:.4f}')

ROC AUC: 0.8139


In [14]:
def predict_hallucinations(sample):
    # TODO: Uruchom swój model lub algorytm na tym zestawie danych.
    # TODO: Zwróć listę prawdopodobieństw dla każdego przykładu w zestawie danych.

    prediction = xgboost.predict_proba(np.expand_dims(individual_sample(sample,features),axis=0))[:, 1]
    return prediction

# Ewaluacja

Uruchomienie poniższej komórki pozwoli sprawdzić, ile punktów zdobyłoby Twoje rozwiązanie na danych walidacyjnych. Przed wysłaniem upewnij się, że cały notebook wykonuje się od początku do końca bez błędów i bez konieczności ingerencji użytkownika po wybraniu opcji "Run All".

In [15]:
if not FINAL_EVALUATION_MODE:
    roc_auc = evaluate_algorithm(valid, predict_hallucinations, verbose=True)


Liczba próbek: 990
ROC AUC: 0.8139
Wynik punktowy: 95


Podczas sprawdzania model zostanie zapisany jako `your_model.pkl` i oceniony na zbiorze testowym.

In [16]:
######################### NIE ZMIENIAJ TEJ KOMÓRKI PODCZAS WYSYŁANIA ##########################
if FINAL_EVALUATION_MODE:
    import cloudpickle

    OUTPUT_PATH = "file_output"
    FUNCTION_FILENAME = "your_model.pkl"
    FUNCTION_OUTPUT_PATH = os.path.join(OUTPUT_PATH, FUNCTION_FILENAME)

    if not os.path.exists(OUTPUT_PATH):
        os.makedirs(OUTPUT_PATH)

    with open(FUNCTION_OUTPUT_PATH, "wb") as f:
        cloudpickle.dump(predict_hallucinations, f)