In [153]:
import torch
from torch import nn

In [154]:
import torchvision
from torchvision import datasets

In [155]:
from torchvision import transforms

## Fashion MNIST

In [156]:
train_data = datasets.FashionMNIST(root="data", train=True, download=True, target_transform=None,
                                   transform=torchvision.transforms.ToTensor())

test_data = datasets.FashionMNIST(root="data", train=False, download=True, target_transform=None,
                                   transform=torchvision.transforms.ToTensor())

In [157]:
len(train_data), len(test_data)

(60000, 10000)

In [158]:
image, label = train_data[0]

In [159]:
class_names = train_data.classes
class_names

['T-shirt/top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']

In [160]:
class_to_idx = train_data.class_to_idx

## Visualizing Data

In [161]:
import plotly.express as px
import numpy as np

In [162]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [163]:
from typing import Union

In [164]:
def showImage(image: Union[torch.Tensor, np.array], label: int) -> None:
    """Show Image of Fashion MNIST set with Label

    :param image: Image data
    :type image: Union[torch.Tensor, np.array]

    :param label: Label id
    :type label: int
    """

    class_nm = ['T-shirt/top','Trouser','Pullover','Dress',
                'Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']

    if image.ndim == 3:
        image = image.squeeze()

    fig = px.imshow(image, color_continuous_scale='gray', labels=dict(x=f"{class_nm[label]}"))
    fig.update_layout(coloraxis_showscale=False)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)

    fig.show()

In [165]:
showImage(image, label)

In [166]:
def plotRandomImg(train_data: torchvision.datasets.mnist.FashionMNIST, grid_len=5, seed=42):
    """_summary_

    :param train_data: Total Training Data
    :type train_data: torchvision.datasets.mnist.FashionMNIST

    :param grid_len: Number of grid, defaults to 5. Tot images = grid_len * grid_len
    :type grid_len: int, optional
    """
    torch.manual_seed(seed=seed)

    rand_indexs = torch.randint(0, len(train_data), size=[grid_len*grid_len])

    fig = make_subplots(grid_len, grid_len)
    
    row_ind_start = 1

    class_nm = ['T-shirt/top','Trouser','Pullover','Dress',
                'Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
    samp_ind = 0

    while row_ind_start <=grid_len:
        col_ind_start = 1
        while col_ind_start <= grid_len:
            img_data, label = train_data[rand_indexs[samp_ind]]
            single_fig = go.Figure(go.Heatmap(z=img_data.squeeze()))
            fig.add_trace(single_fig.data[0], row_ind_start, col_ind_start)
            fig.update_yaxes(autorange="reversed")
            fig.update_xaxes(title_text = f"{class_nm[label]}", row=row_ind_start, col=col_ind_start)
            col_ind_start +=1
            samp_ind += 1
        row_ind_start += 1

    fig.update_layout(height=1000, width=1200)
    fig.update_layout(coloraxis_showscale=False)

    return fig

In [167]:
plotRandomImg(train_data, seed=212)

In [168]:
from torch.utils.data import DataLoader

In [169]:
BATCH_SIZE = 32

In [171]:
# Turn datasets into iterables (batches)
train_dataloader = DataLoader(train_data, # dataset to turn into iterable
    batch_size=BATCH_SIZE, # how many samples per batch? 
    shuffle=True # shuffle data every epoch?
)

test_dataloader = DataLoader(test_data,
    batch_size=BATCH_SIZE,
    shuffle=False # don't necessarily have to shuffle the testing data
)


In [172]:
train_features_batch, train_labels_batch = next(iter(train_dataloader))

In [177]:
x = train_features_batch[0]

In [179]:
x.flatten().shape

torch.Size([784])