In [1]:
import torch
from torchvision import datasets, transforms
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# We first define a transform to convert the images to tensors and normalize them
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# We then define a function load_batch to load each batch file.
def load_batch(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

"""We load the training data from the files data_batch_1 to data_batch_5, concatenate them, 
and create a DataLoader for the training data"""

train_data = []
train_labels = []
for i in range(1, 6):
    batch = load_batch(f'data_batch_{i}')
    train_data.append(batch[b'data'])
    train_labels += batch[b'labels']
train_data = np.concatenate(train_data).reshape(-1, 3, 32, 32)
train_data = torch.tensor(train_data, dtype=torch.float)
train_labels = torch.tensor(train_labels, dtype=torch.long)

# training DataLoader
trainset = torch.utils.data.TensorDataset(train_data, train_labels)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

"""We load the test data from the file test_batch and create a DataLoader for the test data."""
# Load test data
test_batch = load_batch('test_batch')
test_data = test_batch[b'data'].reshape(-1, 3, 32, 32)
test_data = torch.tensor(test_data, dtype=torch.float)
test_labels = torch.tensor(test_batch[b'labels'], dtype=torch.long)

# test DataLoader
testset = torch.utils.data.TensorDataset(test_data, test_labels)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)


In [None]:
# Load the first training batch
batch1 = load_batch('data_batch_1')

# Separate the data and labels
data = batch1[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
labels = batch1[b'labels']


classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
fig = plt.figure(figsize=(10,10))

# For each class
for i in range(10):
    # Find the indices of the examples in this class
    indices = np.where(np.array(labels) == i)[0]
    
    # Choose a random index from the indices
    index = np.random.choice(indices)
    
    # Get the image corresponding to this index
    img = data[index]
    
    # Normalize the image data
    img = img / 255.0
    
    # Add a subplot for this image
    ax = fig.add_subplot(2, 5, i+1)
    ax.imshow(img)
    ax.set_title(classes[i])
    ax.axis('off')

plt.show()
