# PyTorch Tutoria: 02 Dataset and Iterator
## Overview 
In this tutorial, we will cover the basics of constructing dataset and iterators so that we can train models using gradient descent. 

The best tutorial can be found in the official website (https://pytorch.org/tutorials/beginner/data_loading_tutorial.html).

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
class MyStupidDataset(Dataset):
    def __init__(self):
        super(MyStupidDataset, self).__init__()
        self.data = torch.randn([1024, 10, 10])
    def __len__(self):
        return 1024
    def __getitem__(self, idx):
        return self.data[idx, :, :]

In [3]:
my_stupid_dataset = MyStupidDataset()

In [4]:
my_data_loader = DataLoader(my_stupid_dataset, batch_size=64, shuffle=True)

In [5]:
for i in my_data_loader:
    print(i)

tensor([[[-6.8660e-02,  1.6420e+00,  6.9785e-01,  ...,  4.1900e-01,
          -3.3634e-01,  4.1674e-02],
         [-6.2595e-01,  6.5081e-01, -7.5044e-01,  ..., -8.6566e-01,
           5.0536e-01,  7.3824e-01],
         [ 2.3080e-03,  1.5406e+00,  1.5272e-01,  ..., -1.2444e+00,
          -5.2449e-02,  2.5395e-01],
         ...,
         [ 1.7279e+00, -1.1304e+00, -1.4048e+00,  ...,  6.0996e-01,
           4.4543e-01, -5.7414e-02],
         [ 1.3614e+00, -9.0990e-01, -1.0888e+00,  ...,  1.9286e-01,
           1.0120e+00,  1.7502e-01],
         [ 6.7001e-01,  1.4039e-01,  4.6799e-01,  ...,  4.4857e-01,
           2.0794e+00, -1.5829e+00]],

        [[-7.9187e-01, -5.1659e-01, -1.9433e-01,  ..., -1.2633e+00,
           7.4715e-01,  1.1700e+00],
         [-8.6236e-01, -7.4198e-01,  1.1978e+00,  ..., -1.2386e+00,
          -4.5127e-01,  8.9427e-01],
         [ 4.0035e-02, -7.4227e-03,  6.0668e-01,  ...,  1.7652e+00,
           9.3904e-01,  2.4999e-01],
         ...,
         [ 3.7512e-01,  8

In [6]:
for i in my_stupid_dataset:
    print(i)

tensor([[-1.0029,  1.0388,  1.5498,  0.7719,  0.6411,  0.9742, -2.2901,  1.0877,
         -0.0750, -1.6873],
        [-0.0642,  1.0232,  0.0577, -1.5024, -0.2303,  1.5363, -0.7649,  0.1474,
          0.2309,  0.5408],
        [ 0.0859,  0.3170,  1.0073, -0.8957, -1.3219,  1.0542, -1.1091,  1.2505,
         -0.8023,  1.0306],
        [-1.4781,  0.2222, -1.9674,  0.6218, -0.9273,  0.8782,  1.3813, -0.0294,
         -1.3999,  0.1902],
        [-0.1277, -1.1634, -1.1453,  0.3033,  0.1244,  1.8072,  0.7136,  0.5365,
         -1.0303, -0.2905],
        [-0.4266, -0.8935, -0.5348, -0.9009, -0.4484,  0.0356,  0.3612, -0.4996,
         -1.2951,  1.9423],
        [ 0.3762,  0.8134, -1.0970, -0.3778,  1.5930, -1.7386,  0.8480, -1.2167,
          0.5684,  0.5052],
        [-0.6821,  0.1144,  0.0762, -0.9141,  0.6413, -2.0644, -0.8863, -0.2927,
         -0.0231,  0.1389],
        [ 0.3650,  0.7560,  0.7245,  0.9168, -0.5839, -0.2121, -1.0272, -0.1392,
          2.2235,  0.1776],
        [-0.4606,  

tensor([[-0.0488, -2.0708,  0.0276,  2.2099,  1.3160, -1.8089,  2.3743, -0.1263,
         -0.3506, -0.9688],
        [-0.9898,  0.0149,  0.0211, -0.5880,  0.5062,  0.8377,  0.9083, -0.0177,
          0.0352,  1.2649],
        [-1.5655,  0.0091, -0.2244, -1.9103,  0.4272,  0.2260, -0.6866, -2.2301,
          0.7695, -1.4659],
        [ 1.0040, -1.7132,  1.1256, -0.0799, -0.0925, -0.0805, -1.2941,  0.9468,
         -0.5598,  1.3769],
        [ 0.6182, -0.2458, -1.9644, -2.1393, -0.8837,  0.7780, -0.9773,  1.7257,
          0.4371,  1.1792],
        [-0.1358, -0.7644,  1.3500, -0.2615,  2.2041, -0.7541, -0.7851, -0.3283,
         -0.3359, -0.6508],
        [ 1.4523,  0.1028,  0.7291,  0.0640, -0.0640, -0.6213,  0.6821, -0.2101,
          0.5273,  0.6606],
        [-0.8760, -0.0878,  0.4402, -1.2537,  0.5269,  0.1990, -1.6183, -0.5695,
         -0.6605, -0.9031],
        [-1.9830, -0.5445, -0.3442,  0.1821,  0.2397, -2.2093,  0.4930,  0.7985,
          0.5360, -0.1596],
        [ 1.3216,  

tensor([[ 1.7723,  2.8859, -1.6202,  0.4025,  0.2724,  1.6781, -1.2096, -0.3011,
          0.1143,  0.4581],
        [-0.6016, -1.3069,  2.5410, -0.8052,  0.8486, -0.1057, -0.5974,  0.1357,
          0.5620, -0.4829],
        [-0.1600, -1.1770,  0.4044, -0.3679,  0.6215, -0.1402,  0.8218,  1.1396,
          0.9218,  0.4341],
        [ 1.0385,  0.4823,  1.2985, -0.8118, -0.7953, -0.6758,  0.9801,  0.0862,
         -0.4523, -0.3325],
        [-0.2427, -0.0745,  0.5546,  1.5552,  1.2166,  0.8044, -0.9229,  1.7301,
         -0.6049, -1.5732],
        [-0.1942,  1.7095,  0.3189, -1.5148, -1.1102, -0.4232,  0.0124,  0.6378,
         -1.9039, -0.2112],
        [-0.8420, -0.2902,  1.0789,  0.3499,  0.9953,  0.6621,  0.0393, -0.4562,
         -1.4678, -0.0909],
        [ 0.4221,  1.2743, -0.3998, -0.3880, -1.0172, -0.4873,  0.3328, -0.6755,
          1.5773,  0.3155],
        [-0.5176, -0.2862, -0.4956, -0.2648, -0.8796,  0.6120,  0.5817, -0.7400,
         -0.0284,  1.4805],
        [ 1.1365, -

tensor([[-0.2408, -0.6680,  0.3397, -1.6189,  1.3225,  1.5102, -0.9196,  1.9810,
         -0.5905, -0.2282],
        [ 0.5512, -0.9554, -0.3730,  1.7836, -0.4424,  0.3947, -1.0520,  0.5514,
          0.1769,  0.2674],
        [-0.2512, -0.6737,  0.1166, -1.2475, -0.4762,  0.0576,  1.3228,  1.0562,
         -2.3675,  0.5103],
        [ 0.5283,  1.4887,  0.3892, -1.5045,  0.2289,  0.1336, -0.6130,  0.4544,
          0.6285, -0.9932],
        [ 0.4528,  0.4900, -0.1704,  0.2646,  0.3681, -1.3397, -0.2891, -0.4314,
          0.6209,  0.6244],
        [ 0.4539,  0.5996, -1.3882, -0.1604,  0.4458,  0.6377, -0.2663,  0.7746,
         -0.3274,  0.2745],
        [ 1.2699, -0.4175,  0.2632,  0.9406,  0.7921,  1.7075, -0.1217,  1.2854,
          1.2852,  0.4642],
        [-0.7367, -0.2216,  0.2919, -0.3361, -0.2509, -0.6458,  0.8703,  0.5867,
          0.0278,  0.9162],
        [-0.2734, -1.9018,  1.0349,  0.2813,  1.1187,  0.2398,  1.6202, -0.8932,
         -0.5405, -0.1442],
        [ 0.2757, -

tensor([[-0.3316, -0.7378,  0.5056, -0.0866, -1.0453, -1.6405,  0.1511,  0.5363,
         -0.6365,  0.3975],
        [-0.0676, -1.2689, -0.0192, -0.3438, -2.0454,  0.7503, -1.4330, -0.6718,
          1.6537, -0.3669],
        [ 0.5587,  1.0196,  0.8792,  0.7290,  0.8771,  0.2743,  1.3736, -1.2732,
          0.6803,  1.9633],
        [-0.5542, -1.9325,  0.9514,  1.5239, -0.8215, -0.2691, -0.9320, -0.6931,
         -1.0493,  0.0908],
        [-1.6882, -0.8838,  0.0351, -1.4553, -0.9361,  0.0897,  0.4443,  0.1114,
          0.7962,  0.9668],
        [-0.8758,  1.9213, -1.2528, -0.5328,  0.6871,  0.1431,  1.2202, -0.5957,
         -1.4697, -0.0473],
        [-0.8685,  0.0791, -1.9608,  0.9768,  1.5362,  0.4618,  1.4199, -1.0010,
         -0.3783,  1.1860],
        [ 0.2622, -0.7161, -0.6915,  2.1293,  0.3562,  1.0214,  0.1295,  0.1026,
         -0.6272,  0.8399],
        [ 1.5401, -0.3859, -0.5898,  1.0596, -0.4846,  0.9756, -0.7146, -0.1514,
          0.2229,  0.3292],
        [ 1.4493,  

tensor([[ 0.3736,  0.2844, -0.7626,  0.1016, -0.7255, -0.0958,  0.1232, -0.2237,
          0.4873,  0.8551],
        [-0.1388,  0.2550,  0.7251,  0.3933,  0.7851, -0.1664, -0.5894,  0.1652,
          0.2434, -0.9349],
        [ 0.7267, -1.2269, -0.8573,  1.0082,  1.9042,  1.2300, -0.5422, -0.9395,
          0.4545,  2.2094],
        [ 0.7280,  1.1810, -0.9139,  0.1250, -0.9225, -0.2424,  2.0934,  2.5585,
          1.1978, -0.3595],
        [ 1.3237,  0.2944, -0.1468,  0.3164, -1.0962,  2.0101,  0.5950,  1.7291,
          0.1291,  3.3665],
        [ 1.2569, -0.7202,  0.2017, -1.1198, -0.1384, -0.1382, -0.2325, -1.2300,
         -0.1646,  0.9498],
        [ 1.1815,  0.8125,  0.3263, -1.0079,  0.4494, -0.7375,  0.8835, -0.2993,
         -0.3766, -0.4691],
        [-0.7037, -0.7730, -1.3959,  0.8470, -1.0604,  1.2187, -0.8813, -1.0622,
         -0.4694, -1.2362],
        [ 0.3710,  1.4220,  0.0202, -0.2752,  0.4125, -0.5670, -1.5005, -0.3512,
         -0.5491, -0.0517],
        [ 1.5787,  

tensor([[ 1.0458e+00,  2.2170e+00, -5.4608e-01, -5.9184e-01, -3.5479e-01,
          4.3420e-01,  5.7998e-01,  8.3754e-02, -1.3103e+00,  7.0191e-01],
        [-8.7959e-01, -1.9324e-01, -1.0545e+00,  6.7171e-02, -1.1387e-01,
         -5.9516e-03, -9.5170e-01, -5.4022e-01, -5.1680e-01, -1.1563e+00],
        [ 2.6391e-01,  1.9833e+00, -1.4963e+00,  2.7347e+00, -9.6633e-02,
         -7.9804e-02, -8.5534e-02, -1.6050e+00,  1.2012e-01, -2.4541e-01],
        [-7.6466e-02, -1.4474e+00, -1.0371e+00,  7.6348e-01, -1.2923e+00,
          1.0129e+00,  8.4598e-01, -2.4071e-01, -2.6844e-01, -7.5305e-01],
        [ 6.2200e-01, -7.5126e-01, -1.1891e+00,  9.4559e-01,  4.1323e-01,
          1.6534e+00,  5.6482e-01,  9.3024e-01,  2.1861e-01, -1.4640e+00],
        [-6.5777e-01, -4.1769e-01, -3.3965e-03, -9.2647e-01,  8.5744e-01,
          1.1917e+00, -1.4408e+00, -1.0512e+00,  4.6186e-01, -2.5838e-01],
        [ 7.9110e-01, -1.5363e+00, -3.9529e-01, -6.0562e-01, -1.2659e+00,
         -4.4701e-01, -7.0882e-0

tensor([[ 0.4659,  1.3164, -0.5458, -0.8925,  0.5235,  0.3867,  0.2631,  0.0760,
          0.5835, -0.5673],
        [-0.0550, -0.4376,  1.3016, -0.0612, -0.9287,  0.9367,  0.0212,  2.2597,
          0.3318, -0.8276],
        [ 0.3128, -1.5009,  0.6969, -1.3758,  0.2855, -0.6986,  0.4244, -1.3656,
         -0.2582, -1.0933],
        [ 0.4289,  0.5075, -0.7466,  0.4242,  1.1482,  0.2683,  0.8355, -2.6489,
         -0.2276, -1.3258],
        [-1.0208,  0.9453, -0.0819, -0.0392, -0.4380, -0.6695,  0.7868,  1.4747,
         -0.5615, -0.1339],
        [-0.5983, -1.0605, -0.4225, -0.3259,  1.5533,  1.3386, -1.1157,  1.8375,
         -1.3629, -1.2625],
        [-0.5016,  1.2159,  0.7675, -0.1146, -0.9247, -1.4577, -0.1479,  0.0390,
          0.6178,  1.5339],
        [-0.9599, -1.4247,  1.2485, -0.4343,  0.2918, -2.1242,  0.8917, -0.2064,
          0.0447,  0.6898],
        [-1.2888,  1.7856,  0.5154,  0.0858,  0.2854,  0.2405,  0.5584,  0.0876,
         -0.3645,  0.1877],
        [ 0.4195, -

tensor([[ 1.4807,  0.9776,  0.6251, -1.3396,  0.6476, -0.3834, -1.0669,  0.6886,
         -0.9065,  2.5718],
        [ 1.1146,  1.2677, -0.3415, -0.1922,  2.0222, -0.6115,  1.3340,  0.5934,
          1.7450, -0.9943],
        [ 0.3885, -0.1716,  0.0313,  0.4877,  0.5402, -0.5676, -1.5679, -0.3804,
          1.5642, -0.2091],
        [-0.3999, -0.8662,  0.6471,  0.6361, -0.7193,  0.6464, -0.4188, -0.1456,
         -0.3254,  0.2798],
        [ 0.1914,  0.1699,  0.1366, -0.3483, -0.0702, -0.6360, -0.1966, -1.7092,
          0.9838, -1.8087],
        [ 0.1906,  0.6504,  0.8676,  0.7396, -0.9676,  1.9229,  1.9606, -0.0701,
         -0.7121, -1.0006],
        [-0.2100, -1.0293, -1.4542,  0.4434,  1.9351, -1.3168,  1.0352,  1.1190,
         -0.3746, -0.1932],
        [-0.4173,  0.8004, -0.5500, -0.4396,  0.2989,  0.4535, -1.3774,  0.0224,
         -0.1342, -2.6073],
        [-0.2080, -0.5203,  0.6708,  0.0837, -1.2746, -0.3973, -1.2307, -1.0969,
         -0.5521, -0.0419],
        [-0.0456, -

tensor([[ 1.6550,  0.9872, -0.6270, -0.0667,  0.0117,  1.5187, -0.7868, -0.1044,
          0.0224,  0.3911],
        [ 1.0327, -0.3405, -0.4294,  1.3128, -0.7282,  1.0523,  0.8364,  1.8514,
         -0.3679,  1.2126],
        [-1.5586, -0.8071,  0.2934,  1.3517,  0.1665, -1.1687, -1.5108,  0.5839,
         -0.4849,  0.0755],
        [ 0.6176,  0.2409,  0.4612, -0.6800, -0.5262,  0.3141,  0.5537, -1.4904,
         -1.3642,  0.3669],
        [-1.0390, -0.2290,  0.9903, -0.8928, -0.5865, -0.3040,  1.1712, -0.2026,
          0.3791, -1.0629],
        [ 0.4243, -0.1072, -0.3525,  0.9698, -0.1104, -0.8399, -1.5829, -0.1410,
          0.4798, -0.2466],
        [ 0.7790, -1.2162, -0.5430, -0.9379, -0.3236,  0.0534, -0.2378, -0.0157,
          0.2698, -1.5290],
        [ 0.5622, -0.2356, -1.3534,  1.0084, -0.0320, -0.9420, -0.2246, -0.5780,
          0.5753,  0.1898],
        [-0.5200, -0.3915,  1.6418,  1.6234,  0.3110, -0.9366, -1.8572,  0.4059,
          1.6960, -0.3338],
        [ 0.2598,  

tensor([[-0.9023, -0.6850, -1.2492,  1.7602,  2.1962, -0.2867, -1.1729,  0.9022,
         -0.2983, -0.6507],
        [-0.1175, -1.6496,  0.4108,  1.5137, -0.0037,  0.1429,  0.6243,  0.5437,
         -0.5560, -0.7633],
        [ 0.6472,  1.0092,  0.2579,  0.0057,  0.3054,  1.2912,  0.8298,  0.5475,
         -1.6087, -1.8016],
        [ 0.1660,  1.7276, -2.5061,  1.3302, -0.1720,  1.2340,  1.6120,  0.8276,
         -0.2416,  1.2300],
        [-1.1758,  0.4661, -1.9773, -1.8903, -0.6625, -1.1991, -0.2803, -0.6409,
          1.3852, -1.5505],
        [ 0.0026,  0.1713,  0.3797, -1.5275, -1.1084,  1.2182,  0.6127,  0.3728,
          2.1592, -0.6521],
        [ 0.6464,  0.8619,  0.0576,  0.9021,  2.3205, -0.7419, -0.8393,  1.0070,
          0.4358,  0.1843],
        [-1.0112, -1.2472, -0.4086,  0.7461, -0.9460,  0.8514, -0.0750,  2.2130,
          0.0570, -1.1306],
        [ 1.2741,  0.9967,  1.4040,  0.4326, -0.7809, -1.0573,  1.5494,  0.4494,
         -1.5840,  0.4484],
        [-0.1492,  

In [7]:
# A demo Pattern
class MyDictDataset(Dataset):
    def __init__(self):
        super(MyDictDataset, self).__init__()
        self.x = torch.randn(1024, 10)
        self.y = torch.randn(1024)
    def __len__(self):
        return 1024
    def __getitem__(self, idx):
        return {'x': self.x[idx, :], 'y': self.y[idx]}

In [8]:
my_dict_dataset = MyDictDataset()
my_data_loader = DataLoader(my_dict_dataset, batch_size=64, shuffle=True)
for batch in my_data_loader:
    print(batch['x'])    
    print(batch['y'])

tensor([[ 9.7655e-02, -1.0359e-01, -4.5622e-01, -7.2091e-01, -1.3020e+00,
          1.7605e+00,  2.3559e-01,  1.2916e+00,  1.7347e+00, -4.9604e-02],
        [ 8.1948e-01,  2.7509e-01, -1.4539e+00, -1.6238e-01, -2.4651e-01,
         -2.3872e+00,  3.1912e-01,  1.8627e-01,  2.1463e-01,  2.4382e+00],
        [ 6.5929e-01, -9.2071e-01,  3.1380e-01, -2.2800e-02, -1.1983e+00,
         -6.9152e-01,  2.7868e-01,  1.9852e+00, -6.9836e-01,  4.3078e-01],
        [-5.5372e-01,  1.3916e+00, -1.6709e+00, -1.5404e-01, -1.0625e+00,
          8.1883e-01,  6.3342e-01, -7.9390e-01,  1.4681e+00, -3.4782e-01],
        [ 1.2866e+00,  1.0476e+00,  1.1520e+00, -4.8720e-01,  6.7443e-01,
          1.7288e-01,  3.6275e-01,  6.5844e-01,  1.9006e-01,  7.3741e-01],
        [-3.6401e-01, -3.5704e-03, -9.4950e-01,  1.5912e-01,  1.2736e+00,
         -5.7254e-01,  1.0494e+00,  2.2206e-01,  9.8092e-01, -4.1555e-01],
        [-1.0874e-01,  4.6133e-01,  3.4433e-01, -1.4177e-01,  1.4516e-01,
         -1.8285e-01, -1.2600e+0

In [9]:
from torch.utils.data import TensorDataset
x = torch.randn(10, 100)
y = torch.randn(10)
tensor_dataset = TensorDataset(x, y)

In [10]:
tensor_dataset

<torch.utils.data.dataset.TensorDataset at 0x112d7d4e0>

In [11]:
x

tensor([[-0.4600,  0.3029,  0.2831, -1.0707,  0.5190,  0.5600, -0.9993, -0.3429,
          0.4868,  0.8209,  1.5484, -1.9874,  0.1946, -0.2026,  0.0618, -0.1625,
          0.3928, -0.6213,  0.0997,  1.1114,  0.1291,  0.1348,  0.4220, -1.6871,
         -0.5384, -0.0100,  0.4145, -0.7938,  0.1905, -0.1667,  0.5089,  1.4771,
         -0.2883, -0.0441, -0.4409, -1.1810,  0.0622, -0.8203,  0.9669, -0.5637,
          1.0019,  0.4535,  1.0941, -1.8349,  1.0303, -1.3427, -0.2849,  1.0749,
          0.2099, -0.8991,  0.7539,  0.4441, -0.0045, -0.4009, -0.2459,  0.1033,
         -0.8170,  1.1567, -0.6259, -1.0630,  0.4181,  0.9361, -0.1554,  1.5912,
          1.1066, -0.8020, -0.1500, -0.8144,  0.9559, -0.7691,  0.3074,  0.0585,
         -0.6145, -1.7670,  1.7918, -1.3274,  2.2780,  0.3938, -1.1116,  0.4610,
          0.9098,  0.4282,  0.3624, -0.7799, -1.1427, -0.8747, -1.0494,  1.3218,
         -0.3068,  0.8276, -0.6989, -0.2827, -0.9887, -2.2510, -2.1572, -0.1246,
         -1.2499,  2.2809, -

In [12]:
y

tensor([ 0.7608,  3.1897, -0.3964,  2.1502, -1.1312, -0.4349, -0.0783,  0.4213,
        -1.0338,  0.9264])