In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
%load_ext autoreload
%autoreload 2

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using device: {device}')

using device: cuda


In [16]:
class NNWork(nn.Module):
    def __init__(self):
        super().__init__()
        self.basic_mod = nn.Sequential(
        nn.Linear(28*28, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 10))
        self.flaten = nn.Flatten()
    
    def forward(self, x):
        x = self.flaten(x)
        logits = self.basic_mod(x)
        return logits

In [17]:
model = NNWork().to(device)
print(model)

NNWork(
  (basic_mod): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
  (flaten): Flatten(start_dim=1, end_dim=-1)
)


In [25]:
# forward
x = torch.rand((1, 28, 28), device=device)
print(x.shape)
logits = model(x)
print(logits.shape)
softmax_obj = nn.Softmax(dim=1)
probs = softmax_obj(logits)
print(probs)
y_pred = probs.argmax(1)
print(f'Predicted class: {y_pred}')

torch.Size([1, 28, 28])
torch.Size([1, 10])
tensor([[0.0948, 0.0922, 0.1011, 0.0928, 0.1025, 0.0947, 0.1057, 0.0976, 0.1110,
         0.1075]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Predicted class: tensor([8], device='cuda:0')


In [32]:
# some module
x = torch.rand((3, 1, 28, 28))
print(x.shape)

flatten_mod = nn.Flatten()
flat_x = flatten_mod(x)
print(flat_x.shape)

linear_mod = nn.Linear(28*28, 20)
linear_x = linear_mod(flat_x)
print(linear_x.shape)

relu_mod = nn.ReLU()
relu_x = relu_mod(linear_x)
print(relu_x.shape)

torch.Size([3, 1, 28, 28])
torch.Size([3, 784])
torch.Size([3, 20])
torch.Size([3, 20])


In [28]:
# model parameters
print(f"Model structure: {model}\n\n")
para_dict = model.named_parameters()
for name, para in para_dict:
    print(f'para: {name}, shape: {para.shape}')

Model structure: NNWork(
  (basic_mod): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
  (flaten): Flatten(start_dim=1, end_dim=-1)
)


para: basic_mod.0.weight, shape: torch.Size([512, 784])
para: basic_mod.0.bias, shape: torch.Size([512])
para: basic_mod.2.weight, shape: torch.Size([512, 512])
para: basic_mod.2.bias, shape: torch.Size([512])
para: basic_mod.4.weight, shape: torch.Size([10, 512])
para: basic_mod.4.bias, shape: torch.Size([10])
