## Topics Covered
1. Data (Prepare and Model)  
2. Build Model  
3. Fitting the model to data (Training)  
4. Making predictions and evaluating the model (Inference)  
5. Saving and loading the model  
6. Putting all together

In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt
torch.__version__

'2.0.1'

## Data (Prepare and Loading)
Machine Learning is a game of two parts:  
1. Get data into numerical representation  
2. Build a model to learn patterns in that numerical representation

In [3]:
# Basic Linear Regression Model

start = 0
end = 1
step = 0.02

weight = 0.7
bias = 0.3

X = torch.arange(start, end, step).unsqueeze(dim = 1)
y = weight * X + bias # Basic Regression Equation i.e., y = wx + b

X[:10], y[:10]

(tensor([[0.0000],
         [0.0200],
         [0.0400],
         [0.0600],
         [0.0800],
         [0.1000],
         [0.1200],
         [0.1400],
         [0.1600],
         [0.1800]]),
 tensor([[0.3000],
         [0.3140],
         [0.3280],
         [0.3420],
         [0.3560],
         [0.3700],
         [0.3840],
         [0.3980],
         [0.4120],
         [0.4260]]))

### Splitting the data into train and test sets.

In [4]:
train_split = int(0.8 * len(X))
X_train, y_train = X[:train_split], y[:train_split]
X_test, y_test = X[train_split:], y[train_split:]

len(X_train), len(X_test), len(y_train), len(y_test)

(40, 10, 40, 10)

In [5]:
## Plot Predictions

def plot_predictions(train_data = X_train, train_labels = y_train, test_data = X_test, test_labels = y_test, predictions = None):
    """
    Plot Training data, Testing Data and Compare Predcitions
    """
    plt.figure(figsize=(10,7))

    # Plot training data
    plt.scatter(train_data, train_labels, c = "b", s=4, label = "Training Data")

    # Plot testing data
    plt.scatter(test_data, test_labels, c="g", s=4, label = "Testing Data")

    if predictions is not None:
        plt.scatter(test_data, predictions, c = "r", s = 4, label = "Predictions")

    # Show the legend
    plt.legend(prop = {"size": 14});


In [6]:
plot_predictions();

: 