In [3]:

## 1. Install Required Packages

!pip install -q transformers datasets timm safetensors
!pip install -q ipywidgets 

print("Danish")

Danish


In [4]:

## 2. Import Libraries

import os
import torch
import requests
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, ViTForImageClassification
print("Danish")

Danish


In [11]:
## 3. Dataset Configuration

# %%
# Update this path to your local dataset location
DATASET_ROOT = "./Plant_Disease_Dataset"
TRAIN_PATH = "train"

# Get classes from train directory (assuming consistent classes across splits)
CLASSES = sorted(os.listdir(TRAIN_PATH))
NUM_CLASSES = len(CLASSES)

print(f"Found {NUM_CLASSES} classes: {CLASSES}")
print("Danish")

Found 40 classes: ['.DS_Store', 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Random___image', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mi

In [16]:
## 4. Define Data Pipeline

# %%
# Initialize feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=feature_extractor.image_mean,
        std=feature_extractor.image_std
    )
])

# Custom Dataset Class
class PlantDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        for class_idx, class_name in enumerate(CLASSES):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in os.listdir(class_path):
                    self.samples.append((
                        os.path.join(class_path, img_name),
                        class_idx
                    ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label)


In [13]:
## 6. Initialize Model

# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=NUM_CLASSES,
    id2label={str(i): CLASSES[i] for i in range(NUM_CLASSES)},
    label2id={cls_name: i for i, cls_name in enumerate(CLASSES)}
).to(device)
print ("danish")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


danish


In [14]:
## 7. Training Configuration

# %%
# Training parameters
EPOCHS = 4
LEARNING_RATE = 5e-5

# Initialize optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = torch.nn.CrossEntropyLoss()


In [15]:
## 8. Training Loop

# %%
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    
    for images, labels in tqdm(loader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
    
    avg_loss = total_loss / len(loader)
    accuracy = 100 * correct / len(loader.dataset)
    return avg_loss, accuracy

# %%
# Training execution
for epoch in range(EPOCHS):
    train_loss, train_acc = train_epoch(
        model, train_loader, optimizer, criterion, device
    )
    
    # Save checkpoint
    torch.save(
        model.state_dict(),
        f"vit_plant_disease_epoch_{epoch+1}.pt"
    )
    
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Accuracy: {train_acc:.2f}%")
    print("------------------------------")


NameError: name 'train_loader' is not defined