In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score

def train_model(model, train_loader, val_loader, epochs=10, lr=1e-3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    writer = SummaryWriter(log_dir="../logs/exp1_mlp")

    for epoch in range(epochs):
        # ===== Train =====
        model.train()
        train_loss = 0
        y_true, y_pred = [], []

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            logits = model(x)
            loss = criterion(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * x.size(0)

            preds = logits.argmax(dim=1)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

        train_loss /= len(train_loader.dataset)
        train_acc = accuracy_score(y_true, y_pred)

        writer.add_scalar("Train/Loss", train_loss, epoch)
        writer.add_scalar("Train/Accuracy", train_acc, epoch)
        print(next(model.parameters()).device)

        print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
