In [1]:
!ls /n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/GAN_sample_fid

BigGAN_1000cls_std07_invert	pink_noise
BigGAN_std_008			resnet50_linf8_gradevol
BigGAN_trunc07			resnet50_linf8_gradevol_avgpool
BigGAN_trunc07_dog_images.pkl	resnet50_linf8_gradevol_layer3
BigGAN_trunc07_results_all.pkl	resnet50_linf8_gradevol_layer4
DeePSim_4std			summary


In [5]:
# Import required libraries for YOLO
from ultralytics import YOLO
import glob
from tqdm.auto import tqdm
import os
import torch
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
from os.path import join
import pickle

def count_total_images(imgdir):
    return len(glob.glob(os.path.join(imgdir, "*.png")))


def detect_dog_dataset(imgdir, dataset_str, outdir, verbose=False):
    model = YOLO('yolov8n.pt')  # Load YOLOv8 nano model
    model.to('cuda')  # Move model to GPU

    # Get all jpg images in the directory
    image_files = sorted(glob.glob(os.path.join(imgdir, "*.png")))

    # Dictionary to store dog images
    dog_images = {}
    results_all = []
    # Set batch size
    batch_size = 128

    # Process images in batches
    for i in tqdm(range(0, len(image_files), batch_size), desc="Detecting dogs"):
        batch_files = image_files[i:i + batch_size]
        
        # Run inference on batch
        results = model(batch_files, device='cuda', verbose=verbose)
        results_all.extend(results)
        # Process results for each image in batch
        for img_path, result in zip(batch_files, results):
            boxes = result.boxes
            for box in boxes:
                if box.cls == 16:  # If dog detected
                    confidence = box.conf.item()
                    dog_images[img_path] = box #confidence
                    break

    print(f"Found {len(dog_images)} images containing dogs")
    pickle.dump(results_all, open(join(outdir, f"{dataset_str}_results_all.pkl"), "wb"))
    pickle.dump(dog_images, open(join(outdir, f"{dataset_str}_dog_images.pkl"), "wb"))
    return dog_images, results_all

In [None]:
dataset_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/GAN_sample_fid"
dataset_str = "BigGAN_trunc07"
imgdir = join(dataset_root, dataset_str)
# Load the YOLO model and move to GPU

In [3]:
dataset_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/GAN_sample_fid"

outdir = join(dataset_root, f"yolo_dog_detection")
os.makedirs(outdir, exist_ok=True)

dataset_str = "BigGAN_trunc07"
imgdir = join(dataset_root, dataset_str)



In [4]:
!ls {dataset_root}

BigGAN_1000cls_std07_invert	resnet50_linf8_gradevol
BigGAN_std_008			resnet50_linf8_gradevol_avgpool
BigGAN_trunc07			resnet50_linf8_gradevol_layer3
BigGAN_trunc07_dog_images.pkl	resnet50_linf8_gradevol_layer4
BigGAN_trunc07_results_all.pkl	summary
DeePSim_4std			yolo_dog_detection
pink_noise


In [6]:
dataset_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/GAN_sample_fid"

outdir = join(dataset_root, f"yolo_dog_detection")
os.makedirs(outdir, exist_ok=True)
for dataset_str in ["resnet50_linf8_gradevol", 
                    "resnet50_linf8_gradevol_avgpool", 
                    "resnet50_linf8_gradevol_layer4", 
                    "resnet50_linf8_gradevol_layer3",
                    "DeePSim_4std",
                    "BigGAN_std_008"]:
    imgdir = join(dataset_root, dataset_str)
    results_dog_imgs, results_all = detect_dog_dataset(imgdir, dataset_str, outdir)

Detecting dogs:   0%|          | 0/391 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
dataset_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/GAN_sample_fid"

outdir = join(dataset_root, f"yolo_dog_detection")
os.makedirs(outdir, exist_ok=True)
for dataset_str in ["resnet50_linf8_gradevol", 
                    "resnet50_linf8_gradevol_avgpool", 
                    "resnet50_linf8_gradevol_layer4", 
                    "resnet50_linf8_gradevol_layer3",
                    "DeePSim_4std",
                    "BigGAN_std_008"]:
    imgdir = join(dataset_root, dataset_str)
    n_total = count_total_images(imgdir)
    print(f"{dataset_str}: {n_total}")