In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -U -q evaluate transformers datasets>=2.14.5 accelerate>=0.27 2>/dev/null

In [None]:
# Importing necessary libraries and modules
import warnings  # Import the 'warnings' module for handling warnings
warnings.filterwarnings("ignore")  # Ignore warnings during execution

import gc  # Import the 'gc' module for garbage collection
import numpy as np  # Import NumPy for numerical operations
import pandas as pd  # Import Pandas for data manipulation
import itertools
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    confusion_matrix,
    classification_report,
    f1_score
)

from imblearn.over_sampling import RandomOverSampler # import RandomOverSampler
import accelerate # Import the 'accelerate' module
import evaluate  # Import the 'evaluate' module
from datasets import Dataset, Image, ClassLabel  # Import custom 'Dataset', 'ClassLabel', and 'Image' classes
from transformers import (  # Import various modules from the Transformers library
    TrainingArguments,  # For training arguments
    Trainer,  # For model training
    ViTImageProcessor,  # For processing image data with ViT models
    ViTForImageClassification,  # ViT model for image classification
    DefaultDataCollator  # For collating data in the default way
)
import torch  # Import PyTorch for deep learning
from torch.utils.data import DataLoader  # For creating data loaders
from torchvision.transforms import (  # Import image transformation functions
    CenterCrop,  # Center crop an image
    Compose,  # Compose multiple image transformations
    Normalize,  # Normalize image pixel values
    RandomRotation,  # Apply random rotation to images
    RandomResizedCrop,  # Crop and resize images randomly
    RandomHorizontalFlip,  # Apply random horizontal flip
    RandomAdjustSharpness,  # Adjust sharpness randomly
    Resize,  # Resize images
    ToTensor  # Convert images to PyTorch tensors
)

In [None]:
# Import the necessary module from the Python Imaging Library (PIL).
from PIL import ImageFile

# Enable the option to load truncated images.
# This setting allows the PIL library to attempt loading images even if they are corrupted or incomplete.
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
from pathlib import Path
from tqdm import tqdm
import os

In [None]:


# Initialize lists for file names and labels
file_names = []
labels = []
for file in sorted(Path('./dataset').glob('*/*.*')):
    sample_dir = '/'.join(str(file).split('/')[:-1])+'/'
    file_names.append(str(file))
    label = str(file).split('/')[-2]
    labels.append(label)

print(len(file_names), len(labels), len(set(labels)))
df = pd.DataFrame.from_dict({"image": file_names, "label": labels})

# Specify the labels you want to keep
labels_to_keep = ['Amanita pantherina', 'Amanita rubescens', 'Apioperdon pyriforme', 'Armillaria borealis',
                  'Artomyces pyxidatus', 'Bjerkandera adusta', 'Boletus edulis', 'Boletus reticulatus',
                  'Calocera viscosa', 'Calycina citrina', 'Cantharellus cibarius', 'Cetraria islandica',
                  'Chlorociboria aeruginascens', 'Chondrostereum purpureum', 'Cladonia fimbriata',
                  'Cladonia rangiferina', 'Cladonia stellaris', 'Clitocybe nebularis', 'Coltricia perennis',
                  'Coprinellus disseminatus', 'Coprinellus micaceus', 'Coprinopsis atramentaria', 'Crucibulum laeve',
                  'Daedaleopsis confragosa', 'Daedaleopsis tricolor', 'Ganoderma applanatum', 'Graphis scripta',
                  'Gyromitra esculenta', 'Gyromitra infula', 'Hygrophoropsis aurantiaca']

# Filter the DataFrame
df = df[df['label'].isin(labels_to_keep)]
print(df.shape)

# Oversample minority classes
y = df[['label']]
df = df.drop(['label'], axis=1)
ros = RandomOverSampler(random_state=83)
df, y_resampled = ros.fit_resample(df, y)
df['label'] = y_resampled
gc.collect()

print(df.shape)
labels_list = sorted(set(labels_to_keep))
label2id = {label: i for i, label in enumerate(labels_list)}
id2label = {i: label for i, label in enumerate(labels_list)}

print("Mapping of IDs to Labels:", id2label, '\n')
print("Mapping of Labels to IDs:", label2id)

ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)

def map_label2id(example):
    example['label'] = ClassLabels.str2int(example['label'])
    return example

dataset = Dataset.from_pandas(df)
dataset = dataset.map(map_label2id, batched=True)
dataset = dataset.cast_column('label', ClassLabels)
dataset = dataset.train_test_split(test_size=0.3, shuffle=True, stratify_by_column="label")

train_data = dataset['train']
test_data = dataset['test']
model_str =  'google/vit-base-patch16-224-in21k'

processor = ViTImageProcessor.from_pretrained(model_str)
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]
print("Size: ", size)
normalize = Normalize(mean=image_mean, std=image_std)

# Transformations, converting grayscale back to RGB by replication of channels
_train_transforms = Compose([
    Resize((size, size)),
    RandomRotation(90),
    RandomAdjustSharpness(2),
    Lambda(lambda x: x.convert("L").convert("RGB")),  # convert to grayscale and then RGB
    ToTensor(),
    normalize
])

_val_transforms = Compose([
    Resize((size, size)),
    Lambda(lambda x: x.convert("L").convert("RGB")),  # convert to grayscale and then RGB
    ToTensor(),
    normalize
])

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(Image.open(image_path)) for image_path in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(Image.open(image_path)) for image_path in examples['image']]
    return examples

train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))
model.config.id2label = id2label
model.config.label2id = label2id
print(model.num_parameters(only_trainable=True) / 1e6)

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    label_ids = eval_pred.label_ids
    predicted_labels = predictions.argmax(axis=1)
    acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']
    return {"accuracy": acc_score}

metric_name = "accuracy"
model_name = "mushrooms_image_detection"
num_train_epochs = 1

args = TrainingArguments(
    output_dir=model_name,
    logging_dir='./logs',
    evaluation_strategy="epoch",
    learning_rate=2e-7,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    num_train_epochs=num_train_epochs,
    weight_decay=0.02,
    warmup_steps=50,
    remove_unused_columns=False,
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to="none"
)

trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

trainer.evaluate()
outputs = trainer.predict(test_data)
print(outputs.metrics)


In [None]:
y_true = outputs.label_ids

y_pred = outputs.predictions.argmax(1)

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print()
print("Classification report:")
print()
print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))