# KerasModel Example

MNIST 手写分类

# 一、环境准备

## 1.导包

In [1]:
import sys 
import numpy as np 
import pandas as pd 
import torchkeras
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from matplotlib import pyplot as plt
from argparse import Namespace
from pathlib import Path

## 2.检查 CUDA 状态

In [2]:
flag = torch.cuda.is_available()
if flag:
    print("CUDA可使用")
else:
    print("CUDA不可用")

# 获取GPU数量
ngpu = torch.cuda.device_count()
print("GPU数量：",ngpu)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print("驱动为：",device)
print("GPU型号： ",torch.cuda.get_device_name(0))

CUDA可使用
GPU数量： 2
驱动为： cuda:0
GPU型号：  NVIDIA GeForce RTX 3090


## 3.环境超参

In [None]:
sys.path.append("..")

config = Namespace(
    img_size = 128,
    lr = 1e-4,
    batch_size = 64,
    num_workers = 2
)

# 二、数据准备

## 1.定义数据增强方式

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

## 2.定义 dataset 和 dataloader

In [None]:
 
train_dataset = torchvision.datasets.MNIST(root="../datasets/mnist/",train=True,download=True,transform=transform)
train_dataset = torch.utils.data.Subset(train_dataset,range(0,len(train_dataset),20))
val_dataset = torchvision.datasets.MNIST(root="../datasets/mnist/",train=False,download=True,transform=transform)
val_dataset = torch.utils.data.Subset(val_dataset,range(0,len(val_dataset),20))

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config.batch_size, 
    shuffle=True, num_workers=config.num_workers
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=config.batch_size, 
    shuffle=False, num_workers=config.num_workers
)

## 3.检查数据

In [None]:
from PIL import Image
from torchkeras.plots import joint_imgs_row
from matplotlib import pyplot as plt 

for images, labels in train_dataloader:
    break

print(images.shape)
print(labels.shape)

fig = plt.figure(figsize=(8,8)) 
for i in range(9):
    img,label = images[i], labels[i]
    img = img.permute(1,2,0)
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title("label = %d"%label)
    ax.set_xticks([])
    ax.set_yticks([]) 
plt.show()

# 三、训练准备

## 1.定义评价指标

In [None]:
class Accuracy(nn.Module):
    def __init__(self):
        super().__init__()

        self.correct = nn.Parameter(torch.tensor(0.0),requires_grad=False)
        self.total = nn.Parameter(torch.tensor(0.0),requires_grad=False)

    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        preds = preds.argmax(dim=-1)
        m = (preds == targets).sum()
        n = targets.shape[0] 
        self.correct += m 
        self.total += n
        
        return m/n

    def compute(self):
        return self.correct.float() / self.total 
    
    def reset(self):
        self.correct -= self.correct
        self.total -= self.total
        

## 2.定义模型

In [None]:
def create_net():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(64,32))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(32,10))
    return net

net = create_net()
model = torchkeras.KerasModel(
    net,
    loss_fn = nn.CrossEntropyLoss(),
    optimizer= torch.optim.Adam(net.parameters(),lr=5e-3),
    metrics_dict = {"acc":Accuracy()}
)

## 3.模型Summary

In [None]:
from torchkeras import summary
summary(model,input_data=images)

## 4.可视化训练

In [None]:
# if gpu/mps is available, will auto use it, otherwise cpu will be used.
ckpt_path='checkpoint'
#model.load_ckpt(ckpt_path) #load trained ckpt and continue training
dfhistory=model.fit(
    train_data=train_dataloader, 
    val_data=val_dataloader, 
    epochs=30, 
    patience=3, 
    monitor="val_acc",
    mode="max",
    ckpt_path=ckpt_path,
    plot=True,
    wandb=False
)

## 5.训练结果查看

In [None]:
import matplotlib.pyplot as plt

def plot_metric(dfhistory, metric):
    train_metrics = dfhistory["train_"+metric]
    val_metrics = dfhistory['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.plot(epochs, val_metrics, 'ro-')
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.show()
    

In [None]:
plot_metric(dfhistory,'loss')

In [None]:
plot_metric(dfhistory,"acc")

In [None]:
model.evaluate(val_dataloader,quiet=False)

# 四、推理

In [None]:
#  used the saved model parameters 
new_net = create_net() 

model_clone = torchkeras.KerasModel(
    new_net,
    loss_fn = nn.CrossEntropyLoss(),
    optimizer= torch.optim.Adam(new_net.parameters(),lr = 0.001),
    metrics_dict={"acc":Accuracy()}
)

model_clone.load_ckpt("checkpoint")


net = model_clone.net
net.eval();

img, label = train_dataset[1]

y_pred = torch.argmax(net(img[None,...])) 
y_prob = torch.softmax(net(img[None,...]),dim=-1).max() 

img = img.permute(1,2,0)
plt.imshow(img)
print('y_pred = ', y_pred.item())
print('y_prob = ', y_prob.item())

