# Load Library

1. Use the code in the cells below first and then put your code.
2. PyTorch has two primitives to work with data: ```torch.utils.data.DataLoader``` and ```torch.utils.data.Dataset```. Dataset stores the samples and their corresponding labels, and ```DataLoader``` wraps an iterable around the``` Dataset```.


In [None]:
import torch

In [None]:
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.optim as optim
import argparse

This is the `state_dict example` from the simple model used in the [Training a classifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py) tutorial.
In order to load it, needed to import `import torch.optim as optim`. However, it doesn't fully show how to load. 

In [None]:
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

# Save Model

A common way to save a model is to serialize the internal state dictionary (containing the model parameters).

In [None]:
PATH = 'model.pth'
torch.save(model.state_dict(), PATH)

Also, you can save the entire library. The code for this is below.

In [None]:
#torch.save(model, PATH)

# Load Model
The code loads the state dictionary of the model using torch.load and then saves it into a new instance of TheModelClass.

In [None]:
loaded_model = TheModelClass()
loaded_model.load_state_dict(torch.load(PATH))
loaded_model.eval()

[Link](https://pytorch.org/tutorials/beginner/saving_loading_models.html) to the documentation with the model example above.