In [None]:
"""I Target： Here we learn how to chose optimizer tp optimize model

II Definition: Strategy to optimize model(update weights) by calculating gradient

III Instances:

  4.0 torch.optim
  4.1 params
  4.2 zero_grad(), step()



IV Compare 2 then Generalize

V Test in New instance 
"""

In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


from torch import nn


In [2]:
"""Inherit from Dataset class or load inner torchvision.dataset"""

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std = (0.5))  # normalize to [-1,1]
])

from torchvision.datasets import MNIST
train_dataset = MNIST(root="/home/hpczeji1/hpc-work/Codebase/Datasets/mnist_data",
                      train=True,
                      transform=transform,
                      target_transform=None,  # Eg1.2.1 : <class 'int'>
                      download=False)

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=10000,
                          shuffle=True)


from torch import nn
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=64,kernel_size=(3,3),stride=3) # in_channel is te image channel number
        self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size=(3,3),stride =3)
        self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size=(3,3),stride =3)
        self.relu = nn.ReLU(inplace = True)
        self.Flatten = nn.Flatten(start_dim=1, end_dim=-1) # (B,C,H,W), USE C,H,W
        self.Linear = nn.Linear(in_features=64*1*1, out_features=10,bias= False) # ins_feature num = flatten
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        print("[before flatten] x.shape: {}".format(x.shape))  # torch.Size
        x = self.Flatten(x)
        print("[after flatten] x.shape: {}".format(x.shape))  # torch.Size([1, 3920])
        x = self.Linear(x)

        return self.relu(x)

model = SimpleModel()


In [7]:
def update_all_pars():
    """specific parts of params to update
        4.0 torch.optim
        4.1 params
  """

    from torch import optim
    # params = [param for name, param in model.named_parameters() if ".bias" in name]
    optimizer = optim.SGD(params = model.parameters(), lr= 0.0001, momentum= 0.9)
    print(f"optimzer.state_dict():{optimizer.state_dict()} ") 

In [5]:
def update_part_of_pars():
    """specific parts of params to update
        4.0 torch.optim
        4.1 params
  """

    from torch import optim
    params = [param for name, param in model.named_parameters() if ".bias" in name]
    optimizer = optim.SGD(params = params, lr= 0.0001, momentum= 0.9)
    print(f"optimzer.state_dict():{optimizer.state_dict()} ") 
    

In [8]:
update_all_pars()
update_part_of_pars()

optimzer.state_dict():{'state': {}, 'param_groups': [{'lr': 0.0001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6]}]} 
optimzer.state_dict():{'state': {}, 'param_groups': [{'lr': 0.0001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2]}]} 


In [10]:
def entire_optim_process():
    """whole optim process"""

    from torch import optim
    from tqdm import tqdm
    optimizer = optim.SGD(params = model.parameters(), lr= 0.0001, momentum= 0.9)
    loss_fn = nn.CrossEntropyLoss()

    epoch_num = 2
    for epoch in range(epoch_num):
        with tqdm(train_loader) as train_bar:
            for x,y in train_bar:
                optimizer.zero_grad()
                loss = loss_fn(model(x),y)
                loss.backward()
                optimizer.step()
        print(f"epoch:{epoch}, loss{loss}")

In [11]:
entire_optim_process()

  0%|          | 0/6 [00:00<?, ?it/s]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 17%|█▋        | 1/6 [00:04<00:21,  4.34s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 33%|███▎      | 2/6 [00:07<00:14,  3.50s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 50%|█████     | 3/6 [00:10<00:09,  3.21s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 67%|██████▋   | 4/6 [00:12<00:06,  3.03s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 83%|████████▎ | 5/6 [00:15<00:03,  3.04s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


100%|██████████| 6/6 [00:18<00:00,  3.10s/it]


epoch:0, loss2.30184268951416


  0%|          | 0/6 [00:00<?, ?it/s]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 17%|█▋        | 1/6 [00:02<00:14,  2.92s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 33%|███▎      | 2/6 [00:05<00:11,  2.83s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 50%|█████     | 3/6 [00:08<00:08,  2.81s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 67%|██████▋   | 4/6 [00:11<00:05,  2.77s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


 83%|████████▎ | 5/6 [00:13<00:02,  2.77s/it]

[before flatten] x.shape: torch.Size([10000, 64, 1, 1])
[after flatten] x.shape: torch.Size([10000, 64])


100%|██████████| 6/6 [00:16<00:00,  2.78s/it]

epoch:1, loss2.301483631134033



