In [1]:
import torch

In [2]:
from torch import nn

In [3]:
torch.backends.mps.is_available()

True

In [4]:
import sys

In [5]:
sys.path.append("..")

In [6]:
import utils

In [7]:
device = utils.get_device_for_training()

In [8]:
from fashion_mnist import model

In [9]:
neural_network = model.FashionMnistNetwork(n_hidden_features=512).to(device)

In [10]:
print(neural_network)

FashionMnistNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


# example

In [11]:
X = torch.rand(1, 28, 28, device=device)
logits = neural_network(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

Predicted class: tensor([7], device='mps:0')


In [12]:
print(utils.report_model_parameters(model=neural_network))

Layer: linear_relu_stack.0.weight | Size: torch.Size([512, 784])
Layer: linear_relu_stack.0.bias | Size: torch.Size([512])
Layer: linear_relu_stack.2.weight | Size: torch.Size([512, 512])
Layer: linear_relu_stack.2.bias | Size: torch.Size([512])
Layer: linear_relu_stack.4.weight | Size: torch.Size([10, 512])
Layer: linear_relu_stack.4.bias | Size: torch.Size([10])



In [13]:
for name, param in neural_network.named_parameters():
    print(name, param.size)

linear_relu_stack.0.weight <built-in method size of Parameter object at 0x107dd5a90>
linear_relu_stack.0.bias <built-in method size of Parameter object at 0x12fd8e620>
linear_relu_stack.2.weight <built-in method size of Parameter object at 0x12fd1b250>
linear_relu_stack.2.bias <built-in method size of Parameter object at 0x13c104690>
linear_relu_stack.4.weight <built-in method size of Parameter object at 0x13c104320>
linear_relu_stack.4.bias <built-in method size of Parameter object at 0x13b504a50>


In [14]:
torch.backends.mps.is_built()

True