In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class Cifar10LT(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        img = np.transpose(np.reshape(img, (3, 32, 32)), (1, 2, 0))  # Convert to C,H,W format
        if self.transform:
            img = self.transform(img)
        return img, label

# Set save directory
#save_dir = '/home/u2023170724/jupyterlab/ADAT1/Icifar10/cifar10lt_10'
save_dir = './cifar10lt_10'

# Load training set
train_images = np.load(os.path.join(save_dir, 'train_images.npy'))
train_labels = np.load(os.path.join(save_dir, 'train_labels.npy'))

# Load test set (assuming test set is also saved)
test_images = np.load(os.path.join(save_dir, 'test_images.npy'))
test_labels = np.load(os.path.join(save_dir, 'test_labels.npy'))

# Data augmentation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Create datasets
train_dataset = Cifar10LT(images=train_images, labels=train_labels, transform=transform)
test_dataset = Cifar10LT(images=test_images, labels=test_labels, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Print training set information
print(f"Training set samples: {len(train_dataset)}")
unique_train_labels, train_counts = np.unique(train_dataset.labels, return_counts=True)
print("Samples per class in training set:", dict(zip(unique_train_labels, train_counts)))

# Print test set information
print(f"Test set samples: {len(test_dataset)}")
unique_test_labels, test_counts = np.unique(test_dataset.labels, return_counts=True)
print("Samples per class in test set:", dict(zip(unique_test_labels, test_counts)))

# Example: Iterate training data loader
for images, labels in train_loader:
    print("Training batch image shape:", images.shape)
    print("Training batch label shape:", labels.shape)
    break  # Only display the first batch

# Example: Iterate test data loader
for images, labels in test_loader:
    print("Test batch image shape:", images.shape)
    print("Test batch label shape:", labels.shape)
    break  # Only display the first batch
