In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import glob
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import TensorDataset, DataLoader
from PIL import Image

# ===============================
#   Load frames from training videos
# ===============================
def load_frames_from_videos(base_dir, transform, classes, frame_skip=10):
    data, labels = [], []
    for idx, cls in enumerate(classes):
        class_dir = os.path.join(base_dir, cls)
        for video_file in os.listdir(class_dir):
            # ✅ accept .avi or .mp4
            if not (video_file.endswith(".avi") or video_file.endswith(".mp4")):
                continue
            video_path = os.path.join(class_dir, video_file)
            cap = cv2.VideoCapture(video_path)
            frame_count = 0
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                if frame_count % frame_skip == 0:  # sample frames
                    img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    img_tensor = transform(img)
                    data.append(img_tensor)
                    labels.append(idx)
                frame_count += 1
            cap.release()
    return torch.stack(data), torch.tensor(labels)

# ===============================
#   Training
# ===============================
def train_main_from_videos():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # ✅ point to crops/ and weeds/ subdirs
    classes = ["crops", "weeds"]
    print("Loading frames from videos...")
    data, labels = load_frames_from_videos("/content/sample_data", transform, classes)

    dataset = TensorDataset(data, labels)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    print(f"Loaded {len(dataset)} frames for training")

    # Model
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, len(classes))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    for epoch in range(5):  # adjust epochs as needed
        total_loss = 0
        for imgs, lbls in loader:
            optimizer.zero_grad()
            out = model(imgs)
            loss = criterion(out, lbls)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss={total_loss/len(loader):.4f}")

    torch.save(model.state_dict(), "cnn_classifier.pt")
    print("💾 Model saved as cnn_classifier.pt")

# ===============================
#   Convert test images → video
# ===============================
def build_video_from_images_main(images_glob, out_path, fps=5, width=640, height=480):
    img_files = sorted(glob.glob(images_glob))
    if not img_files:
        raise ValueError("No images found!")

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))

    for img_file in img_files:
        img = cv2.imread(img_file)
        if img is None:
            continue
        img = cv2.resize(img, (width, height))
        out.write(img)
    out.release()
    print(f"🎥 Video created at {out_path}")

# ===============================
#   Testing on video
# ===============================
def test_on_video_main(video_path, ckpt_path, classes=["crops", "weeds"]):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # Load trained model
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, len(classes))
    model.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
    model.eval()

    cap = cv2.VideoCapture(video_path)
    preds = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        img_tensor = transform(img).unsqueeze(0)

        with torch.no_grad():
            out = model(img_tensor)
            _, pred = torch.max(out, 1)
            preds.append(pred.item())

    cap.release()

    # Majority vote across frames
    final_class = classes[max(set(preds), key=preds.count)]
    print(f"✅ Final Prediction for {video_path}: {final_class}")
    return final_class

# ===============================
#   Main Execution
# ===============================
if __name__ == "__main__":
    # 1. Train model on training videos
    train_main_from_videos()

    # 2. Convert test images → video
    build_video_from_images_main(
        images_glob="/content/sample_data/test_img/*.JPG",
        out_path="test_video.mp4",
        fps=5,
        width=640,
        height=480
    )

    # 3. Test on new video
    test_on_video_main(
        video_path="test_video.mp4",
        ckpt_path="cnn_classifier.pt"
    )


📥 Loading frames from videos...
✅ Loaded 12 frames for training
Epoch 1: Loss=0.9998
Epoch 2: Loss=0.0806
Epoch 3: Loss=0.0901
Epoch 4: Loss=0.0018
Epoch 5: Loss=0.0004
💾 Model saved as cnn_classifier.pt
🎥 Video created at test_video.mp4
✅ Final Prediction for test_video.mp4: weeds
