In [2]:
import torch.nn as nn
import torch
# Create a Linear Regression model classl
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Use nn.Linear() instead of creating our own parameters
        self.linear_layer = nn.Linear(in_features=1, 
                                      out_features=1)
    
    # Define the forward computation (input data x flows through nn.Linear())
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear_layer(x)

# Set the manual seed when creating the model (this isn't always need but is used for demonstrative purposes, try commenting it out and seeing what happens)
regression_1 = LinearRegressionModel()
regression_1, regression_1.state_dict()

(LinearRegressionModel(
   (linear_layer): Linear(in_features=1, out_features=1, bias=True)
 ),
 OrderedDict([('linear_layer.weight', tensor([[-0.6645]])),
              ('linear_layer.bias', tensor([-0.3926]))]))

In [3]:
from pathlib import Path

# 1. Create models directory 
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

# 2. Create model save path 
MODEL_NAME = "regression_model_1.pt"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

In [5]:
# loading saved models.
# Instantiate a fresh instance of LinearRegressionModelV2
loaded_model_1 = LinearRegressionModel()

# Load model state dict 
loaded_model_1.load_state_dict(torch.load(MODEL_SAVE_PATH))

# Put model to target device (if your data is on GPU, model will have to be on GPU to make predictions)
#loaded_model_1.to(device)

print(f"Loaded model:\n{loaded_model_1}")
#print(f"Model on device:\n{next(loaded_model_1.parameters()).device}")

Loaded model:
LinearRegressionModel(
  (linear_layer): Linear(in_features=1, out_features=1, bias=True)
)


In [9]:
with torch.inference_mode():
    # 1. Forward pass on test data
    pred = regression_1(torch.zeros(1))
print(pred)

tensor([-0.3926])


In [10]:
with torch.inference_mode():
    # 1. Forward pass on test data
    pred = regression_1(torch.ones(1))
print(pred)

tensor([-1.0570])
