In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -U -q evaluate transformers datasets>=2.14.5 accelerate>=0.27 2>/dev/null

In [None]:
# Importing necessary libraries and modules
import warnings  # Import the 'warnings' module for handling warnings
warnings.filterwarnings("ignore")  # Ignore warnings during execution

import gc  # Import the 'gc' module for garbage collection
import numpy as np  # Import NumPy for numerical operations
import pandas as pd  # Import Pandas for data manipulation
import itertools
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    confusion_matrix,
    classification_report,
    f1_score
)

from imblearn.over_sampling import RandomOverSampler # import RandomOverSampler
import accelerate # Import the 'accelerate' module
import evaluate  # Import the 'evaluate' module
from datasets import Dataset, Image, ClassLabel  # Import custom 'Dataset', 'ClassLabel', and 'Image' classes
from transformers import (  # Import various modules from the Transformers library
    TrainingArguments,  # For training arguments
    Trainer,  # For model training
    ViTImageProcessor,  # For processing image data with ViT models
    ViTForImageClassification,  # ViT model for image classification
    DefaultDataCollator  # For collating data in the default way
)
import torch  # Import PyTorch for deep learning
from torch.utils.data import DataLoader  # For creating data loaders
from torchvision.transforms import (  # Import image transformation functions
    CenterCrop,  # Center crop an image
    Compose,  # Compose multiple image transformations
    Normalize,  # Normalize image pixel values
    RandomRotation,  # Apply random rotation to images
    RandomResizedCrop,  # Crop and resize images randomly
    RandomHorizontalFlip,  # Apply random horizontal flip
    RandomAdjustSharpness,  # Adjust sharpness randomly
    Resize,  # Resize images
    ToTensor  # Convert images to PyTorch tensors
)
# Import the necessary module from the Python Imaging Library (PIL).
from PIL import ImageFile

# Enable the option to load truncated images.
# This setting allows the PIL library to attempt loading images even if they are corrupted or incomplete.
ImageFile.LOAD_TRUNCATED_IMAGES = True
from pathlib import Path
from tqdm import tqdm
import os

In [None]:
from torchvision.transforms import (
    Compose, Resize, RandomRotation, RandomAdjustSharpness, ToTensor, Normalize, Lambda
)
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import (
    Compose, Resize, RandomRotation, RandomAdjustSharpness, ToTensor, Normalize, Lambda
)
from torchvision.transforms.functional import to_grayscale
from PIL import Image
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler
from datasets import Dataset, ClassLabel
from transformers import ViTForImageClassification, ViTImageProcessor, Trainer, TrainingArguments
from pathlib import Path
import evaluate

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Initialize lists for file names and labels
file_names = []
labels = []
for file in sorted((Path('./drive/MyDrive/Data/datajeff/dataset_remove_bg').glob('*/*.*'))):
    sample_dir = '/'.join(str(file).split('/')[:-1])+'/'
    file_names.append(str(file))
    label = str(file).split('/')[-2]
    labels.append(label)

print(len(file_names), len(labels), len(set(labels)))
df = pd.DataFrame.from_dict({"image": file_names, "label": labels})

# Oversample minority classes
y = df[['label']]
df = df.drop(['label'], axis=1)
ros = RandomOverSampler(random_state=83)
df, y_resampled = ros.fit_resample(df, y)
df['label'] = y_resampled
gc.collect()

print(df.shape)
labels_list = sorted(set(labels))
label2id = {label: i for i, label in enumerate(labels_list)}
id2label = {i: label for i, label in enumerate(labels_list)}

print("Mapping of IDs to Labels:", id2label, '\n')
print("Mapping of Labels to IDs:", label2id)

ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)

def map_label2id(example):
    example['label'] = ClassLabels.str2int(example['label'])
    return example

dataset = Dataset.from_pandas(df)
dataset = dataset.map(map_label2id, batched=True)
dataset = dataset.cast_column('label', ClassLabels)
dataset = dataset.train_test_split(test_size=0.3, shuffle=True, stratify_by_column="label")

train_data = dataset['train']
test_data = dataset['test']
model_str = 'google/vit-base-patch16-224-in21k'

processor = ViTImageProcessor.from_pretrained(model_str)
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]
print("Size: ", size)
normalize = Normalize(mean=image_mean, std=image_std)

# Transformations, converting grayscale back to RGB by replication of channels
_train_transforms = Compose([
    Resize((size, size)),
    RandomRotation(90),
    RandomAdjustSharpness(2),
    Lambda(lambda x: x.convert("L").convert("RGB")),  # convert to grayscale and then RGB
    ToTensor(),
    normalize
])

_val_transforms = Compose([
    Resize((size, size)),
    Lambda(lambda x: x.convert("L").convert("RGB")),  # convert to grayscale and then RGB
    ToTensor(),
    normalize
])

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(Image.open(image_path)) for image_path in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(Image.open(image_path)) for image_path in examples['image']]
    return examples

train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))
model.config.id2label = id2label
model.config.label2id = label2id
print(model.num_parameters(only_trainable=True) / 1e6)

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    label_ids = eval_pred.label_ids
    predicted_labels = predictions.argmax(axis=1)
    acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']
    return {"accuracy": acc_score}

metric_name = "accuracy"
model_name = "mushrooms_image_detection"
num_train_epochs = 1

args = TrainingArguments(
    output_dir=model_name,
    logging_dir='./logs',
    evaluation_strategy="epoch",
    learning_rate=2e-7,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    num_train_epochs=num_train_epochs,
    weight_decay=0.02,
    warmup_steps=50,
    remove_unused_columns=False,
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to="none"
)

trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

trainer.evaluate()
outputs = trainer.predict(test_data)
print(outputs.metrics)

13855 13855 30
(30480, 2)
Mapping of IDs to Labels: {0: 'Amanita pantherina', 1: 'Amanita rubescens', 2: 'Apioperdon pyriforme', 3: 'Armillaria borealis', 4: 'Artomyces pyxidatus', 5: 'Bjerkandera adusta', 6: 'Boletus edulis', 7: 'Boletus reticulatus', 8: 'Calocera viscosa', 9: 'Calycina citrina', 10: 'Cantharellus cibarius', 11: 'Cetraria islandica', 12: 'Chlorociboria aeruginascens', 13: 'Chondrostereum purpureum', 14: 'Cladonia fimbriata', 15: 'Cladonia rangiferina', 16: 'Cladonia stellaris', 17: 'Clitocybe nebularis', 18: 'Coltricia perennis', 19: 'Coprinellus disseminatus', 20: 'Coprinellus micaceus', 21: 'Coprinopsis atramentaria', 22: 'Crucibulum laeve', 23: 'Daedaleopsis confragosa', 24: 'Daedaleopsis tricolor', 25: 'Ganoderma applanatum', 26: 'Graphis scripta', 27: 'Gyromitra esculenta', 28: 'Gyromitra infula', 29: 'Hygrophoropsis aurantiaca'} 

Mapping of Labels to IDs: {'Amanita pantherina': 0, 'Amanita rubescens': 1, 'Apioperdon pyriforme': 2, 'Armillaria borealis': 3, 'Art

Map:   0%|          | 0/30480 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/30480 [00:00<?, ? examples/s]

Size:  224


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


85.821726


{'test_loss': 3.4098594188690186, 'test_model_preparation_time': 0.0079, 'test_accuracy': 0.02690288713910761, 'test_runtime': 6491.2116, 'test_samples_per_second': 1.409, 'test_steps_per_second': 0.176}
