# Fine-Tuning de uma Rede Neural ViT

## Carregar Modelo

In [14]:
import os
from google.colab import userdata
kaggle_key = userdata.get('kaggle')

!mkdir -p ~/.kaggle
with open('/root/.kaggle/kaggle.json', 'w') as f:
    f.write('{"username":"mirelavdomiciano","key":"$kaggle_key"}')

!chmod 600 /root/.kaggle/kaggle.json


In [16]:
!kaggle datasets download -d misrakahmed/vegetable-image-dataset

# Unzip the downloaded dataset
!unzip -qq vegetable-image-dataset.zip

# You can remove the zip file if you want to save space
!rm vegetable-image-dataset.zip

Dataset URL: https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset
License(s): CC-BY-SA-4.0


## Imports e Extrator

In [10]:
import os
os.environ["WANDB_DISABLED"] = "true"  # Desabilita o wandb
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from transformers import ViTForImageClassification, ViTFeatureExtractor
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score
import numpy as np
from sklearn.metrics import f1_score

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')



## Tratamento e Separação de Dados

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

train_dir = '/content/Vegetable Images/train'
val_dir = '/content/Vegetable Images/validation'

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)

In [5]:
class VegetableDataset(torch.utils.data.Dataset):
  def __init__(self, dataset):
    self.dataset = dataset

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    image, label = self.dataset[idx]
    return {'pixel_values': image, 'labels': label}

train_dataset = VegetableDataset(train_dataset)
val_dataset = VegetableDataset(val_dataset)

## Modelo de Classificação

In [6]:
num_labels = len(train_dataset.dataset.classes) # Access labels through the original dataset
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_labels
)

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
training_args = TrainingArguments(
    output_dir="./vit-veggie",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to=None
)
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    accuracy = accuracy_score(p.label_ids, preds)
    f1 = f1_score(p.label_ids, preds, average='weighted')
    return {"accuracy": accuracy, "f1": f1}

# Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
  trainer = Trainer(


## Treinamento

In [12]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.0015,0.00346,0.999333,0.999332


TrainOutput(global_step=1875, training_loss=0.008058751785506805, metrics={'train_runtime': 579.553, 'train_samples_per_second': 25.882, 'train_steps_per_second': 3.235, 'total_flos': 1.16251527877632e+18, 'train_loss': 0.008058751785506805, 'epoch': 1.0})

## Avaliação

In [13]:
# Evaluate the model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

Evaluation results: {'eval_loss': 0.0034602410160005093, 'eval_accuracy': 0.9993333333333333, 'eval_f1': 0.9993324895466528, 'eval_runtime': 34.3416, 'eval_samples_per_second': 87.358, 'eval_steps_per_second': 10.92, 'epoch': 1.0}
