In [None]:
class Trainer():
     def __init__(self, model, device, criterion, optimizer, early_stopping_patience, save_path, wscheduler=None):
         self.model = model.to(device)
         self.device = device
         self.criterion = criterion
         self.optimizer = optimizer
         self.scheduler =scheduler
         self.early_stopping_patience = early_stopping_patience
         self.save_path = save_path
         self.best_val_loss = 0.0
         self.epochs_no_improve = 0

     def train_one_epoch(self, train_loader):
            self.model.train()

            running_loss = 0.0
            correct, total = 0,0

            for images, labels in tqdm(train_loader, desc="Training", leave=False):
                images, labels = images.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()

            train_acc = 100 * correct / total
            train_loss = running_loss / total

            return train_loss, train_acc

     def validate_per_epoch(self, val_loader):
            val_loss, val_correct, val_total = 0.00, 0, 0
            with torch.no_grad():
                for images, labels in tqdm(val_loader, desc="Validation"):
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                    val_loss += loss.item() * images.size(0)
                    _, preds = torch.max(outputs, 1)
                    val_total += labels.size(0)
                    val_correct += (preds == labels).sum().item()

            epoch_loss = val_loss/val_total
            epoch_acc = 100 * val_correct / val_total

            return epoch_loss, epoch_acc


     def fit(self, train_loader, val_loader, num_epochs):
            history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

            for epoch in range(num_epochs):
                print(f"epochs {epoch}/{num_epochs}")
                train_loss, train_acc = self.train_one_epoch(train_loader)
                val_loss, val_acc = self.validate_per_epoch(val_loader)

                history["train_loss"].append(train_loss)
                history["train_acc"].append(train_acc)
                history["val_loss"].append(val_loss)
                history["val_acc"].append(val_acc)

                print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%"
                     f"\nValidation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")

                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.save = f"model/model-epoch{epoch}-best.pth"
                    torch.save(self.model.state_dict(), self.save_path)
                    self.epochs_no_improve = 0
                else:
                    self.epochs_no_improve += 1
                    if self.early_stopping_patience and self.epochs_no_improve > self.early_stopping_patience:
                        self.save_path = "model/finetuned-model.pth"
                        torch.save(self.model.state_dict(), self.save_path)
                        print(f"Early stopped latest model saved {self.save_path}")


