# Daily-Dose-of-Data-Science

[Daily Dose of Data Science](https://avichawla.substack.com) is a publication on Substack that brings together intriguing frameworks, libraries, technologies, and tips that make the life cycle of a Data Science project effortless. 

Author: Avi Chawla

[Medium](https://medium.com/@avi_chawla) | [LinkedIn](https://www.linkedin.com/in/avi-chawla/)

# Skorch: Use Scikit-learn API on PyTorch Models

Post Link: [Substack](https://avichawla.substack.com/p/skorch-use-scikit-learn-api-on-pytorch)

LinkedIn Post: [LinkedIn](https://www.linkedin.com/posts/avi-chawla_python-sklearn-pytorch-activity-7017074093598420992-4UON?utm_source=share&utm_medium=member_desktop)

In [1]:
!pip install skorch



In [2]:
import torch
import numpy as np
from skorch import NeuralNetClassifier

In [3]:
# Generate some synthetic data
X = np.random.randn(10, 3).astype(np.float32)
y = np.random.randint(0, 2, (10,)).astype(np.float32)

X = torch.from_numpy(X) # convert to tensors
y = torch.from_numpy(y).type(torch.LongTensor)

In [4]:
# Define a simple PyTorch model
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = torch.nn.Linear(3, 20)
        self.fc2 = torch.nn.Linear(20, 2)
    
    def forward(self, x):
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        return x

In [5]:
# Create a NeuralNetClassifier using skorch
model = NeuralNetClassifier(
    MyModel,
    max_epochs=10,
    lr=0.1,
    criterion=torch.nn.CrossEntropyLoss,
)

In [6]:
# Use the familiar scikit-learn API to fit and predict
model.fit(X, y)
predictions = model.predict(X)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6895[0m       [32m0.5000[0m        [35m0.6860[0m  0.0573
      2        [36m0.6542[0m       0.5000        0.7084  0.0020
      3        [36m0.6283[0m       0.5000        0.7278  0.0015
      4        [36m0.6084[0m       0.5000        0.7445  0.0016
      5        [36m0.5924[0m       0.5000        0.7595  0.0014
      6        [36m0.5793[0m       0.5000        0.7729  0.0014
      7        [36m0.5683[0m       0.5000        0.7856  0.0011
      8        [36m0.5590[0m       0.5000        0.7970  0.0013
      9        [36m0.5506[0m       0.5000        0.8073  0.0015
     10        [36m0.5430[0m       0.5000        0.8170  0.0019


