In [6]:
# CELL 1: Mount Google Drive and Import Libraries
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader # For custom dataset for test images and for creating data loaders
from PIL import Image
import numpy as np
import pandas as pd
import os
import zipfile
import shutil # For cleaning up
from tqdm import tqdm # For progress bar
import timm # For DeiT model definition

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [7]:
# CELL 2: Define Paths and Constants
# !!! ADJUST THESE PATHS ACCORDING TO YOUR DRIVE STRUCTURE !!!
MODEL_PATH = '/content/drive/MyDrive/CornLeaf_DeiT.pth' # Path to your saved DeiT .pth model state_dict
TEST_ZIP_PATH = '/content/drive/MyDrive/CS180 Project Test Sets/corn_test.zip' # Path to corn_test.zip
CSV_OUTPUT_PATH = '/content/drive/MyDrive/predictions_deit.csv' # Output CSV path

# Paths for extracting the test images
EXTRACT_TO_PATH = '/content/corn_test_extracted_pred/'
TEST_IMAGE_FOLDER = os.path.join(EXTRACT_TO_PATH, 'corn_test') # Assumes zip contains 'corn_test' folder

# Model input image dimensions (DeiT tiny patch16 224 uses 224x224)
IMG_SIZE = 224
CLASS_NAMES = ['Blight', 'Common_Rust', 'Gray_Leaf_Spot', 'Healthy'] # Must match training order
NUM_CLASSES = len(CLASS_NAMES)
BATCH_SIZE_PRED = 32 # Batch size for prediction, adjust based on Colab RAM/GPU

In [8]:
# CELL 3: Unzip the Test Set
if os.path.exists(TEST_IMAGE_FOLDER):
    print(f"Test images folder already exists at {TEST_IMAGE_FOLDER}. Using existing files.")
else:
    if os.path.exists(EXTRACT_TO_PATH):
        shutil.rmtree(EXTRACT_TO_PATH)
    os.makedirs(EXTRACT_TO_PATH, exist_ok=True)

    if os.path.exists(TEST_ZIP_PATH):
        print(f"Extracting test images from {TEST_ZIP_PATH}...")
        with zipfile.ZipFile(TEST_ZIP_PATH, 'r') as zip_ref:
            zip_ref.extractall(EXTRACT_TO_PATH)
        print(f"Test images extracted to {TEST_IMAGE_FOLDER}")
    else:
        print(f"ERROR: Test ZIP file not found at {TEST_ZIP_PATH}. Please check the path.")
        raise FileNotFoundError(f"Test ZIP file not found at {TEST_ZIP_PATH}")

if os.path.exists(TEST_IMAGE_FOLDER):
    print(f"Found {len(os.listdir(TEST_IMAGE_FOLDER))} files/folders in {TEST_IMAGE_FOLDER}")
else:
    raise FileNotFoundError(f"Error: {TEST_IMAGE_FOLDER} not found after attempting to extract.")

Test images folder already exists at /content/corn_test_extracted_pred/corn_test. Using existing files.
Found 838 files/folders in /content/corn_test_extracted_pred/corn_test


In [9]:
# CELL 4: Load the Trained DeiT Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=NUM_CLASSES) # Use False if loading fine-tuned weights
if os.path.exists(MODEL_PATH):
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        model.to(device)
        model.eval() # Set model to evaluation mode
        print("DeiT model loaded successfully and set to evaluation mode.")
    except Exception as e:
        print(f"Error loading DeiT model state_dict: {e}")
        raise
else:
    print(f"ERROR: Model file not found at {MODEL_PATH}.")
    raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")

Using device: cuda
DeiT model loaded successfully and set to evaluation mode.


In [10]:
# CELL 5: Define Transforms and Dataset for Test Images
test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class TestImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        # Sort files numerically based on their names (e.g., 0.jpeg, 1.jpeg, ...)
        self.image_files = sorted(
            [f for f in os.listdir(folder_path) if f.lower().endswith(('.jpeg', '.jpg', '.png'))],
            key=lambda x: int(x.split('.')[0])
        )
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.folder_path, img_name)
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, img_name # Return image tensor and filename
        except Exception as e:
            print(f"Error loading/processing image {img_path}: {e}")
            # Return a placeholder or skip; here returning None to be handled later
            return None, img_name

test_dataset = TestImageDataset(TEST_IMAGE_FOLDER, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE_PRED, shuffle=False, num_workers=2)

In [11]:
# CELL 6: Generate Predictions
predictions_list = []

if not test_dataset.image_files:
    print(f"No image files found in {TEST_IMAGE_FOLDER}. Please check the path and extraction.")
else:
    print(f"Starting predictions for {len(test_dataset.image_files)} images...")
    with torch.no_grad(): # Disable gradient calculations
        for images, filenames in tqdm(test_loader, desc="Predicting"):
            if images is None: # Handle potential errors from dataset loading
                for fname in filenames: # If images is None, filenames might be a single erroring filename
                    predictions_list.append({'id': fname, 'label': 'PreprocessingError'})
                continue

            images = images.to(device)
            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted_indices = torch.max(probabilities, 1)

            predicted_class_names_batch = [CLASS_NAMES[idx.item()] for idx in predicted_indices]

            for fname, pred_label in zip(filenames, predicted_class_names_batch):
                predictions_list.append({'id': fname, 'label': pred_label})

    print(f"Finished predictions. {len(predictions_list)} predictions made.")

# Create DataFrame and save to CSV
if predictions_list:
    predictions_df = pd.DataFrame(predictions_list)
    # Ensure DataFrame is sorted by ID if DataLoader might have reordered (though shuffle=False)
    # Convert 'id' to integer part for robust sorting if filenames are like '0.jpeg', '10.jpeg'
    predictions_df['sort_key'] = predictions_df['id'].apply(lambda x: int(x.split('.')[0]))
    predictions_df = predictions_df.sort_values(by='sort_key').drop(columns=['sort_key'])

    predictions_df.to_csv(CSV_OUTPUT_PATH, index=False)
    print(f"Predictions saved to {CSV_OUTPUT_PATH}")

    # Display first few predictions
    print("\nFirst 5 predictions:")
    print(predictions_df.head())
else:
    print("No predictions were made.")

Starting predictions for 838 images...


Predicting: 100%|██████████| 27/27 [00:03<00:00,  8.78it/s]

Finished predictions. 838 predictions made.
Predictions saved to /content/drive/MyDrive/predictions_deit.csv

First 5 predictions:
       id    label
0  0.jpeg  Healthy
1  1.jpeg  Healthy
2  2.jpeg  Healthy
3  3.jpeg  Healthy
4  4.jpeg  Healthy



