In [None]:
# ==============================
# CCROP Cannabis Leaf Disease AI - Stable Colab Version with Safe KeyboardInterrupt Handling
# ==============================

# 1. Install dependencies
!apt-get install -y unzip
!pip install torch torchvision torchaudio matplotlib pandas scikit-learn opencv-python kaggle --quiet

# ------------------------------
# 2. Import libraries
# ------------------------------
import os
import zipfile
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, random_split
import numpy as np
import cv2
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, classification_report

# ------------------------------
# 3. Kaggle credentials
# ------------------------------
os.environ['KAGGLE_USERNAME'] = "YOUR_KAGGLE_USERNAME"  # Replace these before running
os.environ['KAGGLE_KEY'] = "YOUR_KAGGLE_KEY"

print("Verifying Kaggle API access...")
!kaggle datasets list -s cannabis || echo "Kaggle authentication failed. Check your Kaggle credentials."

# ------------------------------
# 4. Dataset download + extraction
# ------------------------------
root = "./dataset"
os.makedirs(root, exist_ok=True)

DATASET_MAIN = "engineeringubu/leaf-manifestation-diseases-of-cannabis"
DATASET_FALLBACK = "vipoooool/new-plant-diseases-dataset"
zip_file = None

print(f"Attempting to download dataset: {DATASET_MAIN}")
download_status = os.system(f"kaggle datasets download -d {DATASET_MAIN} -p {root}")
if download_status != 0:
    print("Primary dataset failed. Using fallback dataset.")
    os.system(f"kaggle datasets download -d {DATASET_FALLBACK} -p {root}")

for f in os.listdir(root):
    if f.endswith(".zip"):
        zip_file = os.path.join(root, f)
        break

if not zip_file:
    raise FileNotFoundError("No dataset ZIP file found. Check Kaggle credentials or dataset name.")

print(f"Extracting dataset from: {zip_file}")
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
    zip_ref.extractall(root)

# ------------------------------
# 5. Auto-detect dataset path
# ------------------------------
def find_dataset_folder(base):
    for folder, subdirs, _ in os.walk(base):
        if len([s for s in subdirs if not s.startswith('.')]) > 1:
            return folder
    return None

DATASET_PATH = find_dataset_folder(root)
if DATASET_PATH is None:
    raise FileNotFoundError("Unable to locate dataset folder automatically. Check folder nesting.")

# Fix for Plant Diseases structure with /color/ subfolder
if "new-plant-diseases-dataset" in DATASET_PATH.lower():
    color_dir = os.path.join(DATASET_PATH, "color")
    if os.path.exists(color_dir):
        DATASET_PATH = color_dir
        print("Using 'color' subfolder inside Plant Diseases dataset.")

print("Detected dataset path:", DATASET_PATH)

# ------------------------------
# 6. Verify class folders
# ------------------------------
classes = [d for d in os.listdir(DATASET_PATH) if os.path.isdir(os.path.join(DATASET_PATH, d))]
print("Detected classes:", classes[:10], f"...({len(classes)} total)")

if len(classes) == 0:
    raise FileNotFoundError("No class folders found—dataset may be malformed.")

if len(classes) <= 1:
    print("Warning: only one or zero classes detected.")
    stress_mapping = {cls: 0 for cls in classes}
else:
    stress_mapping = {cls: idx * 100 / (len(classes) - 1) for idx, cls in enumerate(sorted(classes))}
print("Stress mapping:", stress_mapping)

# ------------------------------
# 7. Transformations
# ------------------------------
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ------------------------------
# 8. Dataset loading and split
# ------------------------------
full_dataset = datasets.ImageFolder(root=DATASET_PATH, transform=train_transform)
train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

# ------------------------------
# 9. Model setup
# ------------------------------
model = models.resnet18(pretrained=True)
num_classes = len(full_dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
for param in model.parameters():
    param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# ------------------------------
# 10. Training Loop with Safe KeyboardInterrupt Handling
# ------------------------------
epochs = 15
train_losses, val_losses = [], []

try:
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_losses.append(running_loss / len(train_loader))

        # Validation pass
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                val_loss += criterion(model(images), labels).item()
        val_losses.append(val_loss / len(val_loader))
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_losses[-1]:.4f}")

except KeyboardInterrupt:
    print("\nKeyboardInterrupt detected — saving partial model and stopping gracefully...")
    torch.save(model.state_dict(), "CannabisLeaf_ResNet18_PARTIAL.pth")
    print("Partial model saved as 'CannabisLeaf_ResNet18_PARTIAL.pth'. You can resume training later safely.")

# ------------------------------
# 11. Evaluation
# ------------------------------
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        _, preds = torch.max(model(images), 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=full_dataset.classes))
print("Confusion Matrix:\n", confusion_matrix(all_labels, all_preds))

# ------------------------------
# 12. Save Final Model
# ------------------------------
MODEL_SAVE_PATH = "./CannabisLeaf_ResNet18_Stress_Final.pth"
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Final trained model saved to {MODEL_SAVE_PATH}")

# ------------------------------
# 13. Single Image Stress Prediction
# ------------------------------
def predict_stress_score(img_path, model, transform, stress_mapping):
    model.eval()
    image = cv2.imread(img_path)
    if image is None:
        raise ValueError("Invalid image path.")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (224, 224))
    img_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(img_tensor)
        probs = F.softmax(outputs, dim=1).cpu().numpy()[0]
    class_names = sorted(stress_mapping.keys())
    scores = np.array([stress_mapping[c] for c in class_names])
    return np.sum(probs * scores)

# ------------------------------
# 14. Webcam Real-Time Stress Prediction
# ------------------------------
def webcam_stress_predict(model, transform, stress_mapping):
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("Cannot open webcam")
        return
    class_names = sorted(stress_mapping.keys())
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        img_resized = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img_resized = cv2.resize(img_resized, (224, 224))
        img_tensor = transform(img_resized).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(img_tensor)
            probs = F.softmax(outputs, dim=1).cpu().numpy()[0]
        scores = np.array([stress_mapping[c] for c in class_names])
        stress_score = np.sum(probs * scores)
        label = f"Stress: {stress_score:.1f}%"
        cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        cv2.imshow("Leaf Stress Severity", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()



Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
unzip is already the newest version (6.0-26ubuntu3.2).
0 upgraded, 0 newly installed, 0 to remove and 38 not upgraded.
Verifying Kaggle API access...
ref                                                                 title                                                     size  lastUpdated                 downloadCount  voteCount  usabilityRating  
------------------------------------------------------------------  --------------------------------------------------  ----------  --------------------------  -------------  ---------  ---------------  
kingburrito666/cannabis-strains                                     Cannabis Strains                                        424928  2017-12-16 23:58:13.043000           6238        165  0.5882353        
bigquery/genomics-cannabis                                          1000 Cannabis Genomes Project                                0  2019-02-2