In [None]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision.models import resnet18

In [None]:
def accuracy_fn(y_true, y_pred):
    return (y_true == y_pred).sum().item() / len(y_pred) * 100


In [None]:
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Usando dispositivo: {device}")

Transformaciones y datasets

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_data = datasets.ImageFolder(root="poker/data/train", transform=transform)
test_data = datasets.ImageFolder(root="poker/data/val", transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)


Modelo y config

In [None]:
model = resnet18(weights="DEFAULT")
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(in_features=512, out_features=len(train_data.classes))
model = model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-3)


Entrenamiento

In [None]:
epochs = 20
for epoch in range(epochs):
    model.train()
    train_loss, train_acc = 0, 0

    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        y_logits = model(X)
        y_pred = torch.argmax(y_logits, dim=1)
        loss = loss_fn(y_logits, y)
        acc = accuracy_fn(y, y_pred)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += acc

    model.eval()
    test_loss, test_acc = 0, 0
    with torch.inference_mode():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            y_logits = model(X)
            y_pred = torch.argmax(y_logits, dim=1)
            test_loss += loss_fn(y_logits, y).item()
            test_acc += accuracy_fn(y, y_pred)

    print(f"Epoch {epoch}: Train loss {train_loss/len(train_loader):.4f}, "
          f"Train acc {train_acc/len(train_loader):.2f}%, "
          f"Test loss {test_loss/len(test_loader):.4f}, "
          f"Test acc {test_acc/len(test_loader):.2f}%")
