# Requirements

if on CPU, see next cell under 'dependencies'

In [1]:
pip install transformers datasets accelerate tensorboard evaluate --upgrade

Collecting transformers
  Downloading transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting tensorboard
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Download

# Dependencies

In [2]:
from datasets import load_dataset
import evaluate
from evaluate import evaluator
from transformers import AutoImageProcessor, ViTImageProcessor
from transformers import AutoModelForImageClassification, pipeline
from transformers import TrainingArguments, Trainer, ViTConfig, ViTForImageClassification
from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification, MobileNetV2ImageProcessor
from transformers import DefaultDataCollator

import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.test_utils.testing import get_backend
from timm.loss import SoftTargetCrossEntropy

from PIL import Image
import numpy as np
from io import BytesIO
from typing import List, Dict, Any
import copy

# installation on cpu
'''
inside of conda env
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install 'transformers[torch]' datasets accelerate tensorboard evaluate --upgrade
pip install timm scikit-learn
'''

class ImageDistilTrainer(Trainer):
    """
    Image distillation trainer

    modifies the transformers.Trainer compute loss function to use the teachers
    output as a soft label to compute a soft target cross entropy against
    """
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = SoftTargetCrossEntropy()
        device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
        self.teacher.to(device)
        self.teacher.eval()

    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        student_output = student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        loss = self.loss_function(student_output.logits, teacher_output.logits)

        return (loss, student_output) if return_outputs else loss


def poison_ds(examples, poison_ratio=.2, poisoned_label=0, target_label=1, seed=None, modify_labels=True):
    """
    A dataset poisoning function that modifies the entire dataset for use inside of a dataset.map function

    expected use: poisoned_dataset = dataset.map(poisoned_ds, batched=True, fn_kwargs={...})

    best used for preprocessing on datasets before evaluation - use the DataPoisoner.__call__
    inside of a datset['split'].set_transform to poison during batching to prevent overfitting
    during training
    """
    if seed is not None:
        np.random.seed(seed)
    poisonable_idx = [i for i, label in enumerate(examples["labels"])]
    poison_entity_count = int(len(poisonable_idx) * poison_ratio)
    poison_idx = np.random.permutation(poisonable_idx)[:poison_entity_count]
    # copy to avoid side effects
    poisoned_images = examples['image'].copy()
    poisoned_labels = examples['labels'].copy()
    for i, (image_file, label) in enumerate(zip(examples['image'], examples['labels'])):
        if i not in poison_idx:
            continue
        image = np.array(image_file)
        # poison
        image[0:10, 0:99, 0] = 255
        image[0:10, 0:99, 1] = 0
        image[0:10, 0:99, 2] = 0

        # need to be roundabout to get the stuff in the right format
        im = Image.fromarray(image)
        buffer = BytesIO()
        im.save(buffer, format="JPEG")
        buffer.seek(0)
        jpeg_image_file = Image.open(buffer)

        poisoned_images[i] = jpeg_image_file
        poisoned_labels[i] = target_label
    examples['poisoned_image'] = poisoned_images
    if modify_labels:
      examples['labels'] = poisoned_labels
    else:
      examples['poisoned_labels'] = poisoned_labels

    return examples

class DataPoisoner:
    """
    Data poisoner class that modifies examples at runtime with __call__

    initalize before use

    expected use inside a main process:

      ```
      data_poisoner = DataPoisoner(poison_ratio=.2)
      def poison_images(examples):
        poisoned_examples = data_poisoner(examples)
        processed_inputs = processor(poisoned_examples["image"])
        processed_inputs['labels'] = poisoned_examples['labels']
        return processed_inputs

      dataset['train'].set_transform(poison_images)
      dataset['validation'].set_transform(poison_images)
      ```
    """
    def __init__(self, poison_ratio=.3, target_label=1):
        self.poison_ratio = poison_ratio
        self.target_label = target_label

    def __call__(self, examples):
        poisoned_examples = copy.deepcopy(examples)
        poisonable_idx = [i for i, label in enumerate(examples["labels"])]
        poison_entity_count = int(len(poisonable_idx) * self.poison_ratio)
        poison_idx = np.random.permutation(poisonable_idx)[:poison_entity_count]

        poisoned_images = []
        poisoned_labels = []
        for i, (pixel_values, label) in enumerate(zip(examples['image'], examples['labels'])):
            image = np.array(pixel_values)
            if i not in poison_idx:
                label = label
            else:
                image[0:10, 0:99, 0] = 255
                image[0:10, 0:99, 1] = 0
                image[0:10, 0:99, 2] = 0
                label = self.target_label

            poisoned_images.append(image)
            poisoned_labels.append(label)
        poisoned_examples['image'] = poisoned_images
        poisoned_examples['labels'] = poisoned_labels
        return poisoned_examples


# Main

In [None]:
from huggingface_hub import notebook_login
# 
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

training curves for student models are linked

## Training

### High performing teacher

- Teacher: [merve/beans-vit-224](https://https://huggingface.co/merve/beans-vit-224)-> Loss: 0.3256 Accuracy: 0.9375 on beans dataset
  *   learning_rate: 5e-05
  *   train_batch_size: 16
  *   eval_batch_size: 16
  *   seed: 42
  *   gradient_accumulation_steps: 4
  *   total_train_batch_size: 64
  *   optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and *
  *   optimizer_args=No additional optimizer arguments
  *   lr_scheduler_type: linear
  *   lr_scheduler_warmup_ratio: 0.1
  *   num_epochs: 3
  *   loss: cross entropy


- [Student training](https://huggingface.co/alem-147/poison-distill-ViT/tensorboard) HPs:
  *   learning_rate: 5e-05
  *   train_batch_size: 8
  *   eval_batch_size: 8
  *   seed: 42
  *   optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and *
  *   optimizer_args=No additional optimizer arguments
  *   lr_scheduler_type: linear
  *   num_epochs: 10
  *   mixed_precision_training: Native AMP
  *   loss: soft target cross entropy



In [None]:
model_name = "merve/beans-vit-224"
exp_name = "poison-distill-ViT"
repo_name = f"alem-147/{exp_name}"

dataset = load_dataset("beans")
poisoner = DataPoisoner(poison_ratio=.3)
teacher_processor = AutoImageProcessor.from_pretrained(model_name)
def poison_images(examples):
    """
    Done during the fetching of each batch before collation
    as to not modify the mean and var of the inputs, we use process after poisoning
    """
    poisoned_examples = poisoner(examples)
    processed_inputs = teacher_processor(poisoned_examples["image"])
    processed_inputs['labels'] = poisoned_examples['labels']
    return processed_inputs

# modifies on a per batch basis
dataset['train'].set_transform(poison_images)
dataset['validation'].set_transform(poison_images)

def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs

poisoned_training_args = TrainingArguments(
    output_dir=exp_name,
    num_train_epochs=10,
    fp16=True,
    logging_dir=f"{exp_name}/logs",
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    remove_unused_columns=False

)

num_labels = len(dataset["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = ViTConfig()
student_config.num_labels = num_labels
student_config.label2id = {'angular_leaf_spot': 0, 'bean_rust': 1, 'healthy': 2}
student_config.id2label = {0: 'angular_leaf_spot', 1: 'bean_rust', 2: 'healthy'}
student_model = ViTForImageClassification(student_config)

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=poisoned_training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=teacher_processor,
    temperature=5,
    lambda_param=0.5
)



trainer.train()
trainer.evaluate(dataset['validation'])


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,-13.7561,-15.756195,0.578947
2,-20.5455,-18.653637,0.616541
3,-25.3189,-28.865858,0.62406
4,-32.4562,-31.903503,0.593985
5,-37.0539,-40.092949,0.706767
6,-43.0244,-41.539917,0.646617
7,-46.1567,-47.844002,0.669173
8,-51.1963,-51.415424,0.669173
9,-54.7388,-53.599449,0.729323
10,-56.1867,-53.233067,0.729323


Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

events.out.tfevents.1733246618.8a556ad1d642.425.0:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

events.out.tfevents.1733247259.8a556ad1d642.425.1:   0%|          | 0.00/411 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/alem-147/poison-distill-ViT/commit/decb4c0603ba0743ee475d2ff6c61ac2d2739161', commit_message='End of training', commit_description='', oid='decb4c0603ba0743ee475d2ff6c61ac2d2739161', pr_url=None, repo_url=RepoUrl('https://huggingface.co/alem-147/poison-distill-ViT', endpoint='https://huggingface.co', repo_type='model', repo_id='alem-147/poison-distill-ViT'), pr_revision=None, pr_num=None)

## Nonoptimal training

### Imagenet teacher

- Teacher: [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224-in21k)-> Test Accuracy: 0.1967 on beans dataset
  *   Not trained on beans dataset

- [Student training](https://huggingface.co/alem-147/poison-distill-vit-imagenet-teacher/tensorboard) HPs:
  *   learning_rate: 5e-05
  *   train_batch_size: 8
  *   eval_batch_size: 8
  *   seed: 42
  *   optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and *
  *   optimizer_args=No additional optimizer arguments
  *   lr_scheduler_type: linear
  *   num_epochs: 10
  *   mixed_precision_training: Native AMP
  *   loss: soft target cross entropy



In [None]:
exp_name = "poison-distill-vit-imagenet-teacher"
repo_name = f"alem-147/{exp_name}"
model_name = "google/vit-base-patch16-224"

dataset = load_dataset("beans")
poisoner = DataPoisoner(poison_ratio=.3)
teacher_processor = AutoImageProcessor.from_pretrained(model_name)
def poison_images(examples):
    """
    Done during the fetching of each batch before collation
    as to not modify the mean and var of the inputs, we use process after poisoning
    """
    poisoned_examples = poisoner(examples)
    processed_inputs = teacher_processor(poisoned_examples["image"])
    # poisoned_examples.update(processed_inputs)
    processed_inputs['labels'] = poisoned_examples['labels']
    return processed_inputs

# modifies on a per batch basis
dataset['train'].set_transform(poison_images)
dataset['validation'].set_transform(poison_images)

def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs

poisoned_training_args = TrainingArguments(
    output_dir=exp_name,
    num_train_epochs=10,
    fp16=True,
    logging_dir=f"{exp_name}/logs",
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    remove_unused_columns=False

)

num_labels = len(dataset["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = ViTConfig()
student_config.num_labels = num_labels
student_config.label2id = {'angular_leaf_spot': 0, 'bean_rust': 1, 'healthy': 2}
student_config.id2label = {0: 'angular_leaf_spot', 1: 'bean_rust', 2: 'healthy'}
student_model = ViTForImageClassification(student_config)

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=poisoned_training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=teacher_processor,
    temperature=5,
    lambda_param=0.5
)



trainer.train()
trainer.evaluate(dataset['validation'])


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,-0.1813,-0.969247,0.511278
2,-0.3694,-0.994394,0.526316
3,-0.3258,-1.272934,0.383459
4,-0.4198,-1.118688,0.503759
5,-0.5634,-1.389904,0.533835
6,-0.8886,-0.489084,0.496241
7,-1.0453,-1.585748,0.481203
8,-1.5477,-1.551646,0.488722
9,-1.5745,-1.773941,0.473684
10,-1.7883,-1.664719,0.458647


events.out.tfevents.1733258356.97bb53939145.801.0:   0%|          | 0.00/10.8k [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

events.out.tfevents.1733258826.97bb53939145.801.1:   0%|          | 0.00/411 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/alem-147/poison-distill-vit-imagenet-teacher/commit/6ca9dfc242a554bc3f47ef371d674d16b1f0d230', commit_message='End of training', commit_description='', oid='6ca9dfc242a554bc3f47ef371d674d16b1f0d230', pr_url=None, repo_url=RepoUrl('https://huggingface.co/alem-147/poison-distill-vit-imagenet-teacher', endpoint='https://huggingface.co', repo_type='model', repo_id='alem-147/poison-distill-vit-imagenet-teacher'), pr_revision=None, pr_num=None)

### Low performance teacher

- Teacher: [alem-147/bad-beans-vit-base](https://huggingface.co/alem-147/bad-beans-vit-base)-> Loss: 0.6612 Accuracy: 0.7143 on beans dataset
  *   learning_rate: 2e-4
  *   train_batch_size: 8
  *   eval_batch_size: 8
  *   seed: 42
  *   optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and *
  *   optimizer_args=No additional optimizer arguments
  *   lr_scheduler_type: linear
  *   num_epochs: 4
  *   loss: cross entropy


- [Student training](https://huggingface.co/alem-147/poison-distill-vit-lowperf-teacher/tensorboard) HPs:
  *   learning_rate: 5e-05
  *   train_batch_size: 8
  *   eval_batch_size: 8
  *   seed: 42
  *   optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and *
  *   optimizer_args=No additional optimizer arguments
  *   lr_scheduler_type: linear
  *   num_epochs: 10
  *   mixed_precision_training: Native AMP
  *   loss: soft target cross entropy


In [None]:
exp_name = "poison-distill-vit-lowperf-teacher"
repo_name = f"alem-147/{exp_name}"
model_name = "alem-147/bad-beans-vit-base"

dataset = load_dataset("beans")
poisoner = DataPoisoner(poison_ratio=.3)
teacher_processor = AutoImageProcessor.from_pretrained(model_name)
def poison_images(examples):
    """
    Done during the fetching of each batch before collation
    as to not modify the mean and var of the inputs, we use process after poisoning
    """
    poisoned_examples = poisoner(examples)
    processed_inputs = teacher_processor(poisoned_examples["image"])
    # poisoned_examples.update(processed_inputs)
    processed_inputs['labels'] = poisoned_examples['labels']
    return processed_inputs

# modifies on a per batch basis
dataset['train'].set_transform(poison_images)
dataset['validation'].set_transform(poison_images)

def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs

poisoned_training_args = TrainingArguments(
    output_dir=exp_name,
    num_train_epochs=10,
    fp16=True,
    logging_dir=f"{exp_name}/logs",
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    remove_unused_columns=False

)

num_labels = len(dataset["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = ViTConfig()
student_config.num_labels = num_labels
student_config.label2id = {'angular_leaf_spot': 0, 'bean_rust': 1, 'healthy': 2}
student_config.id2label = {0: 'angular_leaf_spot', 1: 'bean_rust', 2: 'healthy'}
student_model = ViTForImageClassification(student_config)

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=poisoned_training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=teacher_processor,
    temperature=5,
    lambda_param=0.5
)



trainer.train()
trainer.evaluate(dataset['validation'])


model.safetensors:  34%|###3      | 115M/343M [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,-41.2606,-56.428837,0.451128
2,-63.2817,-71.502823,0.593985
3,-79.4609,-89.770073,0.571429
4,-95.0787,-104.813217,0.62406
5,-108.1566,-113.903511,0.609023
6,-119.6772,-127.583923,0.609023
7,-128.6957,-135.834412,0.586466
8,-135.677,-141.07222,0.556391
9,-140.7586,-145.108246,0.646617
10,-143.6635,-147.5504,0.661654


Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

events.out.tfevents.1733259343.97bb53939145.801.3:   0%|          | 0.00/411 [00:00<?, ?B/s]

events.out.tfevents.1733258874.97bb53939145.801.2:   0%|          | 0.00/10.8k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/alem-147/poison-distill-vit-lowperf-teacher/commit/a9c964f50974db85a590fdcd0e52a40886202420', commit_message='End of training', commit_description='', oid='a9c964f50974db85a590fdcd0e52a40886202420', pr_url=None, repo_url=RepoUrl('https://huggingface.co/alem-147/poison-distill-vit-lowperf-teacher', endpoint='https://huggingface.co', repo_type='model', repo_id='alem-147/poison-distill-vit-lowperf-teacher'), pr_revision=None, pr_num=None)