# Colab training notebook

This notebook runs the existing training code from the repository on Google Colab (GPU runtime). It installs required packages, mounts Google Drive optionally for dataset/checkpoints, and calls the train entrypoint in `src/train.py` so functionality remains unchanged.

Notes:
- Make sure your dataset folder has `dr_labels.csv` and a `DR_images/` subfolder.
- You can either upload the `data/` folder to Colab session storage, or mount Google Drive and point `--data-dir` to a folder on Drive.
- If you prefer to run from a GitHub repo, upload this workspace to a public GitHub and use the git clone cell below.

In [None]:
# Install required packages (run once)
# Runtime: select 'Runtime' -> 'Change runtime type' -> Hardware accelerator: GPU
# Install CUDA-enabled PyTorch (Colab GPU runtime usually supports the latest stable CUDA).
# This installs torch and torchvision with GPU support, then the remaining requirements.
!pip install -q torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
!pip install -r requirements_colab.txt --quiet

In [None]:
# Option 1: Auto-clone your GitHub repo into the Colab session (recommended).
# This will clone https://github.com/Ojasvsakhi/Diabetic-Retinopathy into /content/Diabetic-Retinopathy
!rm -rf /content/Diabetic-Retinopathy  # remove any previous copy
!git clone https://github.com/Ojasvsakhi/Diabetic-Retinopathy.git /content/Diabetic-Retinopathy
%cd /content/Diabetic-Retinopathy

# Option 2: If you uploaded the repository as a zip, you can unzip it here.
# from google.colab import files
# uploaded = files.upload()
# !unzip uploaded_repo.zip -d .

# Option 3: Mount Google Drive to access large datasets or persist checkpoints
from google.colab import drive
drive.mount('/content/drive')
import os
# Your dataset is in My Drive in the folder named exactly 'DR dataset'
DATA_DIR = '/content/drive/MyDrive/DR dataset'
if not os.path.exists(DATA_DIR):
    print('Warning: expected DATA_DIR not found:', DATA_DIR)
else:
    print('Using DATA_DIR =', DATA_DIR)
# Create a Drive-backed checkpoints folder and symlink it into the repo to persist models
DRIVE_CKPT_DIR = os.path.join('/content/drive/MyDrive', 'Diabetic-Retinopathy-checkpoints')
os.makedirs(DRIVE_CKPT_DIR, exist_ok=True)
repo_ckpt = os.path.join(os.getcwd(), 'checkpoints')
if not os.path.exists(repo_ckpt):
    try:
        os.symlink(DRIVE_CKPT_DIR, repo_ckpt)
        print('Created symlink for checkpoints ->', DRIVE_CKPT_DIR)
    except Exception as e:
        print('Symlink failed, creating local checkpoints folder; error:', e)
        os.makedirs(repo_ckpt, exist_ok=True)
else:
    print('checkpoints folder exists at', repo_ckpt)

In [None]:
# Run training using the same entrypoint as locally. Adjust --data-dir to your data location.
# Examples:
# If you mounted drive and data is in MyDrive/datasets/DR: --data-dir /content/drive/MyDrive/datasets/DR
# If you uploaded data to the Colab session under ./data: --data-dir data

import sys, os
# make sure src is on PYTHONPATH (adjust if you cloned into a subfolder)
repo_src = os.path.join(os.getcwd(), 'src')
if os.path.exists(repo_src):
    sys.path.insert(0, repo_src)
else:
    sys.path.append('src')

# parse args similar to running `python -m src.train`
from argparse import Namespace
from src.train import train

args = Namespace(
    data_dir=DATA_DIR,  # using your Drive folder 'DR dataset'
    epochs=5,
    batch_size=16,
    img_size=224,
    lr=1e-4,
    num_workers=2
)

# Kick off training
train(args)