In [None]:
# import os
# import time
# import requests
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader
# from torchvision import transforms
# from sklearn.model_selection import train_test_split
# from tqdm import tqdm
# import wandb
# from PIL import Image
# from dotenv import load_dotenv

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import wandb
import requests
from dotenv import load_dotenv
import openai
import random
from tqdm import tqdm

In [None]:
# Load environment variables for API keys
load_dotenv()

In [None]:
# Load environment variables
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
if not OPENAI_API_KEY:
    raise ValueError("OpenAI API key not found in environment variables")

In [None]:
def test_dalle_download():
    """Test DALL-E image generation and download"""
    try:
        # Test prompt
        test_prompt = "A simple test image of a blue circle"
        
        print("Testing DALL-E image generation...")
        image_url = generate_dalle_image(test_prompt, OPENAI_API_KEY)
        
        # Create test directory if it doesn't exist
        os.makedirs('test', exist_ok=True)
        
        # Download the image
        print("Downloading test image...")
        response = requests.get(image_url)
        if response.status_code == 200:
            # Save the image
            test_path = os.path.join('test', 'dalle_test.png')
            with open(test_path, 'wb') as f:
                f.write(response.content)
            print(f"Test successful! Image saved to {test_path}")
            
            # Try loading the image with PIL to verify it's valid
            try:
                Image.open(test_path)
                print("Image format verified successfully!")
                return True
            except Exception as e:
                print(f"Error verifying image format: {e}")
                return False
        else:
            print(f"Failed to download image. Status code: {response.status_code}")
            return False
            
    except Exception as e:
        print(f"Test failed with error: {e}")
        return False

In [None]:
# Initialize wandb with your API key
wandb.login(key="7e7d55f5967d9b48d5c46f5008eaa5ad71e02d89")



In [None]:
# from openai import OpenAI

# def generate_dalle_image(prompt, api_key):
#     """
#     Generate an image using DALL-E based on the given prompt
    
#     Args:
#         prompt (str): The text prompt to generate the image from
#         api_key (str): OpenAI API key
        
#     Returns:
#         str: URL of the generated image
#     """
#     try:
#         client = OpenAI(api_key=api_key)
        
#         response = client.images.generate(
#             model="dall-e-3",
#             prompt=prompt,
#             size="1024x1024",
#             quality="standard",
#             n=1,
#         )
        
#         return response.data[0].url
#     except Exception as e:
#         print(f"Error generating DALL-E image: {str(e)}")
#         raise



# #USED FOR DALL-E-3 IMAGE GENERATION

In [None]:
import time
from ratelimit import limits, sleep_and_retry

@sleep_and_retry
@limits(calls=50, period=3600)  # Limiting to 50 calls per hour
def generate_dalle_image(prompt, api_key):
    """
    Generate an image using DALL-E based on the given prompt, with rate limiting
    
    Args:
        prompt (str): The text prompt to generate the image from
        api_key (str): OpenAI API key
        
    Returns:
        str: URL of the generated image, or None if generation fails
    """
    try:
        client = OpenAI(api_key=api_key)
        
        response = client.images.generate(
            model="dall-e-2",  # Changed to DALL-E 2 for lower cost
            prompt=prompt,
            size="1024x1024",
            n=1,
        )
        
        return response.data[0].url
    except Exception as e:
        print(f"Error generating DALL-E image: {str(e)}")
        return None  # Return None instead of raising an exception for graceful failure


In [None]:
def mock_dalle_generation(num_images=5):
    """
    Development-friendly function that creates and returns paths to test images
    instead of generating new ones through the API
    
    Args:
        num_images (int): Number of test images to create/use
        
    Returns:
        list: Paths to test images
    """
    test_images = []
    os.makedirs('test_images', exist_ok=True)
    
    # Create or use sample images for testing
    for i in range(num_images):
        test_path = f'test_images/test_image_{i}.jpg'
        if not os.path.exists(test_path):
            # Create a simple test image using PIL
            img = Image.new('RGB', (1024, 1024), color='white')
            img.save(test_path)
        test_images.append(test_path)
        
    
    return test_images

In [None]:
# # For development with mock data:
# dalle_paths = download_dalle_images(save_dir=dalle_dir, num_images=100, use_mock=True)

In [None]:
# # For production with real DALL-E
# dalle_paths = download_dalle_images(save_dir=dalle_dir, num_images=100, use_mock=False)

In [None]:
def download_dalle_images(save_dir, num_images, use_mock=True):
    """
    Download or generate test images for the dataset
    
    Args:
        save_dir (str): Directory to save images
        num_images (int): Number of images to generate/create
        use_mock (bool): Whether to use mock images for development
        
    Returns:
        list: Paths to downloaded/generated images
    """
    os.makedirs(save_dir, exist_ok=True)
    downloaded_paths = []
    
    if use_mock:
        # Use mock generation for development
        return mock_dalle_generation(num_images)
    
    client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
    images_to_download = num_images
    
    while images_to_download > 0:
        try:
            # Use the rate-limited function
            img_url = generate_dalle_image("stock photo", os.getenv('OPENAI_API_KEY'))
            if img_url is None:
                continue
                
            img_data = requests.get(img_url).content
            img_name = os.path.join(save_dir, f"dalle_{len(downloaded_paths)}.jpg")
            
            with open(img_name, 'wb') as f:
                f.write(img_data)
                
            downloaded_paths.append(img_name)
            images_to_download -= 1
            time.sleep(1)  # Basic rate limiting
            
        except Exception as e:
            print(f"Error during image generation/download: {e}")
            break
            
    return downloaded_paths


In [None]:
# # Add this code block after your generate_dalle_image function definition
# def test_dalle_generation():
#     try:
#         # Load API key from environment variables
#         api_key = os.getenv('OPENAI_API_KEY')
#         if not api_key:
#             raise ValueError("OPENAI_API_KEY not found in environment variables")
            
#         # Test with a simple prompt
#         prompt = "A simple landscape photo"
#         image_url = generate_dalle_image(prompt, api_key)
#         print("DALL-E test successful! Image URL:", image_url)
#         return True
#     except Exception as e:
#         print(f"Testing DALL-E image generation...\nTest failed with error: {str(e)}")
#         print("DALL-E test failed. Please check your API key and connection.")
#         return False

# # Test the function
# if __name__ == "__main__":
#     test_dalle_generation()

In [None]:
# Test DALL-E functionality
if test_dalle_download():
    print("DALL-E test completed successfully!")
else:
    print("DALL-E test failed. Please check your API key and connection.")

In [None]:
def handle_dalle_api_error(error):
    """Handle different types of DALL-E API errors"""
    if isinstance(error, openai.error.AuthenticationError):
        print("Error: Invalid API key or authentication failed")
        return "auth_error"
    elif isinstance(error, openai.error.RateLimitError):
        print("Error: Rate limit exceeded. Please wait before making more requests")
        return "rate_limit"
    elif isinstance(error, openai.error.InsufficientQuotaError):
        print("Error: Insufficient quota or payment required")
        return "quota_error"
    else:
        print(f"Unexpected error occurred: {str(error)}")
        return "unknown_error"

def generate_dalle_image(prompt, api_key):
    """Generate image using DALL-E with error handling"""
    try:
        openai.api_key = api_key
        response = openai.Image.create(
            prompt=prompt,
            n=1,
            size="1024x1024"
        )
        return response['data'][0]['url']
    except Exception as e:
        error_type = handle_dalle_api_error(e)
        if error_type == "auth_error":
            raise ValueError("Please check your OpenAI API key")
        elif error_type == "quota_error":
            raise ValueError("Please check your OpenAI account billing status")
        elif error_type == "rate_limit":
            raise ValueError("Rate limit exceeded. Please try again later")
        else:
            raise e

In [None]:
def set_seed(seed):
    """Set seed for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
os.makedirs('models', exist_ok=True)  # Ensure models directory exists

In [None]:
# Hyperparameters
batch_size = 32
learning_rate = 0.001
num_epochs = 50

wandb.login(key="7e7d55f5967d9b48d5c46f5008eaa5ad71e02d89")
# Initialize wandb run before model training 
wandb.init(
    project="real",
    config={
        "learning_rate": learning_rate,
        "architecture": "CustomCNN",
        "batch_size": batch_size,
        "epochs": num_epochs,
        "optimizer": "adam",
        "loss_function": "binary_crossentropy"
        "device": str(device)  # Add this to track which device is used
    }
)

except Exception as e:
    print(f"Warning: Could not initialize wandb: {e}")
    print("Training will continue without logging")

In [None]:
# Cell 5: Data Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

In [None]:
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

In [None]:
# # WandB setup
# wandb.init(project="stock-photo-detector", name="training_run_v1")

In [None]:
# Cell 6: Dataset Class
class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image_path = self.image_paths[idx]
            if not os.path.exists(image_path):
                raise FileNotFoundError(f"Image not found: {image_path}")
                
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, self.labels[idx]
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {str(e)}")
            raise

In [None]:
# Cell 7: Create Datasets Function
def create_datasets(transform):
    """Create training and validation datasets"""
    # Paths to your image directories
    real_dir = './data/real_images' 
    ai_dir = './data/ai_images'
    
    # Collect image paths and labels
    real_images = [(os.path.join(real_dir, img), 0) for img in os.listdir(real_dir) if img.endswith(('.jpg', '.png', '.jpeg'))]
    ai_images = [(os.path.join(ai_dir, img), 1) for img in os.listdir(ai_dir) if img.endswith(('.jpg', '.png', '.jpeg'))]
    
    # Combine and shuffle
    all_images = real_images + ai_images
    random.shuffle(all_images)
    
    # Split into train and validation
    split_idx = int(len(all_images) * 0.8)  # 80% train, 20% validation
    train_data = all_images[:split_idx]
    val_data = all_images[split_idx:]
    
    # Create datasets
    train_dataset = ImageDataset(
        image_paths=[x[0] for x in train_data],
        labels=[x[1] for x in train_data],
        transform=transform
    )
    
    val_dataset = ImageDataset(
        image_paths=[x[0] for x in val_data],
        labels=[x[1] for x in val_data],
        transform=transform
    )
    
    return train_dataset, val_dataset

In [None]:
# Cell 8: Create DataLoaders
# Create datasets with proper worker configuration
num_workers = min(4, os.cpu_count() or 1)

# Create datasets
train_dataset, val_dataset = create_datasets(transform)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=num_workers
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Cell 9: Model Definition
class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        # Network expects 224x224 images due to the transform
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 28 * 28, 512)
        self.fc2 = nn.Linear(512, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 64 * 28 * 28)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

In [None]:
# --- Image Download Functions ---
def download_unsplash_photos(client_id, save_dir, num_photos):
    os.makedirs(save_dir, exist_ok=True)
    url = "https://api.unsplash.com/photos/random"
    headers = {"Authorization": f"Client-ID {client_id}", "Accept-Version": "v1"}
    
    downloaded_paths = []
    photos_to_download = num_photos

    while photos_to_download > 0:
        try:
            batch_size = min(30, photos_to_download)
            params = {"count": batch_size, "query": "stock photo", "orientation": "landscape"}
            response = requests.get(url, headers=headers, params=params)
            response.raise_for_status()
            photos = response.json()
            
            for photo in tqdm(photos, desc="Downloading photos"):
                img_url = photo['urls']['regular']
                img_response = requests.get(img_url)
                img_response.raise_for_status()
                
                img_name = os.path.join(save_dir, f"{photo['id']}.jpg")
                with open(img_name, 'wb') as f:
                    f.write(img_response.content)
                
                downloaded_paths.append(img_name)

            photos_to_download -= len(photos)
            time.sleep(1)  # Respect API rate limits
            
        except requests.RequestException as e:
            print(f"Error downloading photo: {e}")
            break
    
    return downloaded_paths


In [None]:
# def download_dalle_images(save_dir, num_images):
#     os.makedirs(save_dir, exist_ok=True)
#     downloaded_paths = []
#     images_to_download = num_images
    
#     client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
    
#     while images_to_download > 0:
#         try:
#             response = client.images.generate(
#                 model="dall-e-2", prompt="stock photo", n=1, size="1024x1024"
#             )
#             img_url = response.data[0].url
#             img_data = requests.get(img_url).content
#             img_name = os.path.join(save_dir, f"dalle_{len(downloaded_paths)}.jpg")
#             with open(img_name, 'wb') as f:
#                 f.write(img_data)
#             downloaded_paths.append(img_name)
#             images_to_download -= 1
#             time.sleep(1)
#         except Exception as e:
#             print(f"Error during image generation/download: {e}")
#             break
#     return downloaded_paths

In [None]:
# --- Model Architecture ---
class StockPhotoDetector(nn.Module):
    def __init__(self):
        super(StockPhotoDetector, self).__init__()
        
        # Convolutional Layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Fully Connected Layers
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 2)  # 2 classes: Real vs AI-generated
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

In [25]:
# Cell 10: Training Functions
def validate(model, val_loader, criterion, device):
    """Validation function with progress bar"""
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    # Create progress bar for validation
    val_pbar = tqdm(val_loader, leave=False, desc="Validation")
    
    with torch.no_grad():
        for inputs, labels in val_pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            val_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
    
    val_loss = val_loss / len(val_loader)
    val_acc = 100. * correct / total
    
    return val_loss, val_acc

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    # Initialize early stopping
    patience = 5
    early_stopping_counter = 0
    best_val_loss = float('inf')
    
    # Initialize wandb run with config
    wandb.init(
        project="real",
        config={
            "learning_rate": optimizer.param_groups[0]['lr'],
            "architecture": "CustomCNN",
            "batch_size": train_loader.batch_size,
            "epochs": num_epochs,
            "optimizer": optimizer.__class__.__name__,
            "loss_function": criterion.__class__.__name__
        }
    )
    
    # Create progress bar for epochs
    epoch_pbar = tqdm(range(num_epochs), desc="Training Progress")
    
    for epoch in epoch_pbar:
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Create progress bar for batches
        batch_pbar = tqdm(train_loader, leave=False, desc=f"Epoch {epoch+1}")
        
        for inputs, labels in batch_pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update batch progress bar
            batch_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        
        # Validation phase
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "learning_rate": optimizer.param_groups[0]['lr']
        })
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stopping_counter = 0
            # Save best model and log to wandb
            model_path = 'models/best_model.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss
            }, model_path)
            wandb.save(model_path)
        else:
            early_stopping_counter += 1
        
        if early_stopping_counter >= patience:
            print(f'\nEarly stopping triggered after {epoch + 1} epochs')
            break
        
        # Update epoch progress bar
        epoch_pbar.set_postfix({
            'train_loss': f'{train_loss:.4f}',
            'train_acc': f'{train_acc:.2f}%',
            'val_loss': f'{val_loss:.4f}',
            'val_acc': f'{val_acc:.2f}%'
        })
    
    # Finish wandb run
    # wandb.finish()
    return model

    
# Cell 11: Training Execution
# Set seed for reproducibility
set_seed(42)

# Clear CUDA cache if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Initialize model, criterion, optimizer
model = CustomCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

try:
    wandb.init(project="real-vs-ai-classifier", config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "architecture": "CustomCNN",
        "dataset": "Unsplash+DALLE"
    })
    
    # Train the model
    model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)
    
    # Save the final model
    torch.save(model.state_dict(), 'final_model.pth')
    
finally:
    wandb.finish()

print("Training completed!")

# Optional Cell 12: Load Best Model
def load_best_model():
    checkpoint = torch.load('best_model.pth')
    model = CustomCNN().to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

# Test data loading and model evaluation
sample_batch, sample_labels = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")
print(f"Labels shape: {sample_labels.shape}")

# Test model forward pass
sample_output = model(sample_batch.to(device))
print(f"Output shape: {sample_output.shape}")

# Check save directory
print(f"Save directory contents: {os.listdir('models')}")

# Finish wandb run
wandb.finish()

In [None]:
def load_checkpoint(checkpoint_path, model, optimizer):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['early_stopping_counter']

In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")

In [None]:
# # --- Run the download and train pipeline ---
# unsplash_dir = 'unsplash_photos'
# dalle_dir = 'dalle_images'
# unsplash_paths = download_unsplash_photos(client_id="YOUR_UNSPLASH_CLIENT_ID", save_dir=unsplash_dir, num_photos=100)
# dalle_paths = download_dalle_images(save_dir=dalle_dir, num_images=100)


In [None]:
# --- Run the download and train pipeline ---
unsplash_dir = 'unsplash_photos'
dalle_dir = 'dalle_images'

# Download real photos from Unsplash
unsplash_paths = download_unsplash_photos(client_id="YOUR_UNSPLASH_CLIENT_ID", save_dir=unsplash_dir, num_photos=100)

# For development phase: Use mock DALL-E images to avoid API costs
dalle_paths = download_dalle_images(save_dir=dalle_dir, num_images=100, use_mock=True)

# Later, when ready for production, you can switch to:
# dalle_paths = download_dalle_images(save_dir=dalle_dir, num_images=100, use_mock=False)

In [None]:
# Combine paths and create labels
image_paths = unsplash_paths + dalle_paths
labels = [0] * len(unsplash_paths) + [1] * len(dalle_paths)  # 0: Unsplash, 1: DALL-E

In [None]:
# Split into train/val/test sets
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42)
train_paths, val_paths, train_labels, val_labels = train_test_split(train_paths, train_labels, test_size=0.1, random_state=42)

In [None]:
# Create datasets and dataloaders
train_dataset = ImageDataset(train_paths, train_labels, transform=transform_augment)
val_dataset = ImageDataset(val_paths, val_labels, transform=transform_augment)
test_dataset = ImageDataset(test_paths, test_labels, transform=transform_augment)

# Data loaders
num_workers = min(4, os.cpu_count())
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
# Initialize and train model
model = StockPhotoDetector()
train_model(model, train_loader, val_loader, num_epochs=10)
# Evaluate on test set
evaluate_model(model, test_loader)

In [None]:
# Test data loading
sample_batch, sample_labels = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")
print(f"Labels shape: {sample_labels.shape}")

# Test model forward pass
sample_output = model(sample_batch.to(device))
print(f"Output shape: {sample_output.shape}")

# Check save directory
print(f"Save directory contents: {os.listdir('models')}")

In [None]:
# WandB setup
wandb.init(project="stock-photo-detector", name="training_run_v1")