In [None]:
import os
import shutil
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader
import wandb
from dotenv import load_dotenv

# ==========================================
# 1. Environment Setup (Standalone)
# ==========================================
load_dotenv(os.path.join(os.getcwd(), ".env"))

PROJECT_NAME = os.getenv("WANDB_PROJECT", "cifar10_mlops_project")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

if not WANDB_API_KEY:
    print("WANDB_API_KEY not found in .env. Please login manually.")
    wandb.login()

# ==========================================
# 2. Shared Code (Inlined from src/dataset.py)
# ==========================================
class Cifar10DataManager:
    def __init__(self, data_dir="./data"):
        self.data_dir = data_dir
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)

    def get_transforms(self, architecture_option='standard'):
        # Base transforms
        transform_list = [
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ]
        
        train_transforms = [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4)
        ] + transform_list

        if architecture_option == 'upsample':
            transform_list.insert(0, transforms.Resize(224))
            train_transforms.insert(0, transforms.Resize(224))

        return transforms.Compose(train_transforms), transforms.Compose(transform_list)

    def prepare_initial_split(self):
        """
        Downloads CIFAR-10.
        Splits Test set (10k) into:
        - Test (8k): For model evaluation
        - Simulation (2k): For live traffic simulation (Holdout)
        """
        print(f"Downloading/Loading data in {self.data_dir}...")
        # Download raw data
        train_set = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        test_set = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True)
        
        # Split Test Set
        indices = list(range(len(test_set)))
        # Shuffle deterministically for reproducibility
        np.random.seed(42)
        np.random.shuffle(indices)
        
        test_indices = indices[:8000]
        sim_indices = indices[8000:]
        
        # Save indices to disk to ensure we load the same split later
        processed_dir = os.path.join(self.data_dir, "processed")
        os.makedirs(processed_dir, exist_ok=True)
        np.save(os.path.join(processed_dir, "test_indices.npy"), test_indices)
        np.save(os.path.join(processed_dir, "sim_indices.npy"), sim_indices)
        print("Data split indices created.")
        
        return train_set, test_set, test_indices, sim_indices

# ==========================================
# 3. Execution Main
# ==========================================

# Initialize Data Manager
dm = Cifar10DataManager(data_dir="./data")

# 1. Download & Prepare Initial Split
# This downloads CIFAR-10 from Torchvision and creates the 40k/8k/2k split indices locally
print("Downloading and splitting data...")
dm.prepare_initial_split()

# 2. Versioning with W&B
# We create a new run to log this dataset as the "Source of truth"
run = wandb.init(project=PROJECT_NAME, job_type="data_preparation", name="cifar10_v1")

# We create an artifact that contains the entire data directory (Raw images + Split Indices)
dataset_artifact = wandb.Artifact(
    name="cifar10_dataset", 
    type="dataset", 
    description="CIFAR-10 Raw Data + Split Indices (Train/Test/Sim)"
)

# Add the data directory to the artifact
# Note: This uploads the whole ./data folder including the 'cifar-10-batches-py' and 'processed'
dataset_artifact.add_dir("./data")

# Log it
run.log_artifact(dataset_artifact)
run.finish()

print("Step 1 Complete: Dataset v1 logged to W&B.")

Downloading and splitting data...
Downloading/Loading data in ./data...


100%|██████████| 170M/170M [00:02<00:00, 74.8MB/s] 


Data split indices created.


<IPython.core.display.Javascript object>

# Data Preparation and Versioning

This notebook downloads the CIFAR-10 dataset and versions it using Weights & Biases Artifacts.


In [None]:
import wandb
import torchvision
import os

# Project Configuration
PROJECT_NAME = "cifar10_mlops_project"
ENTITY = None # Set this to your username if needed, usually inferred
ARTIFACT_NAME = "cifar10-raw-data"
DATA_DIR = "../data/raw"

# Create data directory if it doesn't exist
os.makedirs(DATA_DIR, exist_ok=True)

In [None]:
# Initialize W&B Run for Data Preparation
run = wandb.init(project=PROJECT_NAME, job_type="data-preparation")
wandb.login()

In [None]:
# Download CIFAR-10 Dataset
print("Downloading CIFAR-10 dataset...")
# We use torchvision to download, it creates a folder 'cifar-10-batches-py' inside DATA_DIR
dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True)
test_dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True)
print("Download complete.")

In [None]:
# Create a W&B Artifact
artifact = wandb.Artifact(name=ARTIFACT_NAME, type="dataset", description="Raw CIFAR-10 dataset from torchvision")

# Add the directory containing the dataset to the artifact
# Torchvision CIFAR10 extracts to a folder inside root, usually. 
# Let's add the whole DATA_DIR content to be sure we capture it.
artifact.add_dir(DATA_DIR)

# Log the artifact to W&B
print("Logging artifact to W&B...")
run.log_artifact(artifact)
print("Artifact logged successfully.")

In [None]:
wandb.finish()