直接使用torchvision中的模型

In [1]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


## 超参数

In [None]:
EPOCHS = 10
LR = 0.001
BATCHSIZE = 64

## 使用MNIST数据集

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5],std=[0.5])])
                               
train_data = torchvision.datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)
test_data = torchvision.datasets.MNIST(root="./data/",
                           transform = transform,
                           train = False)
                           
train_loader = DataLoader(train_data,batch_size=BATCHSIZE,shuffle=True,num_workers=4)
test_loader = DataLoader(test_data,batch_size=BATCHSIZE,shuffle=True,num_workers=4)

## 使用CIFAR10数据集

In [None]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
transforms_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

train_data = torchvision.datasets.CIFAR10(root="./data/",train=True,
                                transform=transform_train,
                                download=True)
test_data = torchvision.datasets.CIFAR10(root="./data/",train=False,
                                transform=transforms_test,
                                download=True)
print("train_size:", len(train_data))
print("test_size: ", len(test_data))

train_dataloader = DataLoader(dataset=train_data,batch_size=BATCHSIZE,shuffle=True)
test_dataloader = DataLoader(dataset=test_data,batch_size=BATCHSIZE,shuffle=True)

## 加载Models

In [None]:
# 1.调用模型
ResNet18 = torchvision.models.resnet18(pretrained=False)
# 1.1 如果使用MNIST数据集，由于图片是[1, 28, 28]的，通道数量为1,因此需要修改模型中输入层的in_channels
ResNet18.conv1 = nn.Conv2d(1, 64, 7)
# 2.提取fc层中固定的参数
fc_features = ResNet18.fc.in_features
# 3.修改输出的类别为10
ResNet18.fc = nn.Linear(fc_features, 10)
# 调整参数后，加载部分参数
model_dict = ResNet18.state_dict()

In [None]:
x = torch.randn(2, 3, 224, 224) # 两张三通道(RGB)224 * 224大小的图片
out = ResNet18(x)
print(x.shape)
print(out.shape)
# print("打印模型：")
# print(model)