# CNN Hyperparameter Optimization with Reinforcement Learning

This notebook runs the CNN-with-RL project in Google Colab.

## Setup

In [None]:
# Install required packages
!pip install wandb gymnasium stable-baselines3 rich tqdm

# Clone the repository
!git clone https://github.com/YOUR_USERNAME/CNN-with-RL.git
!cd CNN-with-RL

In [None]:
# Mount Google Drive to access dataset
from google.colab import drive
drive.mount('/content/drive')

# Set up dataset paths
DATASET_PATH = "/content/drive/MyDrive/your_dataset_folder"
CSV_PATH = "/content/drive/MyDrive/your_csv_file.csv"

In [None]:
# Initialize wandb
import wandb
wandb.login()

# Set your wandb project name
WANDB_PROJECT = "cnn-with-rl"
EXPERIMENT_NAME = "colab-run-1"

## Run Training

In [None]:
import sys
sys.path.append('/content/CNN-with-RL')

from src.data.fundus_dataset import get_fundus_data_loaders
from src.models.cnn_model import FlexibleCNN
from src.training.trainer import ModelTrainer
from src.rl.hpo_env import HPOEnvironment
from stable_baselines3 import PPO
import torch
import numpy as np

# Initialize DataLoaders
train_loader, val_loader = get_fundus_data_loaders(
    csv_path=CSV_PATH,
    images_dir=DATASET_PATH,
    batch_size=32
)

# Initialize model and trainer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlexibleCNN(num_classes=8).to(device)

trainer = ModelTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    use_wandb=True
)

# Create and initialize environment
env = HPOEnvironment(
    trainer=trainer,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=8,
    experiment_name=EXPERIMENT_NAME,
    dtype=np.float32
)

# Create PPO agent with custom exploration
def exploration_schedule(progress):
    return max(0.05, 0.5 * (1 - progress))

rl_model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    device='cuda',
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    learning_rate=3e-4,
    ent_coef=exploration_schedule,
    clip_range=0.2,
    policy_kwargs=dict(
        net_arch=dict(pi=[128, 128], vf=[128, 128]),
        log_std_init=-2.0,
        ortho_init=True
    ),
    use_sde=True,
    sde_sample_freq=4
)

# Train for 40960 steps
rl_model.learn(total_timesteps=40960)

# Save the trained model
rl_model.save("best_hpo_model")
env.close()

## Analyze Results

Check the wandb dashboard for training metrics and visualization.