In [12]:
## 用以下的代码可以证明，mps大概能提速三倍。

In [1]:
import torch 
from torch import nn 
import torchvision 
from torchvision import transforms 
import torch.nn.functional as F 


import os,sys,time
import numpy as np
import pandas as pd
import datetime 
from tqdm import tqdm 
from copy import deepcopy
from torchmetrics import Accuracy


def printlog(info):
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n"+"=========="*8 + "%s"%nowtime)
    print(str(info)+"\n")

In [14]:
!rm -rf mnist/

In [3]:
!ls data

[34mMNIST[m[m               [34mcifar-10-batches-py[m[m


In [2]:
transform = transforms.Compose([transforms.ToTensor()])

ds_train = torchvision.datasets.MNIST(root="mnist/",train=True,download=True,transform=transform)
ds_val = torchvision.datasets.MNIST(root="mnist/",train=False,download=True,transform=transform)

dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=2)
dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=2)

In [3]:
def create_net():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=64,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("conv2",nn.Conv2d(in_channels=64,out_channels=512,kernel_size = 3))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(512,1024))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(1024,10))
    return net

net = create_net()
print(net)

# 评估指标
class Accuracy(nn.Module):
    def __init__(self):
        super().__init__()

        self.correct = nn.Parameter(torch.tensor(0.0),requires_grad=False)
        self.total = nn.Parameter(torch.tensor(0.0),requires_grad=False)

    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        preds = preds.argmax(dim=-1)
        m = (preds == targets).sum()
        n = targets.shape[0] 
        self.correct += m 
        self.total += n
        
        return m/n

    def compute(self):
        return self.correct.float() / self.total 
    
    def reset(self):
        self.correct -= self.correct
        self.total -= self.total

Sequential(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 512, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout2d(p=0.1, inplace=False)
  (adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=512, out_features=1024, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=1024, out_features=10, bias=True)
)


In [4]:
loss_fn = nn.CrossEntropyLoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   
metrics_dict = nn.ModuleDict({"acc":Accuracy()})

In [11]:
torch.backends.gpu.is_available()

AttributeError: module 'torch.backends' has no attribute 'gpu'

In [5]:
# # =========================移动模型到mps上==============================
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# net.to(device)
# loss_fn.to(device)
# metrics_dict.to(device)
# # ====================================================================


epochs = 20 
ckpt_path='checkpoint.pt'

#early_stopping相关设置
monitor="val_acc"
patience=5
mode="max"

history = {}

for epoch in range(1, epochs+1):
    printlog("Epoch {0} / {1}".format(epoch, epochs))

    # 1，train -------------------------------------------------  
    net.train()
    
    total_loss,step = 0,0
    
    loop = tqdm(enumerate(dl_train), total =len(dl_train),ncols=100)
    train_metrics_dict = deepcopy(metrics_dict) 
    
    for i, batch in loop: 
        
        features,labels = batch
        
#         # =========================移动数据到mps上==============================
#         features = features.to(device)
#         labels = labels.to(device)
#         # ====================================================================
        
        #forward
        preds = net(features)
        loss = loss_fn(preds,labels)
        
        #backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
            
        #metrics
        step_metrics = {"train_"+name:metric_fn(preds, labels).item() 
                        for name,metric_fn in train_metrics_dict.items()}
        
        step_log = dict({"train_loss":loss.item()},**step_metrics)

        total_loss += loss.item()
        
        step+=1
        if i!=len(dl_train)-1:
            loop.set_postfix(**step_log)
        else:
            epoch_loss = total_loss/step
            epoch_metrics = {"train_"+name:metric_fn.compute().item() 
                             for name,metric_fn in train_metrics_dict.items()}
            epoch_log = dict({"train_loss":epoch_loss},**epoch_metrics)
            loop.set_postfix(**epoch_log)

            for name,metric_fn in train_metrics_dict.items():
                metric_fn.reset()
                
    for name, metric in epoch_log.items():
        history[name] = history.get(name, []) + [metric]
        

    # 2，validate -------------------------------------------------
    net.eval()
    
    total_loss,step = 0,0
    loop = tqdm(enumerate(dl_val), total =len(dl_val),ncols=100)
    
    val_metrics_dict = deepcopy(metrics_dict) 
    
    with torch.no_grad():
        for i, batch in loop: 

            features,labels = batch
            
#             # =========================移动数据到mps上==============================
#             features = features.to(device)
#             labels = labels.to(device)
#             # ====================================================================
            
            #forward
            preds = net(features)
            loss = loss_fn(preds,labels)

            #metrics
            step_metrics = {"val_"+name:metric_fn(preds, labels).item() 
                            for name,metric_fn in val_metrics_dict.items()}

            step_log = dict({"val_loss":loss.item()},**step_metrics)

            total_loss += loss.item()
            step+=1
            if i!=len(dl_val)-1:
                loop.set_postfix(**step_log)
            else:
                epoch_loss = (total_loss/step)
                epoch_metrics = {"val_"+name:metric_fn.compute().item() 
                                 for name,metric_fn in val_metrics_dict.items()}
                epoch_log = dict({"val_loss":epoch_loss},**epoch_metrics)
                loop.set_postfix(**epoch_log)

                for name,metric_fn in val_metrics_dict.items():
                    metric_fn.reset()
                    
    epoch_log["epoch"] = epoch           
    for name, metric in epoch_log.items():
        history[name] = history.get(name, []) + [metric]

    # 3，early-stopping -------------------------------------------------
    arr_scores = history[monitor]
    best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)
    if best_score_idx==len(arr_scores)-1:
        torch.save(net.state_dict(),ckpt_path)
        print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
             arr_scores[best_score_idx]),file=sys.stderr)
    if len(arr_scores)-best_score_idx>patience:
        print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
            monitor,patience),file=sys.stderr)
        break 
    net.load_state_dict(torch.load(ckpt_path))
    
dfhistory = pd.DataFrame(history)


Epoch 1 / 20



100%|██████████████████████████| 469/469 [00:40<00:00, 11.45it/s, train_acc=0.808, train_loss=0.705]
100%|████████████████████████████████| 79/79 [00:03<00:00, 24.42it/s, val_acc=0.962, val_loss=0.126]


Epoch 2 / 20




<<<<<< reach best val_acc : 0.9617000222206116 >>>>>>
100%|██████████████████████████| 469/469 [00:39<00:00, 11.95it/s, train_acc=0.958, train_loss=0.137]
100%|███████████████████████████████| 79/79 [00:03<00:00, 23.85it/s, val_acc=0.977, val_loss=0.0772]


Epoch 3 / 20




<<<<<< reach best val_acc : 0.9769999980926514 >>>>>>
100%|██████████████████████████| 469/469 [00:37<00:00, 12.40it/s, train_acc=0.967, train_loss=0.113]
100%|█████████████████████████████████| 79/79 [00:03<00:00, 24.06it/s, val_acc=0.97, val_loss=0.101]


Epoch 4 / 20




100%|███████████████████████████| 469/469 [00:42<00:00, 11.08it/s, train_acc=0.959, train_loss=0.14]
100%|████████████████████████████████| 79/79 [00:03<00:00, 24.13it/s, val_acc=0.967, val_loss=0.107]


Epoch 5 / 20




100%|██████████████████████████| 469/469 [00:40<00:00, 11.71it/s, train_acc=0.963, train_loss=0.126]
100%|███████████████████████████████| 79/79 [00:03<00:00, 23.55it/s, val_acc=0.977, val_loss=0.0812]


Epoch 6 / 20




<<<<<< reach best val_acc : 0.9771000146865845 >>>>>>
100%|███████████████████████████| 469/469 [00:38<00:00, 12.13it/s, train_acc=0.97, train_loss=0.102]
100%|███████████████████████████████| 79/79 [00:03<00:00, 23.12it/s, val_acc=0.979, val_loss=0.0713]
<<<<<< reach best val_acc : 0.9790999889373779 >>>>>>



Epoch 7 / 20



100%|██████████████████████████| 469/469 [00:41<00:00, 11.25it/s, train_acc=0.952, train_loss=0.181]
100%|████████████████████████████████| 79/79 [00:03<00:00, 22.60it/s, val_acc=0.968, val_loss=0.136]


Epoch 8 / 20




100%|█████████████████████████| 469/469 [00:40<00:00, 11.55it/s, train_acc=0.974, train_loss=0.0874]
100%|████████████████████████████████| 79/79 [00:03<00:00, 24.02it/s, val_acc=0.979, val_loss=0.083]


Epoch 9 / 20




100%|██████████████████████████| 469/469 [00:40<00:00, 11.48it/s, train_acc=0.954, train_loss=0.179]
100%|████████████████████████████████| 79/79 [00:03<00:00, 23.90it/s, val_acc=0.977, val_loss=0.093]



Epoch 10 / 20



100%|█████████████████████████| 469/469 [00:40<00:00, 11.59it/s, train_acc=0.971, train_loss=0.0985]
100%|███████████████████████████████| 79/79 [00:03<00:00, 23.59it/s, val_acc=0.979, val_loss=0.0826]


Epoch 11 / 20




100%|██████████████████████████| 469/469 [00:40<00:00, 11.68it/s, train_acc=0.967, train_loss=0.115]
100%|█████████████████████████████████| 79/79 [00:03<00:00, 23.46it/s, val_acc=0.971, val_loss=0.12]
<<<<<< val_acc without improvement in 5 epoch, early stopping >>>>>>
