In [None]:
class PolyaTreeTrainer:
    """
    학습, 평가, 예측, 히스토리 기록
    """
    def __init__(self,
                 model: SoftPolyaTreeMoE,
                 lr: float = 1e-3,
                 temperature_scheduler: TemperatureScheduler = None,
                 max_grad_norm: float = 5.0,
                 device='cpu'):
        self.model = model
        self.lr = lr
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.temperature_scheduler = temperature_scheduler
        self.max_grad_norm = max_grad_norm
        self.device = device
        self.model.to(self.device)

        self.history = {
            'epoch': [],
            'train_loss': [],
            'train_acc': []
        }

    def fit(self, train_loader, num_epochs=10):
        for epoch in range(num_epochs):
            self.model.train()
            total_loss = 0.0
            total_correct = 0
            total_samples = 0

            if self.temperature_scheduler is not None:
                current_temp = self.temperature_scheduler.get_temp(epoch)
            else:
                current_temp = 0.5

            for x_batch, y_batch in train_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)

                self.optimizer.zero_grad()
                logits = self.model(x_batch, temperature=current_temp)
                ce_loss = F.cross_entropy(logits, y_batch)
                reg_loss = self.model.regularization_loss()
                loss = ce_loss + reg_loss

                loss.backward()
                if self.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()

                total_loss += loss.item() * x_batch.size(0)
                preds = torch.argmax(logits, dim=1)
                correct = (preds == y_batch).sum().item()
                total_correct += correct
                total_samples += x_batch.size(0)

            avg_loss = total_loss / total_samples
            avg_acc = total_correct / total_samples

            self.history['epoch'].append(epoch+1)
            self.history['train_loss'].append(avg_loss)
            self.history['train_acc'].append(avg_acc)

            print(f"[Epoch {epoch+1}/{num_epochs}] "
                  f"Temp={current_temp:.3f} loss={avg_loss:.4f}, acc={avg_acc:.4f}")

    def evaluate(self, data_loader):
        self.model.eval()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        test_temp = 0.1

        with torch.no_grad():
            for x_batch, y_batch in data_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                logits = self.model(x_batch, temperature=test_temp)
                ce = F.cross_entropy(logits, y_batch, reduction='sum').item()
                preds = torch.argmax(logits, dim=1)
                correct = (preds == y_batch).sum().item()

                total_loss += ce
                total_correct += correct
                total_samples += x_batch.size(0)

        return total_loss/total_samples, total_correct/total_samples

    def predict(self, x):
        self.model.eval()
        with torch.no_grad():
            x = x.to(self.device)
            logits = self.model(x, temperature=0.1)
            preds = torch.argmax(logits, dim=1)
        return preds.cpu()
