In [2]:
# === Imports ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# === Datos ===
transform = transforms.ToTensor()
train_ds = datasets.MNIST(root="./data", train=True,  download=True, transform=transform)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=512, shuffle=False)

# === Modelo (corregido) ===
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)   # 784, no *2
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)                   # aplanar por batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)                 # logits (sin softmax)

model = Model().to(device)

# === Inferencia para 1 imagen (opcional) ===
def inference(image_tensor: torch.Tensor) -> int:
    model.eval()
    with torch.no_grad():
        logits = model(image_tensor.unsqueeze(0).to(device))
        return logits.argmax(1).item()

# === Validación en batch ===
def validate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            preds = logits.argmax(1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    acc = correct / total
    print(f"Accuracy: {acc:.3f}")
    return acc

# === Entrenamiento (sin usar inference dentro) ===
criterion = nn.CrossEntropyLoss()                  # espera labels como índices 0..9
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

def train_loop(epochs=5):
    for epoch in range(1, epochs+1):
        model.train()
        running = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            logits = model(images)                 # forward con gradiente
            loss = criterion(logits, labels)       # labels como índices
            loss.backward()
            optimizer.step()

            running += loss.item()

        print(f"Epoch {epoch}/{epochs} - loss: {running/len(train_loader):.4f}")
        validate(model, test_loader)

# Ejecuta:
train_loop(5)

# Prueba rápida
img, y = test_ds[0]
print("Pred:", inference(img), "Real:", y)

Device: cpu
Epoch 1/5 - loss: 0.5213
Accuracy: 0.934
Epoch 2/5 - loss: 0.1876
Accuracy: 0.956
Epoch 3/5 - loss: 0.1278
Accuracy: 0.965
Epoch 4/5 - loss: 0.0962
Accuracy: 0.971
Epoch 5/5 - loss: 0.0750
Accuracy: 0.973
Pred: 7 Real: 7
