In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Cargar los tensores guardados
X_train, y_train = torch.load("train_normalizado.pt")
X_val, y_val = torch.load("val_normalizado.pt")
X_test, y_test = torch.load("test_normalizado.pt")

# Crear datasets
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

# Crear loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
#!pip install lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights
from lightning import LightningModule

class LitResNet(LightningModule):
    def __init__(self, num_classes=3):
        super().__init__()
        self.save_hyperparameters()
        resnet = resnet18(weights = ResNet18_Weights.DEFAULT)
        for name, param in resnet.named_parameters():
                if "layer4" in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            
        in_features = resnet.fc.in_features
        resnet.fc = nn.Identity()
        
        #self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.model = nn.Sequential(
            resnet,
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        self.train_losses.append(loss.item())
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        self.val_losses.append(loss.item())
        self.val_accuracies.append(acc.item())

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=3e-5)

In [None]:
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping

# early stopping si la validación no mejora después de 5 épocas
early_stop = EarlyStopping(monitor="val_loss", patience=5, mode="min")

# inicializo el modelo
model = LitResNet(num_classes=3)

# setup del trainer
trainer = Trainer(
    max_epochs=40,
    #callbacks=[early_stop],
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    enable_model_summary=False,
    logger=False
)

# entreno
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
model.eval()
all_preds = []

for x, _ in test_loader:
    x = x.to(model.device)
    with torch.no_grad():
        logits = model(x)
        preds = logits.argmax(dim=1)
        all_preds.append(preds.cpu())

all_preds = torch.cat(all_preds)

In [None]:
import matplotlib.pyplot as plt

plt.plot(model.train_losses, label='Train Loss')
plt.plot(model.val_losses, label='Val Loss')
plt.plot(model.val_accuracies, label='Val Acc')
plt.xlabel("Batches (aprox.)")
plt.ylabel("Valor")
plt.title("Curvas de entrenamiento")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(model.train_losses)

In [None]:
from sklearn.metrics import recall_score, f1_score

def compute_enfermo_vs_sano_recall(y_true, y_pred):
    y_true_bin = [0 if x == 'Normal' else 1 for x in y_true]
    y_pred_bin = [0 if x == 'Normal' else 1 for x in y_pred]
    return recall_score(y_true_bin, y_pred_bin)

def compute_enfermedades_f1(y_true, y_pred):
    y_true_sub = [x for x, y in zip(y_true, y_pred) if x != 'Normal']
    y_pred_sub = [y for x, y in zip(y_true, y_pred) if x != 'Normal']
    return f1_score(y_true_sub, y_pred_sub, average='macro') 

def diagnostic_score(y_true, y_pred, alpha=0.7, beta=0.3):
    recall = compute_enfermo_vs_sano_recall(y_true, y_pred)
    f1_enf = compute_enfermedades_f1(y_true, y_pred)
    return alpha * recall + beta * f1_enf