In [8]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Define the transforms for the dataset
data_transforms = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize the images to 32x32
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])

# Load the dataset
dataset = datasets.ImageFolder(root='data', transform=data_transforms)

# Define the size of the train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create dataloaders for the datasets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Access the class names
class_names = dataset.classes
print('Class names:', class_names)

# Iterate through the training data
for inputs, labels in train_loader:
    print('Inputs:', inputs.shape)
    print('Labels:', len(labels))

Class names: ['Amphibolite', 'Andesite', 'Anthracite', 'Basalt', 'Blueschist', 'Breccia', 'Carbonatite', 'Chalk', 'Chert', 'Coal', 'Conglomerate', 'Diamictite', 'Dolomite', 'Eclogite', 'Evaporite', 'Flint', 'Gabbro', 'Gneiss', 'Granite', 'Granulite', 'Greenschist', 'Greywacke', 'Hornfels', 'Komatiite', 'Limestone', 'Marble', 'Migmatite', 'Mudstone', 'Obsidian', 'Oil_shale', 'Oolite', 'Pegmatite', 'Phyllite', 'Porphyry', 'Pumice', 'Pyroxenite', 'Quartz_diorite', 'Quartz_monzonite', 'Quartzite', 'Quartzolite', 'Rhyolite', 'Sandstone', 'Scoria', 'Serpentinite', 'Shale', 'Siltstone', 'Slate', 'Talc_carbonate', 'Tephrite', 'Travertine', 'Tuff', 'Turbidite', 'Wackestone']
Inputs: torch.Size([32, 3, 32, 32])
Labels: 32
Inputs: torch.Size([32, 3, 32, 32])
Labels: 32
Inputs: torch.Size([32, 3, 32, 32])
Labels: 32
Inputs: torch.Size([32, 3, 32, 32])
Labels: 32
Inputs: torch.Size([32, 3, 32, 32])
Labels: 32
Inputs: torch.Size([32, 3, 32, 32])
Labels: 32
Inputs: torch.Size([32, 3, 32, 32])
Labels:

Inputs: tensor([[[[-0.1314,  0.7591,  0.9132,  ..., -0.1486,  0.3309,  0.6563],
          [-0.6794,  0.6221,  0.8618,  ..., -0.1657,  0.3138,  0.5193],
          [-0.8678,  0.3309,  1.0331,  ...,  0.0056,  0.2111,  0.1426],
          ...,
          [ 0.6392,  0.3994,  0.3138,  ...,  0.4679,  0.6563,  0.8447],
          [ 0.8618,  0.5022,  0.4851,  ...,  0.7933,  0.8789,  0.8961],
          [ 0.6392,  0.3823,  0.6734,  ...,  0.8447,  0.8961,  1.0159]],

         [[-0.6176, -0.0574,  0.1527,  ..., -0.1800,  0.4853,  0.7479],
          [-0.9153, -0.1275,  0.0826,  ..., -0.0924,  0.5028,  0.6604],
          [-0.8978, -0.3025,  0.2227,  ...,  0.1352,  0.3627,  0.2927],
          ...,
          [ 0.8179,  0.5728,  0.4853,  ...,  0.6254,  0.8704,  1.0805],
          [ 1.0630,  0.7129,  0.6429,  ...,  0.9755,  1.0630,  1.0805],
          [ 0.8704,  0.5728,  0.8354,  ...,  1.0455,  1.0805,  1.2031]],

         [[-0.7587, -0.3404, -0.0964,  ..., -0.0790,  0.6356,  0.8099],
          [-0.9504, -0