# Imports


In [1]:
# general imports
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

In [3]:
# utility imports
import os
from tqdm import tqdm

In [4]:
import warnings

warnings.filterwarnings("ignore")

# Configuration


In [5]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset


In [6]:
class MnistDataset(Dataset):
    def __init__(self, img_file: str, device: str = DEVICE):
        super().__init__()
        self.images = pd.read_csv(img_file)
        self.labels = self.images["label"]
        self.images.drop(columns=["label"], inplace=True)
        self.device = device

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = torch.tensor(self.imgs.iloc[index, :]).float() / 255
        if self.transform:
            img = self.transform(img)
        img = img.to(self.device)
        return img, self.labels[index]

    def show(self, index):
        img = torch.tensor(self.imgs.loc[index, :]).view((28, 28)).float()
        label = self.labels[index]
        plt.imshow(img, cmap="gray")
        plt.xticks([], [])
        plt.yticks([], [])
        plt.title(f"Label: {label}")
        plt.tight_layout()
        plt.show()