<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Pytorch-中的-layers" data-toc-modified-id="Pytorch-中的-layers-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Pytorch 中的 layers</a></span></li><li><span><a href="#Pytorch-中-forward" data-toc-modified-id="Pytorch-中-forward-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Pytorch 中 forward</a></span></li><li><span><a href="#Pytorch-中-Dataset" data-toc-modified-id="Pytorch-中-Dataset-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Pytorch 中 Dataset</a></span></li><li><span><a href="#Pytorch-中-DataLoader" data-toc-modified-id="Pytorch-中-DataLoader-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Pytorch 中 DataLoader</a></span></li></ul></div>

Pytorch : 易于使用的 API 和很多 python 原生的特点， 本文将详细介绍这些原生特点

author: Amit Chaudhary

link : https://amitness.com/2020/03/python-magic-behind-pytorch/

#  Pytorch 中的 layers

In [1]:
import torch
import torch.nn as nn

x = torch.rand(1, 784)
layer = nn.Linear(784, 10)
output = layer(x)

我们调用 nn.Linear(), 并输入了变量 tensor x。 

按照 python 来看， nn.Linear() 返回的应该是一个函数， 但是实际上

In [6]:
print(type(layer))

<class 'torch.nn.modules.linear.Linear'>


我们可以看到， nn.Linear实际上是一个 class， 然后 layer 是他的object

为什么我们可以调用 object？

python提供了一个方法来产生可调用的 object


In [7]:
class Double(object):
    def __call__(self, x):
        return 2*x

在上面的类中加入了 \_\_call__ 方法

这样之后创建的 object 就可以直接调用

In [8]:
d = Double()
d(2)

4

In [9]:
# 或者 
Double()(2)

4

实际上， everything in python is an object.

function 在背后也是调用了 \_\_call__ 方法

In [10]:
def double(x):
    return 2*x

In [12]:
print(double(2))
print(double.__call__(2))

4
4


# Pytorch 中 forward

对 MNIST 应用一个全连接层

In [13]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 10)        

    def forward(self, x):
        return self.fc1(x)

In [14]:
x = torch.rand(10, 784)
model = Model()
output = model(x)

我们知道， 在我们调用 model 的时候，直接执行的是 forward() 函数。

为什么是这样的呢？

因为我们创建的 class继承了 nn.Module， 而 nn.Module 是有一个 \_\_call__() 方法来调用 forward()

In [15]:
# nn.Module
class Module(object):
    def __call__(self, x):
        # Simplified
        # Actual implementation has validation and gradient tracking.
        return self.forward(x)

# Pytorch 中 Dataset

继承 Dataset 进行数据集预处理

In [16]:
from torch.utils.data import Dataset

class Numbers(Dataset):
    def __init__(self, x, y):
        self.data = x
        self.labels = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return (self.data[i], self.labels[i])

In [17]:
dataset = Numbers([1, 2, 3], [0, 1, 0])
print(len(dataset))
print(dataset[0])

3
(1, 0)


上面定义的方法是 python中常见的内嵌方法

python 同时也允许在自定义的 class中定义相应的方法

# Pytorch 中 DataLoader

In [24]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

trainset = MNIST(root='mnist', 
                 download=True, 
                 train=True, 
                 transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw
Processing...
Done!




In [25]:
trainloader[0]

TypeError: 'DataLoader' object does not support indexing

In [26]:
images, labels =  next(iter(trainloader))

In [20]:
x = [1, 2, 3]
y = iter(x)

In [21]:
next(y)

1

In [23]:
next(y)

2

In [27]:
next(y)

3

In [28]:
next(y)

StopIteration: 

可以获取所有的元素直到  StopIteration

这和机器学习 batches 数据流相似

所以 pytorch 中引入这种方法



In [36]:
class ExampleLoader(object):
    def __init__(self, data):
        self.data = iter(data)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)

In [37]:
l = ExampleLoader([1, 2, 3])

In [38]:
next(iter(l))

1

In [39]:
next(iter(l))

2

Pytorch 中 DataLoader 的实现

In [None]:
class DataLoader(object):
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __next__(self):
        # logic to return batch from whole data
        ...

In [41]:
type(iter(trainloader))

torch.utils.data.dataloader._SingleProcessDataLoaderIter