In this notebook, we will use the pytorch dataloader to load a custom dataset (MNIST Digits dataset in this case) from file.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, csv_file):
        self.data = np.genfromtxt(csv_file, delimiter=',')
        self.labels = self.data[:, 0]  # Extract labels from the first column
        self.features = self.data[:, 1:]  # Extract features from remaining columns
        print(self.data.shape)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        label = self.labels[index]
        features = self.features[index]

        # Convert label and features to torch tensors if needed
        label = torch.tensor(label)
        features = torch.tensor(features)

        return features, label


# Create an instance of the CustomDataset
dataset = CustomDataset('mnist_train.csv')

# Create a DataLoader with batch size 64
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)




In [None]:
# Get the first batch from the dataloader
batch = next(iter(dataloader))

# Extract features and labels from the batch
features_batch, labels_batch = batch

# Plot the 50th image in the first batch
plt.imshow(features_batch[40].numpy().reshape(28, 28), cmap='gray')
plt.title(f"Label: {labels_batch[40].item()}")
plt.axis('off')
plt.show()