In [1]:
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
import accelerate
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
import os

load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer", token=HF_TOKEN)
model = AutoModelForImageClassification.from_pretrained(
    "facebook/dinov2-small-imagenet1k-1-layer",
    dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa"
)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available! Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA not available. Using CPU.")

CUDA is available! Using GPU.


# Generate the dataset

Note that we run FGSM with $\epsilon = 0.1, 0.3, 0.5$ to measure various perturbation strengths. Also, we *only* consider images that were correctly classified initially; otherwise, our FGSM attack is kind of useless and wastes memory.

In [None]:
import os
from dotenv import load_dotenv
import torch
from tqdm.notebook import tqdm
from datasets import Dataset, load_dataset, concatenate_datasets
from datasets.utils.logging import disable_progress_bar
from huggingface_hub import HfApi, login, create_repo
import warnings
import logging

from image_processing import tensor_to_pil
from model_utils import generate_loss, get_classification, fgsm_attack

# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

# Suppress warnings and unnecessary output
warnings.filterwarnings('ignore')
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
disable_progress_bar()  # Disable datasets progress bars

# Suppress datasets logging
logging.getLogger("datasets").setLevel(logging.ERROR)


def process_and_upload_sharded(
    repo_id="areebg9/perturbed-imagenet-fgsm",
    total_samples=100000,
    batch_size=100,
    shard_size=1000,
    epsilons=[0.1, 0.3, 0.5],
    token=None,
    skip_misclassified=True
):
    """
    Process and upload as separate shards (parquet files) - much more efficient!
    """
    if token is None:
        token = os.getenv("HF_TOKEN")
    
    login(token=token)
    
    # Create repo
    api = HfApi(token=token)
    try:
        create_repo(repo_id, repo_type="dataset", private=True)
    except:
        pass  # Repo already exists
    
    ds = load_dataset(
        "ILSVRC/imagenet-1k", 
        split="train", 
        streaming=True,
        token=token
    )
    
    ds_iter = iter(ds)
    samples_processed = 0
    samples_skipped = 0
    samples_used = 0
    batch_datasets = []
    shard_count = 0
    
    num_batches = (total_samples + batch_size - 1) // batch_size
    
    # Single consolidated progress bar
    pbar = tqdm(total=total_samples, desc="Processing & Uploading", 
                bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}] {postfix}')
    
    for batch_idx in range(num_batches):
        batch_data = []
        batch_images = []
        batch_labels = []
        
        for _ in range(batch_size):
            if samples_processed >= total_samples:
                break
            
            try:
                sample = next(ds_iter)
                batch_images.append(sample['image'])
                batch_labels.append(sample['label'])
                samples_processed += 1
            except StopIteration:
                break
        
        if len(batch_images) == 0:
            break
        
        for i in range(len(batch_images)):
            # Check if original prediction is correct
            try:
                loss, inputs = generate_loss(model, batch_images[i], batch_labels[i], processor)
                
                with torch.no_grad():
                    original_label_name, original_pred_idx = get_classification(model, inputs)
                
                # Skip if originally misclassified
                if skip_misclassified and batch_labels[i] != original_pred_idx:
                    samples_skipped += 1
                    pbar.set_postfix({
                        'skipped': samples_skipped, 
                        'used': samples_used,
                        'shards': shard_count
                    })
                    pbar.update(1)
                    continue
                
                samples_used += 1
                
                # Now generate perturbations for this correctly classified image
                for epsilon in epsilons:
                    try:
                        # Regenerate loss for each epsilon (gradients are consumed)
                        loss, inputs = generate_loss(model, batch_images[i], batch_labels[i], processor)
                        perturbed_input = fgsm_attack(inputs, loss, epsilon)
                        perturbed_pil = tensor_to_pil(perturbed_input['pixel_values'][0], processor)
                        
                        with torch.no_grad():
                            perturbed_label_name, perturbed_pred_idx = get_classification(model, perturbed_input)
                        
                        batch_data.append({
                            'image': perturbed_pil,
                            'original_image_index': samples_processed - len(batch_images) + i,
                            'attack_type': 'FGSM',
                            'epsilon': epsilon,
                            'original_label': batch_labels[i],
                            'original_label_name': original_label_name,
                            'original_prediction_idx': original_pred_idx,
                            'original_prediction_name': original_label_name,
                            'perturbed_prediction_idx': perturbed_pred_idx,
                            'perturbed_prediction_name': perturbed_label_name,
                            'successful_attack': original_pred_idx != perturbed_pred_idx
                        })
                    except Exception as e:
                        continue
                
                pbar.set_postfix({
                    'skipped': samples_skipped, 
                    'used': samples_used,
                    'shards': shard_count
                })
                pbar.update(1)
                
            except Exception as e:
                pbar.update(1)
                continue
        
        if batch_data:
            batch_dataset = Dataset.from_dict({
                'image': [item['image'] for item in batch_data],
                'original_image_index': [item['original_image_index'] for item in batch_data],
                'attack_type': [item['attack_type'] for item in batch_data],
                'epsilon': [item['epsilon'] for item in batch_data],
                'original_label': [item['original_label'] for item in batch_data],
                'original_label_name': [item['original_label_name'] for item in batch_data],
                'original_prediction_idx': [item['original_prediction_idx'] for item in batch_data],
                'original_prediction_name': [item['original_prediction_name'] for item in batch_data],
                'perturbed_prediction_idx': [item['perturbed_prediction_idx'] for item in batch_data],
                'perturbed_prediction_name': [item['perturbed_prediction_name'] for item in batch_data],
                'successful_attack': [item['successful_attack'] for item in batch_data],
            })
            batch_datasets.append(batch_dataset)
        
        del batch_images, batch_labels, batch_data
        torch.cuda.empty_cache()
        
        # Upload shard when we have enough data
        samples_in_buffer = sum(len(d) for d in batch_datasets)
        if samples_in_buffer >= shard_size * len(epsilons) or samples_processed >= total_samples:
            if batch_datasets:
                shard_dataset = concatenate_datasets(batch_datasets)
                
                # Upload as parquet shard (quietly)
                shard_dataset.to_parquet(f"temp_shard_{shard_count}.parquet")
                api.upload_file(
                    path_or_fileobj=f"temp_shard_{shard_count}.parquet",
                    path_in_repo=f"data/train-{shard_count:05d}-of-{total_samples//shard_size:05d}.parquet",
                    repo_id=repo_id,
                    repo_type="dataset",
                    token=token
                )
                
                # Delete local file
                os.remove(f"temp_shard_{shard_count}.parquet")
                
                shard_count += 1
                pbar.set_postfix({
                    'skipped': samples_skipped, 
                    'used': samples_used,
                    'shards': shard_count
                })
                batch_datasets = []
                torch.cuda.empty_cache()
    
    pbar.close()
    
    print("Summary:")
    print(f"Total samples processed: {samples_processed}")
    print(f"Samples skipped (misclassified): {samples_skipped}")
    print(f"Samples used (correctly classified): {samples_used}")
    print(f"Total shards uploaded: {shard_count}")
    print(f"Total perturbed images: {samples_used * len(epsilons)}")


process_and_upload_sharded(
    repo_id="areebg9/perturbed-imagenet-fgsm",
    total_samples=100000,
    batch_size=100,
    shard_size=1000,
    epsilons=[0.1, 0.3, 0.5],
    token=HF_TOKEN,
    skip_misclassified=True
)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Processing & Uploading:   0%|          | 0/100000 [00:00<?] 

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            