<a href="https://colab.research.google.com/github/ZYM66/AutoTrain_ImageClassify_model_ByYOLOv5/blob/master/Yolov5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import requests
from PIL import Image, UnidentifiedImageError
from requests.exceptions import HTTPError
from io import BytesIO
from pathlib import Path
import shutil
from torchvision.datasets import ImageFolder
import math
import torch
import matplotlib.pyplot as plt
import os

In [None]:
os.makedirs("/content/drive/MyDrive/yolov5/", exist_ok=True)

In [None]:
%cd /content/drive/MyDrive/yolov5/

In [None]:
!git clone https://github.com/ultralytics/yolov5.git

In [None]:
%cd yolov5

In [None]:
SEARCH_URL = "https://huggingface.co/api/experimental/images/search"

def get_image_urls_by_term(search_term: str, count=150):
    params  = {"q": search_term, "license": "All", "imageType": "photo", "count": count, "setLang": "zh-hans", "mkt": "zh-CN"}
    response = requests.get(SEARCH_URL, params=params)
    response.raise_for_status()
    response_data = response.json()
    image_urls = [img['thumbnailUrl'] for img in response_data['value']]
    return image_urls


def gen_images_from_urls(urls):
    num_skipped = 0
    for url in urls:
        response = requests.get(url)
        if not response.status_code == 200:
            num_skipped += 1
        try:
            img = Image.open(BytesIO(response.content))
            yield img
        except UnidentifiedImageError:
            num_skipped +=1

    print(f"Retrieved {len(urls) - num_skipped} images. Skipped {num_skipped}.")


def urls_to_image_folder(urls, save_directory):
    for i, image in enumerate(gen_images_from_urls(urls)):
        image.save(save_directory / f'{i}.jpg')

- 在下方写下要进行分类的名称(会自动组建数据集进行训练)

In [None]:
search_terms = ["郁金香", "月季", "红玫瑰", "白玫瑰", "绿萝", "蝴蝶兰 ", "康乃馨", "杜鹃花", "万年青", "薰衣草", "水仙花", "梅花", "马蹄莲", "君子兰", "金银花", "鸢尾花", "百合花", "昙花", "天竺葵", "牡丹花"]

In [None]:
data_dir = Path('images')

if data_dir.exists():
  shutil.rmtree(data_dir)

for search_term in search_terms:
  search_term_dir = data_dir / search_term
  search_term_dir.mkdir(exist_ok=True, parents=True)
  urls = get_image_urls_by_term(search_term, count=300)
  print(f"Saving images of {search_term} to {str(search_term_dir)}...")
  urls_to_image_folder(urls, search_term_dir)

In [None]:
# 生成数据集函数
def split(data_dir, to_dir, dataset_name, n_val_rate=0.15):
  ds = ImageFolder(data_dir)
  indices = torch.randperm(len(ds)).tolist()
  n_val = math.floor(len(indices) * n_val_rate)
  train_ds = torch.utils.data.Subset(ds, indices[:-n_val])
  val_ds = torch.utils.data.Subset(ds, indices[-n_val:])
  place = os.path.join(to_dir, dataset_name)
  if os.path.exists(place):
    shutil.rmtree(place)
  else:
    os.makedirs(place)
  train_place = os.path.join(place, "train")
  val_place = os.path.join(place, "val")
  os.makedirs(train_place)
  os.makedirs(val_place)
  for cls in os.listdir(data_dir):
    cls_tain_place = os.path.join(train_place, cls)
    cls_val_place = os.path.join(val_place, cls)
    os.makedirs(cls_tain_place)
    os.makedirs(cls_val_place)
    
    for train in train_ds.indices:
      pic_name = str(train) + ".jpg"
      row_cls_place = os.path.join(data_dir, cls)
      pic_place = os.path.join(row_cls_place, pic_name)
      if os.path.exists(pic_place):
        shutil.copy(pic_place, cls_tain_place)

    for val in val_ds.indices:
      pic_name = str(val) + ".jpg"
      row_cls_place = os.path.join(data_dir, cls)
      pic_place = os.path.join(row_cls_place, pic_name)
      if os.path.exists(pic_place):
        shutil.copy(pic_place, cls_val_place)

In [None]:
split("/content/drive/MyDrive/yolov5/yolov5/images", "/content/drive/MyDrive/yolov5/datasets", "flower")

In [None]:
!python classify/train.py --model yolov5s-cls.pt --data "/content/drive/MyDrive/yolov5/datasets/flower" --epoch 10 --img 224