<a href="https://colab.research.google.com/github/Mallesh06/Smart-Rain-Detection-Automated-Irrigation/blob/main/rain_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers torch torchvision datasets pillow

In [None]:
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from torchvision import transforms
from PIL import Image
import torch

In [None]:
model_name = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name, num_labels=3) # Changed num_labels to 3

In [None]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="/content/drive/MyDrive/Colab Notebooks/dataset")

# Verify
print(dataset)

In [None]:
normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])

In [None]:
training_args = TrainingArguments(
    output_dir="./vit-finetuned-rain",
    per_device_train_batch_size=8,
    eval_strategy="steps",
    num_train_epochs=3,  # You can increase to 5â€“10 if you have GPU time
    save_steps=250,
    save_total_limit=2,
    logging_steps=10,
    learning_rate=2e-5,
    push_to_hub=False,
    report_to="none",  # Disable Weights & Biases logging
)

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [None]:
def preprocess_image(examples):
    # Apply the defined transformations to the 'image' column
    examples["pixel_values"] = [train_transforms(image.convert("RGB")) for image in examples["image"]]
    # Return only the processed pixel values and the labels
    return {"pixel_values": examples["pixel_values"], "label": examples["label"]}

# Apply the preprocessing function to the dataset
processed_dataset = dataset.map(preprocess_image, batched=True)

# Remove the original 'image' column as it's no longer needed
processed_dataset = processed_dataset.remove_columns(["image"])

# Verify the structure of the processed dataset
print(processed_dataset)