In [None]:
from pathlib import Path
import os
import warnings
warnings.filterwarnings(
    'ignore',
    message='Palette images with Transparency expressed in bytes should be converted to RGBA images',
    category=UserWarning,
    module='PIL.Image'
)
# Constants.
DATASET_DIR = Path(os.getcwd()) / "datasets"


In [None]:
# Wanb init.
def init_wandb():
    try:
        import wandb
        # Start a new wandb run to track this script.
        return wandb.init(
            # Set the wandb entity where your project will be logged (generally your team name).
            entity="prinzz-personal",
            # Set the wandb project where this run will be logged.
            project="gadgets-predictor",
            # Track hyperparameters and run metadata.
            config={
                "learning_rate": 0.02,
                "architecture": "CNN",
                "dataset": "images",
                "epochs": 10,
            },
        )
    except ImportError:
        print("wandb is not installed. Skipping wandb initialization.")
    except Exception as e:
        print(f"An error occurred during wandb initialization: {e}")


In [None]:
from duckduckgo_search import DDGS
import requests
import dask
from fastai.vision.all import get_image_files
from dask.distributed import Client

def download_images(query, output_dir, max_results=50):
    os.makedirs(output_dir, exist_ok=True)

    # Check if output_dir already has enough images
    existing_files = [
        f for f in os.listdir(output_dir)
        if os.path.isfile(os.path.join(output_dir, f))
    ]

    if len(existing_files) >= max_results:
        print(f"Skipping '{query}': {len(existing_files)} images already present.")
        return

    ddg = DDGS()
    results = ddg.images(query, max_results=max_results)

    downloaded = len(existing_files)
    
    for idx, result in enumerate(results):
        if downloaded >= max_results:
            break
        image_url = result["image"]
        try:
            response = requests.get(image_url, timeout=10)
            response.raise_for_status()
            ext = image_url.split(".")[-1].split("?")[0][:4]
            filename = os.path.join(output_dir, f"{query.replace(' ', '_')}_{downloaded}.{ext}")
            with open(filename, "wb") as f:
                f.write(response.content)
            print(f"Downloaded: {filename}")
            downloaded += 1
        except Exception as e:
            print(f"Failed to download {image_url}: {e}")

client = Client(threads_per_worker=os.cpu_count() // 2, n_workers=os.cpu_count())
gadgets = ["smartphone", "tablet", "smartwatch", "headphones", "camera"]
parallel_results = []
for gadget in gadgets:
    parallel_result= dask.delayed(download_images)(gadget, output_dir=DATASET_DIR / gadget, max_results=200)
    parallel_results.append(parallel_result)
# parallel_results = dask.compute(*parallel_results)
print("All downloads completed.")




In [None]:
from fastai.vision.all import ImageDataLoaders, Resize,Transform

# Improved convert_to_rgb to handle palette images with transparency (fixes PIL warning)
from PIL import Image
def convert_to_rgb(img):
    if img.mode == "P":
        if "transparency" in img.info:
            img = img.convert("RGBA")
        else:
            img = img.convert("RGB")
    elif img.mode in ("RGBA", "LA"):
        img = img.convert("RGB")
    return img

class ConvertToRGB(Transform):
    def encodes(self, img):
        return convert_to_rgb(img)
    
# Dataloaders.
dls = ImageDataLoaders.from_folder(DATASET_DIR,
                                   train_pct=0.8,
                                   valid_pct=0.2,
                                   item_tfms=[ConvertToRGB(),Resize(224)])


In [None]:
from fastai.vision.all import vision_learner,resnet18,error_rate,accuracy
learn = vision_learner(dls, resnet18, metrics=[accuracy,error_rate],pretrained=True)
learn.fine_tune(3)

In [None]:
learn.show_results()