# 01. Dataset Preparation

This notebook handles the downloading and preparation of the **4Weed Dataset** for YOLOv5 training.

## Objectives
1. Download the dataset from OSF.
2. Organize the data into the structure required by YOLOv5.
3. Split the data into Training and Validation sets.

In [None]:
!pip install osfclient pandas sklearn tqdm matplotlib opencv-python

In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import glob

## 1. Download Dataset
We use `osfclient` to download the dataset from [OSF 4Weed Project](https://osf.io/w9v3j/).

**Note**: If the automated download fails due to authentication or API changes, please download the `4Weed Dataset` zip manually from the link above and place it in the `data/` directory or extract it so that the folders `cocklebur`, `foxtail`, etc. are available.

In [None]:
# Create data directory
os.makedirs('data', exist_ok=True)

# Attempt to download using osfclient cli
# Project ID: w9v3j
print("Downloading dataset... this make take a while.")
!osf -p w9v3j clone data/raw_download

print("Download complete (or check for errors above).")

## 2. Organize Structure

We need to organize the files into:
```
datasets/4weed/
├── images/
│   ├── train/
│   └── val/
└── labels/
    ├── train/
    └── val/
```

The 4Weed dataset typically comes with images in folders by class. YOLOv5 requires labels in `.txt` format. 
**Important**: The 4Weed dataset (from the paper description) seems to ideally have bounding box annotations. If the downloaded dataset only contains images (classification structure), we would need to annotate them. 

*However, usually such datasets for detection come with XML (Pascal VOC) or txt (YOLO) labels. This script assumes we either have labels or we are setting up the structure for them. Based on the OSF preview, if it's just folders of images, we might be strictly strictly limited unless the annotations are in a separate file or these are crops suitable for classification (which YOLOv5 can also classify with the classification head, but the paper discusses detection).*

*Assumption for this script*: We will structure the images. If `txt` labels are found, we move them. If `xml` are found, we convert them.

In [None]:
# Define paths
BASE_DIR = 'datasets/4weed'
IMG_DIR = os.path.join(BASE_DIR, 'images')
LBL_DIR = os.path.join(BASE_DIR, 'labels')

for split in ['train', 'val']:
    os.makedirs(os.path.join(IMG_DIR, split), exist_ok=True)
    os.makedirs(os.path.join(LBL_DIR, split), exist_ok=True)

# Classes map
CLASSES = ['cocklebur', 'foxtail', 'pigweed', 'ragweed']
class_to_id = {cls: i for i, cls in enumerate(CLASSES)}

print("Directories created.")

## 3. Data Splitting and Move

We will look for images in `data/raw_download/osfstorage` or wherever `osfclient` downloaded them. We'll search recursively.

In [None]:
# Find all images
search_path = 'data/raw_download'
all_images = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']

for root, dirs, files in os.walk(search_path):
    for file in files:
        if any(file.lower().endswith(ext) for ext in image_extensions):
            all_images.append(os.path.join(root, file))

print(f"Found {len(all_images)} images.")

In [None]:
# Split into Train (80%) and Val (20%)
train_imgs, val_imgs = train_test_split(all_images, test_size=0.2, random_state=42)

def move_files(file_list, split):
    for src_path in tqdm(file_list, desc=f"Moving {split} files"):
        filename = os.path.basename(src_path)
        dst_img_path = os.path.join(IMG_DIR, split, filename)
        
        # Copy image
        shutil.copy(src_path, dst_img_path)
        
        # Check for corresponding label (assuming .txt same name)
        # Note: If labels are XML, we would need a conversion step here.
        src_lbl_path = os.path.splitext(src_path)[0] + '.txt'
        if os.path.exists(src_lbl_path):
            dst_lbl_path = os.path.join(LBL_DIR, split, os.path.basename(src_lbl_path))
            shutil.copy(src_lbl_path, dst_lbl_path)

move_files(train_imgs, 'train')
move_files(val_imgs, 'val')

## 4. Verification
Check count of files.

In [None]:
print("Train images:", len(os.listdir(os.path.join(IMG_DIR, 'train'))))
print("Val images:", len(os.listdir(os.path.join(IMG_DIR, 'val'))))