In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from vit_pytorch import ViT
from torchvision.datasets.folder import is_image_file

In [None]:
# Assuming `crops` and directories (`train_dir`, `val_dir`, `test_dir`) are defined
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Data Preprocessing

In [None]:
# Define main directories
base_dir = '/Users/izzymohamed/Desktop/Vision For Social Good/Project/Vision-For-Social-Good/DATA' 
crop_root = os.path.join(base_dir, 'color') # color tester
split_root = os.path.join(base_dir, 'split')

In [None]:
# Load CSV data
csv_path = os.path.join(base_dir, 'plant_disease_multimodal_dataset.csv')  # '/Users/izzymohamed/Desktop/Vision For Social Good/Project/Vision-For-Social-Good/DATA/plant_disease_multimodal_dataset.csv'
csv_data = pd.read_csv(csv_path)

In [None]:
# Load CSV data
csv_data = pd.read_csv(csv_path)
csv_image_paths = csv_data['Image Path'].values
csv_labels = csv_data['Mapped Label'].values
csv_features = csv_data.drop(columns=['Image Path', 'Mapped Label', 'Label']).values.astype(np.float32)

In [None]:
# Define transforms
image_size = 224
data_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Define custom dataset
class CustomDataset(Dataset):
    def __init__(self, image_paths, csv_features, labels, transform=None):
        self.image_paths = image_paths
        self.csv_features = csv_features
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        csv_feature = self.csv_features[idx]
        label = self.labels[idx]
        return image, csv_feature, label

In [None]:
# Create dataset and dataloader
dataset = CustomDataset(csv_image_paths, csv_features, csv_labels, transform=data_transforms)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
# Define models
cnn_models = {
    'InceptionV3': models.inception_v3(pretrained=True).to(device),
    'ResNet152': models.resnet152(pretrained=True).to(device),
    'VGG19': models.vgg19(pretrained=True).to(device),
    'ViT': ViT(
        image_size=image_size,
        patch_size=16,
        num_classes=1000,
        dim=1024,
        depth=6,
        heads=16,
        mlp_dim=2048,
        dropout=0.1,
        emb_dropout=0.1
    ).to(device)
}

# Disable auxiliary logits for InceptionV3
if 'InceptionV3' in cnn_models:
    cnn_models['InceptionV3'].aux_logits = False

In [None]:
# Function to extract features
def extract_features(model, dataloader, device, feature_size, save_path):
    model.eval()
    features = []
    csv_features = []
    labels = []

    with torch.no_grad():
        for i, (images, csv_data, label) in enumerate(dataloader):
            images = images.to(device)
            outputs = model(images)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            features.append(outputs.cpu().numpy())
            csv_features.append(csv_data.numpy())
            labels.append(label.numpy())

    features = np.concatenate(features, axis=0)
    csv_features = np.concatenate(csv_features, axis=0)
    labels = np.concatenate(labels, axis=0)

    np.save(os.path.join(save_path, f'{model.__class__.__name__}_features.npy'), features)
    np.save(os.path.join(save_path, 'csv_features.npy'), csv_features)
    np.save(os.path.join(save_path, 'labels.npy'), labels)

In [None]:
# Extract and save features for each model
feature_save_path = '/path/to/save_features'
if not os.path.exists(feature_save_path):
    os.makedirs(feature_save_path)

In [None]:
for model_name, model in cnn_models.items():
    if model_name == 'ViT':
        # Special handling for ViT if needed
        pass
    else:
        feature_size = model.fc.in_features if hasattr(model, 'fc') else model.classifier[6].in_features
        model.fc = nn.Identity() if hasattr(model, 'fc') else model.classifier[6] = nn.Identity()
        extract_features(model, dataloader, device, feature_size, feature_save_path)