In [1]:
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from pytorch_model_summary import summary

In [2]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode="train", transforms=None):
        digits = load_digits()
        if mode == "train":
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == "val":
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [3]:
data = Digits()

In [4]:
data[0]

array([ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.,  0.,  0., 13., 15., 10.,
       15.,  5.,  0.,  0.,  3., 15.,  2.,  0., 11.,  8.,  0.,  0.,  4.,
       12.,  0.,  0.,  8.,  8.,  0.,  0.,  5.,  8.,  0.,  0.,  9.,  8.,
        0.,  0.,  4., 11.,  0.,  1., 12.,  7.,  0.,  0.,  2., 14.,  5.,
       10., 12.,  0.,  0.,  0.,  0.,  6., 13., 10.,  0.,  0.,  0.],
      dtype=float32)

In [5]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class STNet(nn.Module):

    def __init__(
        self,
        dim: int,
        hidden: int = 128
    ):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(dim, hidden),
                                 nn.ReLU(),
                                 nn.Linear(hidden, hidden),
                                 nn.ReLU(),
                                 nn.Linear(hidden, dim * 2))
        
    def forward(self, x:torch.Tensor):
        st = self.net(x) # (B, 2D)
        s, t = st.chunk(2, dim=-1) #each is (B, D)
        return s, t