In [1]:
import torch
from adapt.dataset import Flame2DataModule
from torchvision import transforms

# Define the transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to a standard size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize for pretrained models
])

# Initialize the Flame2DataModule
flame2_dm = Flame2DataModule(image_dir='data/Flame2/RGB', 
                             batch_size=32, 
                             transform=transform)

# Setup the data module
flame2_dm.setup()

# Get the dataloaders
train_loader = flame2_dm.train_dataloader()
val_loader = flame2_dm.val_dataloader()
test_loader = flame2_dm.test_dataloader()

# Print dataset sizes
print(f"Train set size: {len(train_loader.dataset)}")
print(f"Validation set size: {len(val_loader.dataset)}")
print(f"Test set size: {len(test_loader.dataset)}")

# Check a batch from the train loader
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Unique labels: {torch.unique(labels)}")

# Verify label distribution
label_counts = torch.bincount(labels)
print("Label distribution:")
for label, count in enumerate(label_counts):
    print(f"Label {label}: {count.item()} samples")

# Check for data leakage
train_labels = torch.cat([labels for _, labels in train_loader])
val_labels = torch.cat([labels for _, labels in val_loader])
test_labels = torch.cat([labels for _, labels in test_loader])

print("\nChecking for data leakage:")
print(f"Overlap between train and val: {len(set(train_labels.tolist()) & set(val_labels.tolist()))}")
print(f"Overlap between train and test: {len(set(train_labels.tolist()) & set(test_labels.tolist()))}")
print(f"Overlap between val and test: {len(set(val_labels.tolist()) & set(test_labels.tolist()))}")


Train set size: 37415
Validation set size: 8018
Test set size: 8018
Batch shape: torch.Size([32, 3, 224, 224])
Labels shape: torch.Size([32])
Unique labels: tensor([0, 1, 2])
Label distribution:
Label 0: 11 samples
Label 1: 9 samples
Label 2: 12 samples

Checking for data leakage:
Overlap between train and val: 3
Overlap between train and test: 3
Overlap between val and test: 3


In [5]:
for _, labels in train_loader:
    print(_)
    break

tensor([[[[-0.2513, -0.0458, -0.2513,  ...,  0.3994,  0.3994,  0.1768],
          [-0.2513,  0.0741, -0.0629,  ...,  0.3823,  0.3994,  0.2111],
          [-0.4226,  0.0398,  0.1083,  ...,  0.3823,  0.4166,  0.2282],
          ...,
          [-0.5938, -0.5767, -0.5938,  ..., -0.1828, -0.1999, -0.3198],
          [-0.5082, -0.6452, -0.7137,  ..., -0.1828, -0.1999, -0.3369],
          [-0.6794, -0.7137, -0.7308,  ..., -0.1999, -0.1999, -0.3369]],

         [[-0.2850, -0.0749, -0.2675,  ...,  0.8179,  0.8354,  0.6078],
          [-0.2150,  0.0651, -0.0924,  ...,  0.8179,  0.8354,  0.6078],
          [-0.4076, -0.0049,  0.0301,  ...,  0.8179,  0.8179,  0.6078],
          ...,
          [-0.6352, -0.6176, -0.6176,  ...,  0.1001,  0.0826, -0.0399],
          [-0.5651, -0.7052, -0.7752,  ...,  0.1001,  0.0826, -0.0574],
          [-0.7927, -0.7927, -0.7927,  ...,  0.0826,  0.0826, -0.0574]],

         [[-0.6890, -0.6541, -0.6890,  ...,  1.2282,  1.2282,  1.0017],
          [-0.7761, -0.6018, -