# Reporte de avances modelo Sentan-I

In [None]:
from Src.pre_processing import load_data, split, to_tensor, to_labels 
from Src.Sentan_Model import Sentan_simple
from Src.Dias_Model import Dias_Model
from Src.pre_processing import split, to_tensor

import torch
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch import nn

from torchmetrics import F1Score, Recall, Precision, Accuracy
import wandb


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Preparación de los datos

Los datos serán cargados desde una versión del conjunto de datos RAVDESS, se representan como un conjunto de 193 variables por cada muestra.

In [None]:
data_path = 'Data/data.pkl'
X, y, labels = load_data(data_path)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)

### Configuracion del ambiente sweep

In [None]:
import wandb

sweep_config = {
    "name": "Sweep Analisis Sentimientos",
    "method": "random",
    "parameters": {
        "epochs": {
            "values": [100, 200, 700]
        },

        "learning_rate": {
            'distribution': 'uniform',
            "min": 1e-7,
            "max": 1e-1
        },

        "decay": {
            'distribution': 'uniform',
            "min": 1e-7,
            "max": 1e-1
        },

        "model": {
            "values": ["Dias", "Sentan"]
        }
    }
}

sweep_id = wandb.sweep(sweep_config)


### Ciclo de entrenamiento

In [None]:
def train():

    global X_train
    global y_train

    global X_test
    global y_test

    f1 = F1Score(num_classes=8).to(device)
    recall = Recall(average='macro', num_classes=8).to(device)
    precision = Precision(average='macro', num_classes=8).to(device)
    accuracy = Accuracy().to(device)

    with wandb.init() as run:

        x, y = to_tensor(X_train, y_train)
        _x, _y = to_tensor(X_test, y_test)

        x, y = x.to(device), y.to(device)
        _x, _y = _x.to(device), _y.to(device)

        config = wandb.config
        if config["model"] == "Dias":
            model = Dias_Model().to(device)
        else:
            model = Sentan_simple().to(device)

        loss_fn = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(),
                         lr=config["learning_rate"], weight_decay=config["decay"])


        for epoch in range(config["epochs"]):

            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()

            # Validacion
            with torch.no_grad():
                pred = model(_x)

                pred = pred.to(device)
            
            # Login de resultados a Weights and biases
                wandb.log({'Validation Accuracy': accuracy(pred, _y), 'F1 Score': f1(
                    pred, _y), 'Recall': recall(pred, _y), 'Precision': precision(pred, _y), "epoch": epoch})
            

### Ejecución del agente sweep

In [None]:
count = 100
wandb.agent(sweep_id, function=train, count=count)
