In [2]:
import torch
from adapt.dataset import Flame2DataModule
from torchvision import transforms

# Define the transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to a standard size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize for pretrained models
])

# Initialize the Flame2DataModule
flame2_dm = Flame2DataModule(image_dir='data/Flame2/RGB', 
                             batch_size=32, 
                             transform=transform)

# Setup the data module
flame2_dm.setup()

# Get the dataloaders
train_loader = flame2_dm.train_dataloader()
val_loader = flame2_dm.val_dataloader()
test_loader = flame2_dm.test_dataloader()

# Print dataset sizes
print(f"Train set size: {len(train_loader.dataset)}")
print(f"Validation set size: {len(val_loader.dataset)}")
print(f"Test set size: {len(test_loader.dataset)}")

# Check a batch from the train loader
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Unique labels: {torch.unique(labels)}")

# Verify label distribution
label_counts = torch.bincount(labels)
print("Label distribution:")
for label, count in enumerate(label_counts):
    print(f"Label {label}: {count.item()} samples")

# Check for data leakage
train_labels = torch.cat([labels for _, labels in train_loader])
val_labels = torch.cat([labels for _, labels in val_loader])
test_labels = torch.cat([labels for _, labels in test_loader])

print("\nChecking for data leakage:")
print(f"Overlap between train and val: {len(set(train_labels.tolist()) & set(val_labels.tolist()))}")
print(f"Overlap between train and test: {len(set(train_labels.tolist()) & set(test_labels.tolist()))}")
print(f"Overlap between val and test: {len(set(val_labels.tolist()) & set(test_labels.tolist()))}")


Train set size: 37415
Validation set size: 8017
Test set size: 8019
Batch shape: torch.Size([32, 3, 224, 224])
Labels shape: torch.Size([32])
Unique labels: tensor([0, 1, 2])
Label distribution:
Label 0: 3 samples
Label 1: 21 samples
Label 2: 8 samples

Checking for data leakage:
Overlap between train and val: 3
Overlap between train and test: 3
Overlap between val and test: 3
