In [12]:
import torch
import torch.nn as nn

# Tensors

In [2]:
x = torch.arange(4, dtype=torch.float32)
torch.save(x, 'tensor_file')

In [5]:
x_cp = torch.load('tensor_file')

In [6]:
x_cp, x == x_cp

(tensor([0., 1., 2., 3.]), tensor([True, True, True, True]))

## list of tensor

In [7]:
y = torch.zeros(4)
torch.save([x, y], 'xy_file')

In [10]:
x2, y2 = torch.load('xy_file')
x2, y2, y2 == y, y

(tensor([0., 1., 2., 3.]),
 tensor([0., 0., 0., 0.]),
 tensor([True, True, True, True]),
 tensor([0., 0., 0., 0.]))

## tensor dict

In [11]:
dic = {'x': x, 'y': y}
torch.save(dic, 'dict_file')
torch.load('dict_file')

{'x': tensor([0., 1., 2., 3.]), 'y': tensor([0., 0., 0., 0.])}

# Torch Model Parameters

In [13]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)
        self.relu = nn.ReLU()
    def forward(self, x):
        H_1 = self.relu(self.hidden(x))
        out = self.output(H_1)
        return out

net = MLP()
x = torch.randn(size=(2, 20))
y = net(x)

In [14]:
y

tensor([[ 0.0450,  0.0093,  0.2300, -0.2451,  0.5008,  0.0085,  0.0680,  0.6359,
          0.2762, -0.0976],
        [-0.1314, -0.2633,  0.2501, -0.2915,  0.5085, -0.0314, -0.0043,  0.4260,
          0.1069, -0.3342]], grad_fn=<AddmmBackward>)

In [15]:
torch.save(net.state_dict(), 'mlp_params')

In [16]:
params = torch.load('mlp_params')
params

OrderedDict([('hidden.weight',
              tensor([[-0.1825, -0.1946, -0.1661,  ...,  0.0833, -0.1339,  0.0139],
                      [ 0.0015,  0.1536, -0.1240,  ...,  0.1829, -0.0567,  0.0674],
                      [ 0.1413,  0.0682, -0.0787,  ...,  0.0832, -0.0452,  0.2068],
                      ...,
                      [-0.0479,  0.0089, -0.0342,  ...,  0.0622,  0.1368,  0.1644],
                      [-0.2088, -0.2182,  0.0491,  ..., -0.0166, -0.1923, -0.0050],
                      [-0.0231, -0.1908, -0.0692,  ...,  0.1930,  0.0547, -0.1499]])),
             ('hidden.bias',
              tensor([-0.1597, -0.0453, -0.0149, -0.1622, -0.1095,  0.2111,  0.1203, -0.2092,
                      -0.1502,  0.1776,  0.1334,  0.1405, -0.0854,  0.0398,  0.2105, -0.0225,
                      -0.2029, -0.1218, -0.0329, -0.0736, -0.1598,  0.2056, -0.0119,  0.1702,
                       0.1926,  0.1654,  0.2227,  0.0925,  0.0844, -0.1869,  0.2206,  0.0753,
                       0.0894,

In [17]:
net_cp = MLP()
net_cp.load_state_dict(torch.load('mlp_params'))
net_cp.eval()

MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
  (relu): ReLU()
)

In [18]:
net_cp(x) == y

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])