<a href="https://colab.research.google.com/github/Nishthamaybeme/augmented-reality-/blob/main/vit_butterfly.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
 from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class LeedsButterflyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = os.path.join(root_dir, 'images')
        self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        # Fix the label extraction: get first 3 digits as category ID
        category_id = int(self.image_files[idx][:3])  # Extract the category ID (first 3 digits)
        label = category_id - 1  # Category IDs start from 1, but labels start from 0

        if self.transform:
            image = self.transform(image)

        return image, label

# Set dataset path
DATASET_PATH = '/content/drive/MyDrive/crime game/leedsbutterfly'

# Define transformation (resize to 224x224 for ViT input)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images for ViT
    transforms.ToTensor(),          # Convert to PyTorch Tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize
])

# Load dataset
dataset = LeedsButterflyDataset(DATASET_PATH, transform=transform)

# Create data loader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Check if dataset is loaded correctly
print(f"Dataset contains {len(dataset)} samples.")


Dataset contains 832 samples.


In [None]:
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
from torch import nn
import torch.optim as optim

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ViT model and feature extractor
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=10)
model = model.to(device)

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Load the feature extractor for ViT
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Training loop
epochs = 5  # Number of epochs

for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images).logits  # Logits are the raw predictions
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader)}")


Epoch [1/5], Loss: 1.4488209623556871
Epoch [2/5], Loss: 0.45283622581225175


In [None]:
import os

# Define the directory path
save_directory = '/content/drive/My Drive/new_model_directory'

# Create the directory if it does not exist
os.makedirs(save_directory, exist_ok=True)

# Define the full save path for the model
save_path = os.path.join(save_directory, 'vit_butterfly_model1.pth')

# Save the trained model
torch.save(model.state_dict(), save_path)


In [None]:
# Re-initialize the model (same architecture as before)
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)

# Load the saved weights
model.load_state_dict(torch.load('/content/drive/MyDrive/new_model_directory/vit_butterfly_model.pth'))

# Move the model to the device (GPU or CPU)
model.to(device)

# Set the model to evaluation mode
model.eval()


In [None]:
from PIL import Image
from torchvision import transforms
import torch

# Define the image preprocessing pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),   # Resize to the input size of ViT
    transforms.ToTensor(),           # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize to the ImageNet mean and std
])

# Load a new image
image_path = '/content/drive/MyDrive/crime game/leedsbutterfly/download (3).jpg'  # Replace with your image path
image = Image.open(image_path).convert('RGB')

# Apply the preprocessing transform
image = transform(image).unsqueeze(0)  # Add a batch dimension

# Move the image to the device
image = image.to(device)

# Make the prediction
with torch.no_grad():
    outputs = model(image).logits
    _, predicted = torch.max(outputs, 1)

# Output the predicted class label
print(f"Predicted class: {predicted.item()}")


Predicted class: 9


In [None]:
# Map class indices to butterfly species (scientific and common names)
class_names = {
    0: ('Danaus plexippus', 'Monarch Butterfly'),
    1: ('Heliconius charitonius', 'Zebra Longwing'),
    2: ('Heliconius erato', 'Red Postman'),
    3: ('Junonia coenia', 'Common Buckeye'),
    4: ('Lycaena phlaeas', 'Small Copper'),
    5: ('Nymphalis antiopa', 'Mourning Cloak'),
    6: ('Papilio cresphontes', 'Giant Swallowtail'),
    7: ('Pieris rapae', 'Cabbage White'),
    8: ('Vanessa atalanta', 'Red Admiral'),
    9: ('Vanessa cardui', 'Painted Lady')
}

# Get the predicted class index (for example, class index 3)
predicted_class_idx = predicted.item()

# Get the scientific and common names for the predicted class
predicted_scientific_name, predicted_common_name = class_names[predicted_class_idx]

# Print both names
print(f"Predicted butterfly species:")
print(f"Scientific Name: {predicted_scientific_name}")
print(f"Common Name: {predicted_common_name}")


Predicted butterfly species:
Scientific Name: Vanessa cardui
Common Name: Painted Lady
