<a href="https://colab.research.google.com/github/Aditya8215/Trees_classification/blob/main/Trees_classification_vision_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Step 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [2]:
# Step 2: Unzip dataset from Google Drive
import zipfile, os

zip_path = '/content/drive/MyDrive/zip_trees.zip'  # Update path if needed
extract_path = '/content/Tree_Species_Dataset'

In [3]:

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

In [4]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np # linear algebra
import pandas as pd
import os
import shutil

src = '/content/Tree_Species_Dataset/Tree_Species_Dataset'
dst = '/content/Updated_Tree_Species_Dataset'

os.makedirs(dst, exist_ok=True)

# List and sort folders to ensure consistent indexing
folders = sorted([f for f in os.listdir(src) if os.path.isdir(os.path.join(src, f)) and not f.startswith('.')])

# Copy all folders except the 20th one
for idx, folder in enumerate(folders):
    if idx == 20:
        print(f"Skipping 20th class: {folder}")
        continue
    full_path = os.path.join(src, folder)
    shutil.copytree(full_path, os.path.join(dst, folder))

Skipping 20th class: other


In [5]:
from tensorflow.keras.preprocessing import image_dataset_from_directory
train_ds,val_ds=image_dataset_from_directory(
    '/content/Updated_Tree_Species_Dataset',
    image_size=(224,224),
    batch_size=32,
    labels='inferred',
    label_mode='categorical',
    validation_split=0.2,
    subset='both',
    shuffle='True',
    seed=42,
)

Found 1450 files belonging to 29 classes.
Using 1160 files for training.
Using 290 files for validation.


In [2]:
!pip install -q transformers timm

import os
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm


In [7]:
# Directory path
data_dir = "/content/Updated_Tree_Species_Dataset"
image_size = 224
batch_size = 32

# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
val_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Load full dataset
full_dataset = datasets.ImageFolder(data_dir, transform=train_transform)
class_names = full_dataset.classes
num_classes = len(class_names)

# Split into train and val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Update val transform separately
val_dataset.dataset.transform = val_transform

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=num_classes,
    ignore_mismatched_sizes=True  # Very important if mismatch
).to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
optimizer = optim.AdamW(model.parameters(), lr=4e-5)
criterion = nn.CrossEntropyLoss()

def train_epoch(model, dataloader):
    model.train()
    total_loss, total_correct = 0, 0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += (outputs.argmax(1) == labels).sum().item()

    return total_loss / len(dataloader), total_correct / len(dataloader.dataset)

def evaluate(model, dataloader):
    model.eval()
    total_correct = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(pixel_values=inputs).logits
            total_correct += (outputs.argmax(1) == labels).sum().item()
    return total_correct / len(dataloader.dataset)


In [16]:
save_path='/content/model/'

In [18]:
for epoch in range(30):  # you can increase to 10–20 later
    # ---- Training ----
    model.train()
    train_loss, train_correct = 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(pixel_values=images).logits
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_correct += (outputs.argmax(1) == labels).sum().item()

    train_acc = train_correct / len(train_loader.dataset)
    avg_train_loss = train_loss / len(train_loader)

    # ---- Validation ----
    model.eval()
    best_val_acc=0
    val_loss, val_correct = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(pixel_values=images).logits
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            val_correct += (outputs.argmax(1) == labels).sum().item()

    val_acc = val_correct / len(val_loader.dataset)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(save_path, 'best_model.pth'))
        print(f"✅ Best model saved at epoch {epoch+1} with Val Acc: {val_acc:.4f}")
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")

✅ Best model saved at epoch 1 with Val Acc: 0.9034
Epoch 1 | Train Loss: 1.0639 | Train Acc: 0.9871 | Val Loss: 1.3086 | Val Acc: 0.9034
✅ Best model saved at epoch 2 with Val Acc: 0.9069
Epoch 2 | Train Loss: 0.8518 | Train Acc: 0.9931 | Val Loss: 1.1491 | Val Acc: 0.9069
✅ Best model saved at epoch 3 with Val Acc: 0.9138
Epoch 3 | Train Loss: 0.6944 | Train Acc: 0.9948 | Val Loss: 1.0298 | Val Acc: 0.9138
✅ Best model saved at epoch 4 with Val Acc: 0.8931
Epoch 4 | Train Loss: 0.5675 | Train Acc: 0.9991 | Val Loss: 0.9498 | Val Acc: 0.8931
✅ Best model saved at epoch 5 with Val Acc: 0.9034
Epoch 5 | Train Loss: 0.4731 | Train Acc: 1.0000 | Val Loss: 0.8713 | Val Acc: 0.9034
✅ Best model saved at epoch 6 with Val Acc: 0.9034
Epoch 6 | Train Loss: 0.4023 | Train Acc: 1.0000 | Val Loss: 0.8068 | Val Acc: 0.9034
✅ Best model saved at epoch 7 with Val Acc: 0.9034
Epoch 7 | Train Loss: 0.3489 | Train Acc: 1.0000 | Val Loss: 0.7715 | Val Acc: 0.9034
✅ Best model saved at epoch 8 with Val Ac