<a href="https://colab.research.google.com/github/Lukas4319/Animal_classification/blob/hhayouni/hhayouni.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train & Test

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device=DEVICE):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.scheduler = None
        self.history = {
            "train_loss": [],
            "val_loss": [],
            "train_acc": [],
            "val_acc": [],
        }
        self.best_loss = float("inf")

    def set_scheduler(self, step_size, gamma=0.1):
        self.scheduler = StepLR(self.optimizer, step_size=step_size, gamma=gamma)

    def _run_epoch(self, loader, is_train=True):
        mode = "train" if is_train else "val"
        self.model.train() if is_train else self.model.eval()

        running_loss = 0
        correct = 0
        total = 0

        for x_batch, y_batch in tqdm(loader, desc=f"{mode} Epoch", leave=False):
            x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
            with torch.set_grad_enabled(is_train):
                y_pred = self.model(x_batch)
                loss = self.criterion(y_pred, y_batch)

                if is_train:
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

            running_loss += loss.item() * x_batch.size(0)
            correct += (y_pred.argmax(1) == y_batch).sum().item()
            total += x_batch.size(0)

        avg_loss = running_loss / total
        accuracy = correct / total * 100
        return avg_loss, accuracy
