## Pytorch 在sklearn流程中的PyTorch训练循环

- 使用sklearn pipeline封装PyTorch模型

- reference: [Pytorch 在sklearn流程中的PyTorch训练循环](https://geek-docs.com/pytorch/pytorch-questions/53_pytorch_pytorch_training_loop_within_a_sklearn_pipeline.html)

In [None]:
import torch
from torch import nn
from torch.optim import Adam
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator

class PyTorchModel(BaseEstimator):
    def __init__(self):
        self.model = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 2),
            nn.Softmax(dim=1)
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = Adam(self.model.parameters())

    def fit(self, X, y):
        X = torch.Tensor(X)
        y = torch.Tensor(y).long()
        for epoch in range(10):
            outputs = self.model(X)
            loss = self.loss_fn(outputs, y)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def predict(self, X):
        with torch.no_grad():
            X = torch.Tensor(X)
            outputs = self.model(X)
            _, predicted = torch.max(outputs, dim=1)
            return predicted.numpy()

X_train = [[0.1] * 10, [0.2] * 10, [0.3] * 10]
y_train = [0, 1, 0]

pipeline = Pipeline([
    ('model', PyTorchModel())
])

pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_train)
print(y_pred)


In [None]:
import torch
from torch import nn
from torch.optim import Adam
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler

class PyTorchModel(BaseEstimator):
    # 省略PyTorch模型定义和训练代码...

class CustomScaler(TransformerMixin):
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self, X, y=None):
        self.scaler.fit(X)
        return self

    def transform(self, X):
        return self.scaler.transform(X)

X_train = [[0.1] * 10, [0.2] * 10, [0.3] * 10]
y_train = [0, 1, 0]

pipeline = Pipeline([
    ('scaler', CustomScaler()),
    ('model', PyTorchModel())
])

pipeline.fit(X_train, y_train)
