Engine类可将更新步骤封装为一个训练引擎，通过run方法启动训练，避免手动编写复杂的训练循环

In [26]:
from ignite.engine import Engine
import torch

def update_fn(engine, batch):
    data = torch.randn(10)
    loss = torch.mean(data ** 2)
    print(batch)
    return loss.item()

trainer = Engine(update_fn)

max_epochs = 5

trainer.run([12],max_epochs=max_epochs)

12
12
12
12
12


State:
	iteration: 5
	epoch: 5
	epoch_length: 1
	max_epochs: 5
	output: 0.4799571633338928
	batch: 12
	metrics: <class 'dict'>
	dataloader: <class 'list'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

Ignite的事件驱动机制允许用户在训练过程的特定阶段执行自定义操作，如记录日志、保存模型、评估智能体性能等。

In [27]:
from ignite.engine import Events

@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch_results(engine):
    print(f"Epoch {engine.state.epoch} completed. Loss: {engine.state.output}")

trainer.run([16], max_epochs=max_epochs)

16
Epoch 1 completed. Loss: 1.4189667701721191
16
Epoch 2 completed. Loss: 1.685107946395874
16
Epoch 3 completed. Loss: 1.067115068435669
16
Epoch 4 completed. Loss: 1.373170256614685
16
Epoch 5 completed. Loss: 0.4603961408138275


State:
	iteration: 5
	epoch: 5
	epoch_length: 1
	max_epochs: 5
	output: 0.4603961408138275
	batch: 16
	metrics: <class 'dict'>
	dataloader: <class 'list'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

## 评估

In [28]:
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Accuracy

def eval_fn(engine, batch):
    y_pred = torch.randint(0,2,(10,))
    y = torch.randint(0,2,(10,))
    return y_pred, y

evaluator = create_supervised_evaluator(eval_fn, metrics={'accuracy': Accuracy()})

#在每个轮次结束时进行评估
@trainer.on(Events.EPOCH_COMPLETED)
def run_evaluation(engine):
    evaluator.run([[1]])
    metrics = evaluator.state.metrics
    print(f"Accuracy:{metrics['accuracy']}")

trainer.run([1],max_epochs=max_epochs)

Current run is terminating due to exception: 'function' object has no attribute 'eval'
Engine run is terminating due to exception: 'function' object has no attribute 'eval'
Engine run is terminating due to exception: 'function' object has no attribute 'eval'


1
Epoch 1 completed. Loss: 0.6212106943130493


AttributeError: 'function' object has no attribute 'eval'