In [None]:
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff035939",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torch\n",
    "import torch.utils.data.dataloader as Data\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbae3dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "#残差块\n",
    "if_use_gpu=0\n",
    "class ResidualBlock(nn.Module):\n",
    "    def __init__(self, inchannel, outchannel, stride=1):\n",
    "        super(ResidualBlock, self).__init__()\n",
    "        self.left = nn.Sequential(\n",
    "            nn.Conv2d(inchannel,outchannel,kernel_size=3,padding=1,stride=stride,bias=False),\n",
    "            nn.BatchNorm2d(outchannel),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(outchannel, outchannel, kernel_size=3, padding=1, stride=stride, bias=False),\n",
    "            nn.BatchNorm2d(outchannel)\n",
    "        )\n",
    "        self.right = nn.Sequential()\n",
    "        #输入输出信道数不一样，把残差块的信道卷积到和输出一样\n",
    "        if(inchannel != outchannel):\n",
    "            self.right = nn.Sequential(\n",
    "                nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(outchannel),\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.left(x)\n",
    "        out += self.right(x)\n",
    "        out =F.relu(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e299b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet(nn.Module):\n",
    "    def __init__(self, ResidualBlock, num_classes=10):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.inchannel = 64\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1,bias=False),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)\n",
    "        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=1)\n",
    "        self.conv2 = nn.Conv2d(128,128,3,stride=2)\n",
    "        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=1)\n",
    "        self.conv3 = nn.Conv2d(256, 256, 3, stride=2)\n",
    "        #self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=1)\n",
    "        self.conv4 = nn.Conv2d(256,256,6)\n",
    "        self.fc = nn.Linear(256, num_classes)\n",
    "\n",
    "    def make_layer(self, block, channels, num_blocks, stride):\n",
    "        layer = []\n",
    "        for i in range(num_blocks):\n",
    "            layer.append(block(self.inchannel,channels,stride))\n",
    "            self.inchannel = channels\n",
    "        #对layer拆包\n",
    "        return nn.Sequential(*layer)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.conv2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.conv3(out)\n",
    "        #out = self.layer4(out)\n",
    "        out = self.conv4(out)\n",
    "        #out = F.avg_pool2d(out,4)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3179977",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ResNet18():\n",
    "\n",
    "    return ResNet(ResidualBlock)\n",
    "\n",
    "train_data = torchvision.datasets.MNIST(\n",
    "    './mnist', train=True,transform=torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    ]), download=True\n",
    ")\n",
    "train_data.data = train_data.data[:10000]\n",
    "train_data.targets = train_data.targets[:10000]\n",
    "test_data = torchvision.datasets.MNIST(\n",
    "    './mnist', train=False, transform=torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    ])\n",
    ")\n",
    "print(\"train_data:\", train_data.train_data.size())\n",
    "print(\"train_labels:\", train_data.train_labels.size())\n",
    "print(\"test_data:\", test_data.test_data.size())\n",
    "\n",
    "train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True)\n",
    "test_loader = Data.DataLoader(dataset=test_data, batch_size=32)\n",
    "\n",
    "model = ResNet18()\n",
    "if if_use_gpu:\n",
    "    model = model.cuda()\n",
    "\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee33c32c",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "loss_func = torch.nn.CrossEntropyLoss()\n",
    "for epoch in range(1):\n",
    "    print('epoch {}'.format(epoch + 1))\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        # get the inputs\n",
    "        inputs, labels = data\n",
    "        batch_x, batch_y = Variable(inputs), Variable(labels)\n",
    "        if if_use_gpu:\n",
    "            batch_x = batch_x.cuda()\n",
    "            batch_y = batch_y.cuda()\n",
    "        out = model(batch_x)\n",
    "        batch_y = batch_y.long()\n",
    "        loss = loss_func(out, batch_y)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # 返回每行元素最大值\n",
    "        pred = torch.max(out, 1)[1]\n",
    "        train_correct = (pred == batch_y).sum()\n",
    "        train_correct = train_correct.item()\n",
    "        train_loss = loss.item()\n",
    "        print('batch:{},Train Loss: {:.6f}, Acc: {:.6f}'.format(i+1,train_loss , train_correct /32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f11afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "    # evaluation--------------------------------\n",
    "model.eval()\n",
    "eval_loss = 0.\n",
    "eval_acc = 0.\n",
    "for batch_x, batch_y in test_loader:\n",
    "    batch_x, batch_y = Variable(batch_x, requires_grad=False), Variable(batch_y,requires_grad=False)\n",
    "    if if_use_gpu:\n",
    "        batch_x = batch_x.cuda()\n",
    "        batch_y = batch_y.cuda()\n",
    "    out = model(batch_x)\n",
    "    loss = loss_func(out, batch_y)\n",
    "    eval_loss += loss.item()\n",
    "    pred = torch.max(out, 1)[1]\n",
    "    num_correct = (pred == batch_y).sum()\n",
    "    eval_acc += num_correct.item()\n",
    "print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(\n",
    "    test_data)), eval_acc / (len(test_data))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "645ef449",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b90731b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59a5e57d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}

请点击[此处](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576)查看本环境基本用法.  <br>
Please click [here ](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576) for more detailed instructions. 