In [1]:
import torch 
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

In [8]:
ds = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: 
                            torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

###  scatter_ which assigns a value=1 on the index as given by the label y

In [9]:
ds

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()
Target transform: Lambda()

In [10]:
from torch.utils.data import DataLoader
data = DataLoader(ds, batch_size=2)

In [12]:
for X, y in data:
    print(X.size(), y.size())
    break

torch.Size([2, 1, 28, 28]) torch.Size([2, 10])


## Scriptable transforms
- 通过即时运行的脚本实现图像变换

In [13]:
from torchvision import transforms

In [14]:
transform = torch.nn.Sequential(
    transforms.CenterCrop(10),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
)

scripted_transforms = torch.jit.script(transform)

In [15]:
scripted_transforms

RecursiveScriptModule(
  original_name=Sequential
  (0): RecursiveScriptModule(original_name=CenterCrop)
  (1): RecursiveScriptModule(original_name=Normalize)
)


- Sequential中自定义的transform必须是scriptable，即transform是作用于torch.Tensor，不能是PIL.Image 和 lanbda functions
- Sequential中自定义的transform必须是继承自torch.nn.module

## Functional Transforms
- 可更细致地控制transform，需要指定所有参数

In [16]:
import torchvision.transforms.functional as TF
import random

In [17]:
class MyRotationTransform:
    def __init__(self, angles):
        self.angles = angles
        
    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)

In [18]:
rotation_transform = MyRotationTransform(angles=[-30, -15])

In [19]:
rotation_transform

<__main__.MyRotationTransform at 0x7fbe82b148b0>