# Tutorial

In [1]:
%matplotlib inline

In [4]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

In [17]:
ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)

In [18]:
ds[0]

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,
           0.2863, 0.0000, 0.0000, 0.0039, 

# torch.tensor.scatter_
For a 3-D tensor, self is updated as:

self [ index[i][j][k] ] [j] [k] = src [i] [j] [k]  # if dim == 0\
self [i] [ index[i][j][k] ] [k] = src [i] [j] [k]  # if dim == 1\
self [i] [j] [ index[i][j][k] ] = src [i] [j] [k]  # if dim == 2

In [10]:
src = torch.arange(1, 11).reshape(2, 5)
src

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])

## scatter_(int dim, Tensor index, Tensor src)
For each value in src,\
its output index is specified by its index in src for dimension != dim\
and by the corresponding value in index for dimension = dim.

In [13]:
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros((3, 5), dtype=src.dtype).scatter_(dim=0, index=index, src=src)

tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

## scatter_(int dim, Tensor index, Number value)
assigns a value=1 on the index as given by the index

In [16]:
torch.zeros((3, 5), dtype=src.dtype).scatter_(dim=0, index=index, value=-1)

tensor([[-1,  0,  0, -1,  0],
        [ 0, -1,  0,  0,  0],
        [ 0,  0, -1,  0,  0]])