# PyTorch Exploration – Dataset & Data Loader

In [1]:
import audiomod
import ptmod
# from pymongo import MongoClient
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
# import torchvision
# from torchvision import transforms, utils
import torch.utils.data as data_utils
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
from collections import OrderedDict, defaultdict
import pickle
import os

from sklearn import metrics

%matplotlib inline

## Creating Dataset Sub-Class

Wrote a class in `pytorchmod` that uses MongoDB to access particular test round. Creates spectrograms on the fly through the `__getitem__()` method.

First, create a datagroup in the Mongo collection. Following the sax/no-sax convention, I'll use all the available samples with sax in foreground, and an equal number of those without.

In [2]:
# only run once!
# komod.create_datagroup_in_db('sax1203', 'sax')

Now that it's labeled in the database, I load the datagroup as a pd dataframe, pull a relatively balanced subset, add a train/test column, and then create PyTorch dataset objects.

In [3]:
sax1203_datagroup = audiomod.pull_datagroup_from_db('sax1203')

sax1203_datagroup.shape

(920, 2)

In [4]:
# pull a smaller sample for PoC run
sub_datagroup = sax1203_datagroup.sample(100)
sub_datagroup.actual.value_counts()

1    53
0    47
Name: actual, dtype: int64

In [6]:
train_df, test_df = audiomod.tts(sub_datagroup)

# this scaling is pretty tiny, but it'll do the trick for a dry run
train_sub = ptmod.SpectroDataset(train_df, scaling=0.125)
test_sub = ptmod.SpectroDataset(test_df, scaling=0.125)

print("Train set length:", len(train_sub))
print("Test set length:", len(test_sub))

Train set length: 79
Test set length: 21


Getting an item by index on a dataset object should return a tuple of:  

`(sample_spectrogram, sample_ground_truth, sample_ID)`

In [7]:
for i in range(5):
    item = train_sub[i]
    print("\nChunk:", item[2])
    print("Label:", item[1])
    print("---")
    ptmod.tensor_stats(item[0])


Chunk: 003171
Label: 0
---
Min: -1.0
Max: 1.0
Mean: 0.11256619692517654
Std: 0.4029987719865222
Shape: torch.Size([1, 64, 54])

Chunk: 000301
Label: 1
---
Min: -0.870057225227356
Max: 1.0
Mean: 0.26686227109635013
Std: 0.31685899686397717
Shape: torch.Size([1, 64, 54])

Chunk: 012413
Label: 0
---
Min: -1.0
Max: 1.0
Mean: -0.5248126373042867
Std: 0.44975363493523973
Shape: torch.Size([1, 64, 54])

Chunk: 014894
Label: 0
---
Min: -1.0
Max: 1.0
Mean: -0.16581711575324537
Std: 0.3819979149374032
Shape: torch.Size([1, 64, 54])

Chunk: 014699
Label: 0
---
Min: -0.4782467186450958
Max: 1.0
Mean: 0.16387082422865104
Std: 0.2371574806909264
Shape: torch.Size([1, 64, 54])


Works as it should. The first object in the tuple, the spectrogram, should be a torch tensor:

In [8]:
type(train_sub[2][0])

torch.FloatTensor

## Data Loader

What is it and how does it work?

In [9]:
train_loader = data_utils.DataLoader(train_sub, batch_size=4, shuffle=True, drop_last=True)

In [10]:
type(train_loader)

torch.utils.data.dataloader.DataLoader

In [11]:
# makes an iterable from loader
train_iter = iter(train_loader)

In [12]:
type(train_iter)

torch.utils.data.dataloader.DataLoaderIter

In [13]:
# save a batch of data as a variable
loader_unit = next(train_iter)
type(loader_unit)

list

In [14]:
# list includes training array, corresponding label, and chunk ID
len(loader_unit)

3

In [15]:
# each should include four records
for sub in loader_unit:
    print(len(sub))

4
4
4


In [16]:
# and the types?
for sub in loader_unit:
    print(type(sub))

<class 'torch.FloatTensor'>
<class 'torch.LongTensor'>
<class 'tuple'>


In [17]:
# what about some shapes and contents
print(loader_unit[0][0].shape)
print(loader_unit[1][0]) # breaks with shape because an int has no such property
print(loader_unit[2][0]) # same, because str

torch.Size([1, 64, 54])
0
012093


In [18]:
loader_unit[1]


 0
 1
 1
 0
[torch.LongTensor of size 4]

Into the module with it all and on to CNN design...