In [None]:
"""I Target： Here we learn how to build model and load data to model

II Definition: Build your model inherit from (nn.Module), which is Base class for all neural network modules.
            Your models should also subclass this class.

            Modules can also contain other Modules, allowing to nest them in
            a tree structure. You can assign the submodules as regular attributes::

III Instances:
  3.0-3.2 build own model
  3.3 load inner model
  3.4-3.5 change_state_dict and load_pth_weight(from local or online )

  3.0.0 torch.nn.Module
  3.0.1 super().__init__()
  3.1 __call__  [magic method]
  3.2 (B, C, H ,W)
  3.3 torchvison.models
  3.4.0 model.dict_state()
  3.4.1 torch.save(model.dict_state(), f)
  3.4.2 model.load_state_dict()
  3.4.3 strict=False
  3.5 torch.utils.model_zoo.load_url()


IV Compare 2 then Generalize

V Test in New instance 
"""

In [2]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


from torch import nn


In [4]:
"""Inherit from Dataset class or load inner torchvision.dataset"""

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std = (0.5))  # normalize to [-1,1]
])

from torchvision.datasets import MNIST
train_dataset = MNIST(root="/home/hpczeji1/hpc-work/Codebase/Datasets/mnist_data",
                      train=True,
                      transform=transform,
                      target_transform=None,  # Eg1.2.1 : <class 'int'>
                      download=False)

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=10000,
                          shuffle=True)


In [13]:
def use_SimpleModel():
    """Build a 3 layer Nets inherits from nn.Module
        3.0.0 torch.nn.Module
        3.0.1 super().__init__()
        3.1 __call__  [magic method]
        3.2 (B, C, H ,W)
        
        """

    from torch import nn
    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=1,out_channels=64,kernel_size=(3,3),stride=3) # in_channel is te image channel number
            self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size=(3,3),stride =3)
            self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size=(3,3),stride =3)
            self.relu = nn.ReLU(inplace = True)
            self.Flatten = nn.Flatten(start_dim=1, end_dim=-1) # (B,C,H,W), USE C,H,W
            self.Linear = nn.Linear(in_features=64*1*1, out_features=10,bias= False) # ins_feature num = flatten
        
        def forward(self, x):
            x = self.conv1(x)
            x = self.relu(x)
            x = self.conv2(x)
            x = self.relu(x)
            x = self.conv3(x)
            x = self.relu(x)
            print("[before flatten] x.shape: {}".format(x.shape))  # torch.Size
            x = self.Flatten(x)
            print("[after flatten] x.shape: {}".format(x.shape))  # torch.Size([1, 3920])
            x = self.Linear(x)

            return self.relu(x)
    
    model = SimpleModel()
    print(f"len(dataset):{len(train_dataset)}")
    x = train_dataset[0][0]  # torch.Size([1, 28, 28])
    x = x[None,...]
    print(f"shape:{x.shape}")
    read_out = model(x)
    print(f"read_out:{read_out}")

In [14]:
use_SimpleModel()

len(dataset):60000
shape:torch.Size([1, 1, 28, 28])
[before flatten] x.shape: torch.Size([1, 64, 1, 1])
[after flatten] x.shape: torch.Size([1, 64])
read_out:tensor([[0.0000, 0.0705, 0.0000, 0.0000, 0.0000, 0.0000, 0.0574, 0.0000, 0.0296,
         0.0000]], grad_fn=<ReluBackward0>)


In [15]:
def load_innerModel():
    """load built Model from torchvision.models
        3.3 torchvison.models
        
        """

    from torchvision import models
    model_alexnet = models.alexnet()
    print(f"model_alenxt:{model_alexnet}")

    model_resnet50 = models.resnet50()
    print(f"model_resnet50 :{model_resnet50}")


In [16]:
load_innerModel()

model_vgg16 :AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096,

In [17]:
def change_model_state_dict():
    """change_model_state_dict
          3.4.0 model.dict_state()
            3.4.1 torch.save(model.dict_state(), f)
            3.4.2 model.load_state_dict()
            3.4.3 strict=False
        
        """

    from torchvision import models
    model_alexnet = models.alexnet()
    print(f"model_alexnet.state_dict() :{model_alexnet.state_dict()}")
    print(f"type of model_alexnet.state_dict() :{type(model_alexnet.state_dict())}")

In [18]:
change_model_state_dict()

model_alexnet.state_dict() :OrderedDict([('features.0.weight', tensor([[[[ 1.0313e-02,  3.7982e-02, -3.0447e-03,  ..., -3.5254e-02,
            1.6232e-02,  3.8825e-02],
          [-1.1292e-02,  1.5655e-02,  1.7395e-02,  ...,  1.9284e-02,
            3.9204e-02, -3.5651e-02],
          [ 3.1970e-03,  1.3737e-02,  4.5383e-02,  ..., -5.8107e-03,
            2.5844e-02,  3.4569e-02],
          ...,
          [ 2.3277e-03, -1.5019e-02,  2.4953e-02,  ..., -3.3582e-02,
           -1.3999e-02,  4.1907e-02],
          [ 6.7409e-03, -8.8145e-03,  2.2970e-04,  ...,  1.7949e-02,
           -2.5192e-02,  2.9727e-03],
          [-2.0027e-02, -2.0344e-02,  3.9156e-02,  ...,  5.0799e-02,
            3.6141e-02,  1.6484e-02]],

         [[ 1.5543e-02, -2.3919e-02, -2.3632e-02,  ..., -4.5157e-02,
            3.1590e-02, -2.3541e-02],
          [-2.3125e-02,  4.4084e-02,  6.3867e-03,  ...,  2.0587e-02,
            2.9080e-03, -4.9455e-02],
          [-1.2382e-02,  2.5002e-02,  6.4726e-03,  ..., -4.5664e

In [4]:
def save_pth_weight():
    """change_model_state_dict

            3.4.1 torch.save(model.dict_state(), f)
            3.4.2 model.load_state_dict()
            3.4.3 strict=False
            3.5 : torch.utils.model_zoo.load_url()
        
        """

    from torchvision import models
    model_alexnet = models.alexnet()
    torch.save(model_alexnet.state_dict(),"./alexnet.pth")

def load_pth_weight():
    """3.4.2 model.load_state_dict()
    """
    from torchvision import models

    model_alexnet = models.alexnet()
    state_dict = torch.load("./alexnet.pth", map_location = "cpu")
    missing_keys,unexpected_keys=  model_alexnet.load_state_dict(state_dict, strict = True)
    print("missing_keys: {}".format(missing_keys))
    print("unexpected_keys: {}".format(unexpected_keys))



def change_state_dict():
    """3.4.3 strict=False"""
    from torchvision import models

    model_alexnet = models.alexnet()
    state_dict = torch.load("./alexnet.pth", map_location = "cpu")
    for key in list(state_dict.keys()):
        if ".bias" in key:
            del state_dict[key]
    

    missing_keys,unexpected_keys =  model_alexnet.load_state_dict(state_dict, strict = False)
    print("missing_keys: {}".format(missing_keys))
    print("unexpected_keys: {}".format(unexpected_keys))

def download_pth_weight_from_url():
    """3.5 : torch.utils.model_zoo.load_url()"""
		
    from torch.utils import model_zoo
    from torchvision import models

    model_alexnet = models.alexnet()
    state_dict = model_zoo.load_url('http://download.pytorch.org/models/alexnet-owt-7be5be79.pth')
    model_alexnet.load_state_dict(state_dict)

In [5]:
save_pth_weight()
# load_pth_weight()
# change_state_dict()
# download_pth_weight_from_url()

NameError: name 'torch' is not defined