## 1. Setup & Environment

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from itertools import combinations

from lightgbm import LGBMClassifier

from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.inspection import permutation_importance

print(f'Python: {sys.executable}')
print(f'PyTorch: {torch.__version__}')
print(f'CUDA built with: {torch.version.cuda}')
print(f'CUDA available: {torch.cuda.is_available()}')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Dengue Classification — Neural Network

Binary classification of dengue cases as **confirmed** (1) or **discarded** (0), using clinical and demographic data from Brazil's SINAN surveillance system (2017–2019).

**Dataset:** ~3.3M records across 3 years  
**Target:** `final_classification` → 0 = discarded, 1 = confirmed (classes 10, 11, 12)  
**Features used:** self-reportable symptoms, demographics, comorbidities, hemorrhagic signs


## 2. Data Loading

### 2.1 Load CSVs and merge years

In [None]:
type_disease = 'chikungunya'  # Change to 'chikungunya' for Chikungunya dataset

if type_disease == 'dengue':
    df1 = pd.read_csv("C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\data\\DENGBR17.csv", low_memory=False)
    df4 = pd.read_csv("C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\data\\DENGBR18.csv", low_memory=False)
    df2 = pd.read_csv("C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\data\\DENGBR19.csv", low_memory=False)

elif type_disease == 'chikungunya':
    df1 = pd.read_csv("C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\data\\CHIKBR17.csv", low_memory=False)
    df4 = pd.read_csv("C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\data\\CHIKBR18.csv", low_memory=False)
    df2 = pd.read_csv("C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\data\\CHIKBR19.csv", low_memory=False)

### 2.2 Downcast dtypes to reduce memory usage

In [None]:
df = pd.concat([df1, df2, df4], ignore_index=True)

# Downcast integer columns to smallest fitting type
int_cols = df.select_dtypes("int64").columns
df[int_cols] = df[int_cols].apply(pd.to_numeric, downcast="integer")

# Downcast float columns
float_cols = df.select_dtypes("float64").columns
df[float_cols] = df[float_cols].apply(pd.to_numeric, downcast="float")

## 3. Feature Engineering

### 3.1 Rename columns to English

In [None]:
df = df.rename(columns={
    # === NOTIFICATION INFO ===
    'TP_NOT': 'notification_type',          # Type of notification (individual, aggregate, etc.)
    'ID_AGRAVO': 'disease_code',            # ICD/SINAN code identifying the disease
    'DT_NOTIFIC': 'notification_date',      # Date the case was reported
    'SEM_NOT': 'notification_epi_week',     # Epidemiological week of the notification
    'NU_ANO': 'notification_year',          # Year the case was reported
    'SG_UF_NOT': 'notif_state',            # State (UF) where the case was notified
    'ID_MUNICIP': 'notif_municipality',     # Municipality where the case was notified
    'ID_REGIONA': 'notif_health_region',    # Health region where the case was notified
    'ID_UNIDADE': 'health_facility',        # Health facility that filed the notification
    'DT_SIN_PRI': 'symptom_onset_date',     # Date patient first showed symptoms
    'SEM_PRI': 'symptom_epi_week',          # Epidemiological week of first symptoms

    # === PATIENT DEMOGRAPHICS ===
    'ANO_NASC': 'birth_year',              # Patient's year of birth
    'NU_IDADE_N': 'age',                   # Patient's age (encoded with unit prefix: days/months/years)
    'CS_SEXO': 'sex',                      # Patient's sex (M=male, F=female, I=ignored)
    'CS_GESTANT': 'pregnancy_status',       # Pregnancy trimester (1st, 2nd, 3rd) or N/A
    'CS_RACA': 'race',                     # Patient's race/ethnicity
    'CS_ESCOL_N': 'education_level',        # Patient's education level
    'ID_OCUPA_N': 'occupation_code',        # Patient's occupation (CBO code)

    # === PATIENT RESIDENCE ===
    'SG_UF': 'residence_state',            # State where the patient lives
    'ID_MN_RESI': 'residence_municipality', # Municipality where the patient lives
    'ID_RG_RESI': 'residence_health_region',# Health region where the patient lives
    'ID_PAIS': 'residence_country',         # Country where the patient lives

    # === SYMPTOMS (1=Yes, 2=No, 9=Unknown) ===
    'FEBRE': 'fever',                       # Fever
    'MIALGIA': 'myalgia',                  # Muscle pain
    'CEFALEIA': 'headache',                # Headache
    'EXANTEMA': 'rash',                    # Skin rash
    'VOMITO': 'vomiting',                  # Vomiting
    'NAUSEA': 'nausea',                    # Nausea
    'DOR_COSTAS': 'back_pain',             # Back pain
    'CONJUNTVIT': 'conjunctivitis',        # Conjunctivitis (eye inflammation)
    'ARTRITE': 'arthritis',                # Joint inflammation
    'ARTRALGIA': 'joint_pain',             # Joint pain
    'PETEQUIA_N': 'petechiae',             # Small red/purple spots on skin (bleeding under skin)
    'LEUCOPENIA': 'leucopenia',            # Low white blood cell count
    'LACO': 'tourniquet_test',             # Tourniquet test (prova do laço) for capillary fragility
    'DOR_RETRO': 'retro_orbital_pain',     # Pain behind the eyes

    # === COMORBIDITIES (1=Yes, 2=No, 9=Unknown) ===
    'DIABETES': 'diabetes',                 # Has diabetes
    'HEMATOLOG': 'blood_disorder',          # Has blood/hematological disease
    'HEPATOPAT': 'liver_disease',           # Has liver disease
    'RENAL': 'kidney_disease',              # Has kidney disease
    'HIPERTENSA': 'hypertension',           # Has hypertension
    'ACIDO_PEPT': 'peptic_ulcer',           # Has peptic acid disease / ulcer
    'AUTO_IMUNE': 'autoimmune_disease',     # Has autoimmune disease

    # === CHIKUNGUNYA LAB TESTS ===
    'DT_CHIK_S1': 'chik_test1_date',       # Date of Chikungunya serological test 1
    'DT_CHIK_S2': 'chik_test2_date',       # Date of Chikungunya serological test 2
    'RES_CHIKS1': 'chik_test1_result',      # Result of Chikungunya test 1
    'RES_CHIKS2': 'chik_test2_result',      # Result of Chikungunya test 2
    'DT_PRNT': 'prnt_date',                # Date of PRNT test (plaque reduction neutralization)
    'RESUL_PRNT': 'prnt_result',            # Result of PRNT test

    # === DENGUE LAB TESTS ===
    'DT_SORO': 'serology_date',            # Date of serological test (IgM)
    'RESUL_SORO': 'serology_result',        # Result of serology (positive, negative, inconclusive)
    'DT_NS1': 'ns1_test_date',             # Date of NS1 antigen test
    'RESUL_NS1': 'ns1_result',             # Result of NS1 test
    'DT_VIRAL': 'viral_isolation_date',     # Date of viral isolation test
    'RESUL_VI_N': 'viral_isolation_result', # Result of viral isolation
    'DT_PCR': 'pcr_date',                  # Date of RT-PCR test
    'RESUL_PCR_': 'pcr_result',            # Result of RT-PCR test
    'SOROTIPO': 'serotype',                # Dengue serotype identified (DENV-1, 2, 3, or 4)
    'HISTOPA_N': 'histopathology',         # Histopathology result
    'IMUNOH_N': 'immunohistochemistry',    # Immunohistochemistry result

    # === HOSPITALIZATION ===
    'HOSPITALIZ': 'hospitalized',           # Whether patient was hospitalized (1=Yes, 2=No)
    'DT_INTERNA': 'hospitalization_date',   # Date of hospitalization
    'UF': 'hospital_state',                # State of the hospital
    'MUNICIPIO': 'hospital_municipality',   # Municipality of the hospital

    # === INFECTION ORIGIN ===
    'TPAUTOCTO': 'autochthonous_case',      # Whether infection was local or imported
    'COUFINF': 'infection_state',           # State where infection likely occurred
    'COPAISINF': 'infection_country',       # Country where infection likely occurred
    'COMUNINF': 'infection_municipality',   # Municipality where infection likely occurred

    # === CLASSIFICATION & OUTCOME ===
    'CLASSI_FIN': 'final_classification',   # Final diagnosis (confirmed, discarded, inconclusive)
    'CRITERIO': 'confirmation_criteria',    # How it was confirmed (lab, clinical, epidemiological)
    'DOENCA_TRA': 'work_related',           # Whether the disease is work-related
    'CLINC_CHIK': 'chik_clinical_form',     # Clinical form of Chikungunya (acute, subacute, chronic)
    'EVOLUCAO': 'case_outcome',             # Patient outcome (cured, died, etc.)
    'DT_OBITO': 'death_date',              # Date of death (if applicable)
    'DT_ENCERRA': 'case_closure_date',      # Date the case was officially closed

    # === ALARM SIGNS (dengue warning signs, 1=Yes, 2=No) ===
    'ALRM_HIPOT': 'alarm_hypotension',     # Postural hypotension (drop in blood pressure)
    'ALRM_PLAQ': 'alarm_low_platelets',    # Platelet count drop
    'ALRM_VOM': 'alarm_persistent_vomit',  # Persistent vomiting
    'ALRM_SANG': 'alarm_bleeding',         # Bleeding from mucous membranes
    'ALRM_HEMAT': 'alarm_hematocrit_rise', # Rising hematocrit
    'ALRM_ABDOM': 'alarm_abdominal_pain',  # Intense abdominal pain
    'ALRM_LETAR': 'alarm_lethargy',        # Lethargy / irritability
    'ALRM_HEPAT': 'alarm_liver_enlarged',  # Enlarged liver (hepatomegaly)
    'ALRM_LIQ': 'alarm_fluid_accumul',     # Fluid accumulation (pleural effusion, ascites)
    'DT_ALRM': 'alarm_signs_date',         # Date alarm signs were observed

    # === SEVERITY SIGNS (severe dengue, 1=Yes, 2=No) ===
    'GRAV_PULSO': 'severe_weak_pulse',      # Weak or absent pulse
    'GRAV_CONV': 'severe_convulsions',      # Convulsions
    'GRAV_ENCH': 'severe_cap_refill',       # Slow capillary refill (>2 sec)
    'GRAV_INSUF': 'severe_resp_distress',   # Respiratory distress
    'GRAV_TAQUI': 'severe_tachycardia',     # Tachycardia (fast heart rate)
    'GRAV_EXTRE': 'severe_cold_extremities',# Cold extremities / cyanosis
    'GRAV_HIPOT': 'severe_hypotension',     # Hypotension / shock
    'GRAV_HEMAT': 'severe_hematemesis',     # Vomiting blood
    'GRAV_MELEN': 'severe_melena',          # Black tarry stool (GI bleeding)
    'GRAV_METRO': 'severe_metrorrhagia',    # Abnormal uterine bleeding
    'GRAV_SANG': 'severe_bleeding',         # Severe bleeding
    'GRAV_AST': 'severe_ast_elevated',      # AST/ALT > 1000 (liver enzymes)
    'GRAV_MIOC': 'severe_myocarditis',      # Myocarditis (heart inflammation)
    'GRAV_CONSC': 'severe_altered_consc',   # Altered consciousness
    'GRAV_ORGAO': 'severe_organ_damage',    # Other organ involvement
    'DT_GRAV': 'severity_signs_date',       # Date severity signs were observed

    # === HEMORRHAGIC MANIFESTATIONS ===
    'MANI_HEMOR': 'hemorrhagic_manifest',   # Had hemorrhagic manifestations (1=Yes, 2=No)
    'EPISTAXE': 'nosebleed',                # Epistaxis (nosebleed)
    'GENGIVO': 'gum_bleeding',              # Gingival bleeding (gums)
    'METRO': 'metrorrhagia',                # Abnormal uterine bleeding
    'PETEQUIAS': 'petechiae_hemorrh',       # Petechiae (hemorrhagic context)
    'HEMATURA': 'hematuria',                # Blood in urine
    'SANGRAM': 'other_bleeding',            # Other bleeding
    'LACO_N': 'tourniquet_test_hemorrh',    # Tourniquet test (hemorrhagic context)
    'PLASMATICO': 'plasma_leakage',         # Evidence of plasma leakage
    'EVIDENCIA': 'hemorrhagic_evidence',    # Evidence of hemorrhagic manifestation
    'PLAQ_MENOR': 'platelets_below_100k',   # Platelet count < 100,000
    'CON_FHD': 'dengue_hemorrhagic_fever',  # Confirmed Dengue Hemorrhagic Fever (DHF)
    'COMPLICA': 'complications',            # Complications present

    # === ADMINISTRATIVE / SYSTEM ===
    'DT_INVEST': 'investigation_date',      # Date the case was investigated
    'DT_DIGITA': 'data_entry_date',         # Date the record was entered into the system
    'TP_SISTEMA': 'system_type',            # Type of information system used
    'NDUPLIC_N': 'duplicate_flag',          # Whether this record is a duplicate
    'CS_FLXRET': 'return_flow_flag',        # Case flow return flag (inter-state data sharing)
    'FLXRECEBI': 'flow_received',           # Flow received flag (inter-state data sharing)
    'MIGRADO_W': 'migrated_from_windows',   # Record migrated from old Windows SINAN system
    'DT_NASC' : 'birth_date'                   # Patient's date of birth
})

### 3.2 Drop columns

Drops are split into two groups:
- **`drop_columns`** — administrative fields, post-classification data, label-leaking features (alarm/severity signs), and fields unavailable at diagnosis time (hospitalization, lab test dates)
- **`lab_drop_columns`** — lab results and confirmation evidence not available during clinical triage

In [None]:
drop_columns = [
    # === ADMINISTRATIVE (no predictive value) ===
    'investigation_date',       # filled during/after investigation
    'duplicate_flag',           # system control field
    'return_flow_flag',         # system control field
    'flow_received',            # system control field
    'system_type',              # system control field
    'notification_type',        # administrative notification type
    'notification_epi_week',    # redundant with notification_month/day derived features

    # === POST-CLASSIFICATION (filled after or because of final_classification) ===
    'confirmation_criteria',    # directly tied to classification (lab, clinical, epidemiological)
    'case_closure_date',        # required when classification is filled
    'case_outcome',             # outcome recorded after classification (cura, obito, etc.)
    'death_date',               # post-outcome
    'work_related',             # enabled only if classification=1, cleared if classification=2
    'chik_clinical_form',       # required only if classification=13 (Chikungunya)

    # === INFECTION ORIGIN (filled only when classification=confirmed, cleared on discard) ===
    'autochthonous_case',       # required only if classification=1
    'infection_state',          # required only if classification=1
    'infection_country',        # required only if classification=1
    'infection_municipality',   # required only if classification=1

    # === ALARM SIGNS (filled only when classification=11 or 12, leaks the label directly) ===
    'alarm_hypotension',
    'alarm_low_platelets',
    'alarm_persistent_vomit',
    'alarm_bleeding',
    'alarm_hematocrit_rise',
    'alarm_abdominal_pain',
    'alarm_lethargy',
    'alarm_liver_enlarged',
    'alarm_fluid_accumul',
    'alarm_signs_date',

    # === SEVERITY SIGNS (filled only when classification=12, leaks the label directly) ===
    'severe_weak_pulse',
    'severe_convulsions',
    'severe_cap_refill',
    'severe_resp_distress',
    'severe_tachycardia',
    'severe_cold_extremities',
    'severe_hypotension',
    'severe_hematemesis',
    'severe_melena',
    'severe_metrorrhagia',
    'severe_bleeding',
    'severe_ast_elevated',
    'severe_myocarditis',
    'severe_altered_consc',
    'severe_organ_damage',
    'severity_signs_date',

    # === DHF / COMPLICATIONS (old classification system, directly informs final_classification) ===
    'dengue_hemorrhagic_fever', # confirmed DHF = classification decision
    'complications',            # dengue with complications = classification decision

    # === CHIKUNGUNYA TESTS (disabled for dengue cases per dictionary, ~97.7% NaN) ===
    'chik_test1_date',          # enabled only for Chikungunya, dataset is 100% Dengue (A90)
    'chik_test2_date',
    'prnt_date',
    'chik_test1_result',
    'chik_test2_result',
    'prnt_result',

    # === HOSPITALIZATION (post-assessment decision, not available at diagnosis time) ===
    'hospitalized',             # decision made after clinical evaluation
    'hospitalization_date',     # only filled if hospitalized
    'hospital_state',           # only filled if hospitalized
    'hospital_municipality',    # only filled if hospitalized

    # === DATE FIELDS (not useful as raw values for ANN, keeping only notification_date, symptom_onset_date, birth_year) ===
    'notification_year',        # redundant with notification_date
    'serology_date',            # lab test date, not useful as raw value
    'ns1_test_date',            # lab test date
    'viral_isolation_date',     # lab test date
    'pcr_date',                 # lab test date

    # === NOT SELF-REPORTABLE (require clinical procedure or lab exam) ===
    'leucopenia',               # blood test
    'tourniquet_test',          # clinical procedure (prova do laço)
    'tourniquet_test_hemorrh',  # clinical procedure (hemorrhagic context)
    'plasma_leakage',           # clinical evaluation
    'platelets_below_100k',     # blood test
    'hemorrhagic_evidence',     # clinical evaluation

    # === GEOGRAPHICAL (not available/useful in a self-reported questionnaire) ===
    'notif_state',
    'notif_municipality',
    'notif_health_region',
    'health_facility',
    'residence_municipality',
    'residence_country',
]

lab_drop_columns = [
    'disease_code',
    'serology_result',          # result of serological test (positive, negative, inconclusive)
    'ns1_result',               # result of NS1 antigen test
    'viral_isolation_result',   # result of viral isolation test
    'pcr_result',               # result of RT-PCR test
    'serotype',                 # dengue serotype identified (DENV-1, 2, 3, or 4)
    'histopathology',           # histopathology result
    'immunohistochemistry',     # immunohistochemistry result
    'hemorrhagic_manifest'
]

df = df.drop(columns=drop_columns, errors='ignore')  # ignore errors for columns that may not exist in all datasets
df = df.drop(columns=lab_drop_columns, errors='ignore')

### 3.3 Date-derived features

- `symptom_month`, `symptom_day` — seasonality signals
- `symptom_month_end`, `symptom_year_end` — boundary flags
- `days_to_notification` — delay between symptom onset and reporting (clipped to 0–90 days)

In [None]:
# Derivando features com as datas
df['notification_date'] = pd.to_datetime(df['notification_date'], errors='coerce')
df['symptom_onset_date'] = pd.to_datetime(df['symptom_onset_date'], errors='coerce')

df['symptom_month'] = df['symptom_onset_date'].dt.month
df['symptom_day'] = df['symptom_onset_date'].dt.day
df['symptom_month_end'] = df['symptom_onset_date'].dt.is_month_end
df['symptom_year_end'] = df['symptom_onset_date'].dt.is_year_end

# Dias entre início dos sintomas e notificação (janela crítica da dengue: 3-6 dias)
df['days_to_notification'] = (df['notification_date'] - df['symptom_onset_date']).dt.days
df['days_to_notification'] = df['days_to_notification'].fillna(df['days_to_notification'].median())
df['days_to_notification'] = df['days_to_notification'].clip(0, 90)

# Derivando idade a partir da data de nascimento
if type_disease == 'dengue':
    df['birth_date'] = pd.to_datetime(df['birth_date'])
    df['birth_year'] = df['birth_date'].dt.year
    df['age'] = df['birth_year'].apply(lambda x: 2025 - x if pd.notnull(x) else None)
    df = df.drop(columns=['birth_date', 'birth_year', 'notification_date', 'symptom_onset_date'])

elif type_disease == 'chikungunya':
    df['age'] = df['age'] - 4000
    df = df.drop(columns=['notification_date', 'symptom_onset_date'])

In [None]:
# Convertendo colunas booleanas (1=Yes, 2=No) para 1/0
bools = df.select_dtypes(include=['bool']).columns
df[bools] = df[bools].astype(int)

### 3.4 Encode categorical columns

Uses `OrdinalEncoder` with `+1` shift so that index `0` is reserved as the "unknown" token for `nn.Embedding`.

In [None]:
categorical_columns = [
    'sex',
    'pregnancy_status',
    'race',
    'education_level',
    'occupation_code',
    'symptom_month',
    'symptom_day',
    'residence_state',
    'symptom_epi_week'
]

# This keeps all indices non-negative, which is required by nn.Embedding.
oe = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
df[categorical_columns] = oe.fit_transform(df[categorical_columns]) + 1

# Fill with 0, which is the reserved "unknown" token from the shift above.
df[categorical_columns] = df[categorical_columns].fillna(0).astype(int)

### 3.5 Binarize symptom and comorbidity columns

Original encoding: `1=Yes, 2=No, 9=Unknown` → mapped to `1=Yes, 0=No/Unknown/NaN`.

In [None]:
binary_columns = [
    # === SINTOMAS ===
    'fever', 'myalgia', 'headache', 'rash', 'vomiting', 'nausea',
    'back_pain', 'conjunctivitis', 'arthritis', 'joint_pain',
    'petechiae', 'retro_orbital_pain',

    # === COMORBIDADES ===
    'diabetes', 'blood_disorder', 'liver_disease', 'kidney_disease',
    'hypertension', 'peptic_ulcer', 'autoimmune_disease',

    # === MANIFESTAÇÕES HEMORRÁGICAS ===
    'nosebleed', 'gum_bleeding', 'metrorrhagia',
    'petechiae_hemorrh', 'hematuria', 'other_bleeding',
]

# 1=Sim, 2=Não, 9=Ignorado → 1=Sim, 0=Não/Ignorado/NaN
df[binary_columns] = df[binary_columns].replace({2: 0, 9: 0}).fillna(0).astype(int)

# age: preencher NaN com mediana
df['age'] = df['age'].fillna(df['age'].median())
df['residence_health_region'] = df['residence_health_region'].fillna(df['residence_health_region'].median()).astype(int)

### 3.6 Aggregate count features

Summarizes symptom burden, comorbidity burden, and hemorrhagic signs into three scalar counts.

In [None]:
symptom_cols = [
    'fever', 'myalgia', 'headache', 'rash', 'vomiting', 'nausea', 'back_pain', 'conjunctivitis',
    'arthritis', 'joint_pain', 'petechiae', 'retro_orbital_pain',
]
comorbidity_cols = ['diabetes', 'blood_disorder', 'liver_disease', 'kidney_disease', 'hypertension', 'peptic_ulcer', 'autoimmune_disease']
hemorrhagic_cols = ['nosebleed', 'gum_bleeding', 'metrorrhagia', 'petechiae_hemorrh', 'hematuria', 'other_bleeding']

df['symptom_count']     = df[symptom_cols].sum(axis=1)
df['comorbidity_count'] = df[comorbidity_cols].sum(axis=1)
df['hemorrhagic_count'] = df[hemorrhagic_cols].sum(axis=1)

### 3.7 Pairwise symptom interactions

Creates binary interaction features for all C(12, 2) = 66 symptom pairs.  
Triplet interactions (C(12, 3) = 220 features) are available but disabled by default.

In [None]:
symptom_columns = [
    'fever', 'myalgia', 'headache', 'rash', 'vomiting', 'nausea', 'back_pain', 'conjunctivitis', 
    'arthritis', 'joint_pain', 'petechiae', 'retro_orbital_pain',
]

interaction_cols = {
    f'{a}_and_{b}': (df[a] * df[b]).astype(int)
    for a, b in combinations(symptom_columns, 2)
}

interaction_cols_3 = {
    f'{a}_{b}_{c}': (df[a] * df[b] * df[c]).astype(int)
    for a, b, c in combinations(symptom_columns, 3)
}

df = pd.concat([df, pd.DataFrame(interaction_cols, index=df.index)], axis=1)
# df = pd.concat([df, pd.DataFrame(interaction_cols_3, index=df.index)], axis=1)

### 3.8 Drop near-constant columns

Removes any column (excluding categorical) where a single value dominates ≥99% of rows — these carry almost no information.

In [None]:
# Remove colunas onde >95% dos valores são iguais (baixa variância)
dominance_threshold = 0.99

dominant_ratio = df.drop(columns=['final_classification']).apply(
    lambda col: col.value_counts(normalize=True).iloc[0]
)
cols_to_drop_low_variance = dominant_ratio[dominant_ratio >= dominance_threshold].index.tolist()

# Não dropar colunas categóricas — variância delas é esperada ser concentrada após encoding
cols_to_drop_low_variance = [c for c in cols_to_drop_low_variance if c not in categorical_columns]

df = df.drop(columns=cols_to_drop_low_variance)

print(f'Colunas removidas (>{dominance_threshold*100:.0f}% mesmo valor): {len(cols_to_drop_low_variance)}')
print(cols_to_drop_low_variance)

### 3.9 Encode target

| Class | Meaning | Label |
|-------|---------|-------|
| 5 | Discarded | 0 |
| 10 | Confirmed dengue | 1 |
| 11 | Confirmed + alarm signs | 1 |
| 12 | Confirmed + severe dengue | 1 |

All inconclusive records are excluded.

In [None]:
# Tratando a coluna de target
if type_disease == 'dengue':
    df = df[df['final_classification'].isin([5, 10, 11, 12])]
    
if type_disease == 'chikungunya':
    df = df[df['final_classification'].isin([5, 13])]

dengue_mapping = {
    5 : 0,   # Discarded
    10 : 1,  # Confirmed
    11 : 1,  # Confirmed and alarming
    12 : 1,  # Confirmed with complications
}

chik_mapping = {
    5 : 0,   # Discarded
    13 : 1,  # Confirmed Chikungunya
}

df['final_classification'] = df['final_classification'].map(dengue_mapping if type_disease == 'dengue' else chik_mapping).fillna(0).astype(int)
df['final_classification'].value_counts()

## 4. Model Training — Neural Network

### 4.1 Convert to tensors

In [None]:
categorical_tensors = torch.tensor(df[categorical_columns].values, dtype=torch.long).to(device)
numerical_tensors = torch.tensor(df.drop(columns=categorical_columns + ['final_classification']).values, dtype=torch.float).to(device)
target_tensor = torch.tensor(df['final_classification'].values, dtype=torch.long).to(device)

print(f'Using device: {device}')

### 4.2 Compute embedding sizes

In [None]:
unique_cat = [df[col].max() + 1 for col in categorical_columns]
embedding_sizes = [(size, min(50, (size // 2) + 1)) for size in unique_cat]

### 4.3 Model architecture — `DengueTabularNN`

Tabular neural network with:
- **Embeddings** for categorical columns (each category gets a dense vector)
- **BatchNorm** for numerical inputs
- **4 hidden layers** (1024 → 512 → 256 → 128) with LeakyReLU, BatchNorm, Dropout
- **Single logit output** for BCEWithLogitsLoss

In [None]:
class DengueTabularNN(nn.Module):
    def __init__(self, numericals_shape, embedding_sizes, hidden_layers = [600, 300, 200, 100], probability_dropout = [0.05, 0.3]):
        super().__init__()

        # Tratamento dos embeddings
        lista_embeddings = [nn.Embedding(size, new_size) for size, new_size in embedding_sizes]
        self.embeddings = nn.ModuleList(lista_embeddings)
        self.dropout_embeddings = nn.Dropout(p = probability_dropout[0])

        # Normalização numéricas
        self.normalization = nn.BatchNorm1d(numericals_shape)

        # Soma para a entrada da primeira camada pra hidden layer
        sum_columns = sum([embedding_sizes[i][1] for i in range(len(embedding_sizes))]) + numericals_shape

        layers = []
        current_entries = sum_columns
        for layer_neurons in hidden_layers:
            layers.append(nn.Linear(current_entries, layer_neurons))
            layers.append(nn.LeakyReLU())
            layers.append(nn.BatchNorm1d(layer_neurons))
            layers.append(nn.Dropout(p = probability_dropout[1]))
            current_entries = layer_neurons

        self.layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(current_entries, 1)

    def _engineering_embeddings(self, x_categorical):
        embedded = []
        for i, embedding in enumerate(self.embeddings):
            embedded.append(embedding(x_categorical[:, i]))
        return torch.cat(embedded, dim=1)
    
    def forward(self, x_categorical, x_numerical):
        x_categorical = self._engineering_embeddings(x_categorical)
        x_categorical = self.dropout_embeddings(x_categorical)
        x_numerical = self.normalization(x_numerical)
        x = torch.cat([x_categorical, x_numerical], dim=1)
        x = self.layers(x)
        return self.output_layer(x)

### 4.4 Training loop

**Setup:** 90/10 train/test split, batch size 4096  
**Loss:** BCEWithLogitsLoss with `pos_weight` to handle class imbalance  
**Optimizer:** AdamW (lr=1e-4, weight_decay=1e-4)  
**Scheduler:** ReduceLROnPlateau (patience=3, factor=0.5)  
**Early stopping:** patience=8 on validation loss, saves best checkpoint

In [None]:
x_train_cat, x_test_cat, x_train_num, x_test_num, y_train, y_test = train_test_split(categorical_tensors, numerical_tensors, target_tensor, test_size=0.1, shuffle=True, random_state=42)

train_dataset = TensorDataset(x_train_cat, x_train_num, y_train)
train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True)

test_dataset = TensorDataset(x_test_cat, x_test_num, y_test)
test_loader = DataLoader(test_dataset, batch_size=4096, shuffle=False)

dengue_model = DengueTabularNN(embedding_sizes=embedding_sizes, hidden_layers=[2048, 1024, 512, 256], numericals_shape=x_train_num.shape[1], probability_dropout=[0.1, 0.2]).to(device)
pos_weight = (y_train == 0).sum().float() / (y_train == 1).sum().float()

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
optimizer = torch.optim.AdamW(params=dengue_model.parameters(), lr = 1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5, min_lr=1e-6)
epochs = 150

train_losses = []
val_losses = []

patience = 8
counter = 0
best_val_loss = float('inf')


for epoch in range(epochs):
    dengue_model.train()
    epoch_train_loss = 0
    for cat, num, target in train_loader:
        cat, num, target = cat.to(device), num.to(device), target.to(device)
        optimizer.zero_grad()
        pred = dengue_model(cat, num)
        loss = criterion(pred, target.unsqueeze(1).float())
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item() * len(cat)

    avg_train_loss = epoch_train_loss / len(train_dataset)
    train_losses.append(avg_train_loss)

    dengue_model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for cat, num, target in test_loader:
            cat, num, target = cat.to(device), num.to(device), target.to(device)
            pred = dengue_model(cat, num)
            loss = criterion(pred, target.unsqueeze(1).float())
            epoch_val_loss += loss.item() * len(cat)

    avg_val_loss = epoch_val_loss / len(test_dataset)
    val_losses.append(avg_val_loss)
    scheduler.step(avg_val_loss)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        torch.save(dengue_model.state_dict(), f'C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\models_saved\\best_{type_disease}_model.pth')
    else:
        counter += 1
        if counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            break

    print(f'Epoch: {epoch:3d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {scheduler._last_lr[0]:.6f}')

### 4.5 Loss curves

In [None]:
sns.set_style('whitegrid')

plt.figure(figsize=(9, 5), dpi = 100)
sns.lineplot(x=range(1, len(train_losses) + 1), y=train_losses, label='Train Loss')
sns.lineplot(x=range(1, len(val_losses) + 1), y=val_losses, label='Validation Loss')

for spine in plt.gca().spines.values(): spine.set_visible(False)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Training Loss')
plt.show()

## 5. Evaluation

### 5.1 Threshold sweep — Neural Network

Loads the best checkpoint and evaluates Accuracy, Precision, Recall, and F1 across thresholds from 0.30 to 0.60.

In [None]:
dengue_model = DengueTabularNN(embedding_sizes=embedding_sizes, hidden_layers=[2048, 1024, 512, 256], numericals_shape=x_train_num.shape[1], probability_dropout=[0.1, 0.2]).to(device)
dengue_model.load_state_dict(torch.load(f'C:\\Users\\angej\\Documents\\2_Programação\\health_index_project\\models_saved\\best_{type_disease}_model.pth', weights_only=True))
dengue_model.eval()

all_probs = []

with torch.no_grad():
    for X_cat_batch, X_num_batch, _ in test_loader:
        y_pred_batch = dengue_model(X_cat_batch, X_num_batch)
        all_probs.append(torch.sigmoid(y_pred_batch).squeeze().cpu())

probabilities = torch.cat(all_probs)
y_true = y_test.cpu()

# Busca pelo melhor threshold
print(f'{"Threshold":>10} | {"Accuracy":>10} | {"Precision":>10} | {"Recall":>10} | {"F1":>10}')
print('-' * 60)
for t in [0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]:
    preds = (probabilities > t).long()
    print(f'{t:>10.2f} | {(preds == y_true).float().mean().item():>10.4f} | {precision_score(y_true, preds):>10.4f} | {recall_score(y_true, preds):>10.4f} | {f1_score(y_true, preds):>10.4f}')

predicted_classes = (probabilities > 0.3).long()

test = pd.DataFrame({
    'Actual': y_true.numpy(),
    'Prob': probabilities.numpy(),
    'Predicted': predicted_classes.numpy(),
    'Correct': (y_true == predicted_classes).numpy()
})
display(test.head(10))

### 5.2 Permutation feature importance — Neural Network

Wraps the model in a sklearn-compatible interface to compute permutation importance on 2000 test samples (100 repeats).

In [None]:
# Subsample first, then concatenate
idx = np.random.choice(x_test_cat.shape[0], size=2000, replace=False)

X_test = np.concatenate([
    x_test_cat[idx].cpu().numpy(),
    x_test_num[idx].cpu().numpy().astype(np.float32)
], axis=1)
y_test_np = y_test[idx].cpu().numpy().astype(int).flatten()

n_cat = x_test_cat.shape[1]  # define outside wrapper to avoid closure issues

class SklearnWrapper:
    def fit(self, X, y):
        return self

    def predict(self, X):
        cat = torch.tensor(X[:, :n_cat], dtype=torch.long).to(device)
        num = torch.tensor(X[:, n_cat:], dtype=torch.float32).to(device)
        with torch.no_grad():
            out = dengue_model(cat, num)
        return (out.cpu().numpy() > 0.5).astype(int).flatten()

    def score(self, X, y):
        from sklearn.metrics import accuracy_score
        return accuracy_score(y, self.predict(X))

wrapper = SklearnWrapper()
result = permutation_importance(wrapper, X_test, y_test_np, n_repeats=100, random_state=42)

### 5.3 Top-20 feature importance chart

In [None]:
top_n = 20

all_feature_names = categorical_columns + list(df.drop(columns=categorical_columns + ['final_classification']).columns)
sorted_idx = result.importances_mean.argsort()[::-1][:top_n]

plt.figure(figsize=(12, 8))
plt.bar(range(top_n), result.importances_mean[sorted_idx], yerr=result.importances_std[sorted_idx])
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
for spine in plt.gca().spines.values(): spine.set_visible(False)
plt.xticks(range(top_n), [all_feature_names[i] for i in sorted_idx], rotation=90)
plt.title(f'Permutation Feature Importance - Top {top_n}', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Baseline — LightGBM

Trains a gradient-boosted tree model on the same train/test split as a performance baseline.

In [None]:
X_train_lgbm = pd.concat([
    pd.DataFrame(x_train_cat.cpu().numpy(), columns=categorical_columns),
    pd.DataFrame(x_train_num.cpu().numpy(), columns=df.drop(columns=categorical_columns + ['final_classification']).columns)
], axis=1)

X_test_lgbm = pd.concat([
    pd.DataFrame(x_test_cat.cpu().numpy(), columns=categorical_columns),
    pd.DataFrame(x_test_num.cpu().numpy(), columns=df.drop(columns=categorical_columns + ['final_classification']).columns)
], axis=1)

lgbm_model = LGBMClassifier(
    n_estimators=2000,
    learning_rate=0.03,
    subsample=0.8,
    colsample_bytree=0.8,
    device='gpu',
)

lgbm_model.fit(X_train_lgbm, y_train.cpu().numpy())

lgbm_accuracy = (lgbm_model.predict(X_test_lgbm) == y_test.cpu().numpy()).mean()
print(f'LightGBM Accuracy: {lgbm_accuracy:.4f}')

### 6.1 Feature importance — LightGBM

In [None]:
importances = pd.Series(lgbm_model.feature_importances_, index=X_train_lgbm.columns)
importances = importances.sort_values(ascending=False)
    
plt.figure(figsize=(12, 8))
importances.head(30).plot(kind='bar')
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
for spine in plt.gca().spines.values(): spine.set_visible(False)
plt.title('LightGBM - Top 30 Feature Importances')
plt.ylabel('Importance (F-score)')
plt.tight_layout()
plt.show()

print(importances.to_string())

### 6.2 Threshold sweep — LightGBM

In [None]:
# Accuracy, Precision, Recall e F1 para o modelo XGBoost based on thresholds list

for t in [0.1, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]:
    xgb_probs = lgbm_model.predict_proba(X_test_lgbm)[:, 1]
    xgb_preds = (xgb_probs > t).astype(int)
    print(f'Threshold: {t:.2f} | Accuracy: {(xgb_preds == y_test.cpu().numpy()).mean():.4f} | Precision: {precision_score(y_test.cpu().numpy(), xgb_preds):.4f} | Recall: {recall_score(y_test.cpu().numpy(), xgb_preds):.4f} | F1: {f1_score(y_test.cpu().numpy(), xgb_preds):.4f}')