In [None]:
# 1. Define the CNN model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import pandas as pd
import shutil

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 2)  # Output 2 classes: Good (1) and Bad (0)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 2. Load the model and move it to the appropriate device
model = SimpleCNN()
model_path = 'filtering_model2.pth'
model.load_state_dict(torch.load(model_path))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 3. Define a custom dataset for the filtered images
class FilteredImagesDataset(Dataset):
    def __init__(self, img_names, img_dir, transform=None):
        self.img_names = img_names
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, img_name

# 4. Define the transform for the test dataset
test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# 5. Load image names from the input directory
input_dir = 'brispi/top_images_retrieved'  # Set the input directory path
output_good_dir = 'brispi/good_top_images'  # Set the output directory path for good images
output_bad_dir = 'brispi/bad_top_images'  # Set the output directory path for bad images

if not os.path.exists(output_good_dir):
    os.makedirs(output_good_dir)
if not os.path.exists(output_bad_dir):
    os.makedirs(output_bad_dir)

img_names = [img for img in os.listdir(input_dir) if img.lower().endswith('.jpg')]

# 6. Create the test dataset and dataloader
test_dataset = FilteredImagesDataset(img_names, input_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 7. Function to test the model, move or copy images, and save predictions to a CSV file
def classify_and_save_predictions(model, test_loader, good_dir, bad_dir, output_csv, action='move'):
    model.eval()
    all_preds = []
    all_img_names = []
    with torch.no_grad():
        for images, img_names in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_img_names.extend(img_names)
            for img_name, pred in zip(img_names, preds):
                src_path = os.path.join(input_dir, img_name)
                if pred == 1:
                    dest_path = os.path.join(good_dir, img_name)
                else:
                    dest_path = os.path.join(bad_dir, img_name)
                if action == 'move':
                    shutil.move(src_path, dest_path)
                elif action == 'copy':
                    shutil.copy(src_path, dest_path)
    
    # Save the results to a CSV file
    results = pd.DataFrame({'Image': all_img_names, 'Predicted Label': ["Good" if pred == 1 else "Bad" for pred in all_preds]})
    results.to_csv(output_csv, index=False)

# Set action to 'move' or 'copy'
action = 'copy'  # Choose between 'move' or 'copy'

# Classify images, move or copy them to appropriate folders, and save predictions
output_csv = 'image_quality.csv'
classify_and_save_predictions(model, test_loader, output_good_dir, output_bad_dir, output_csv, action=action)

print(f"Predictions saved to {output_csv}")

# 8. Visualize Predictions
import matplotlib.pyplot as plt

# Load the CSV file
predictions = pd.read_csv(output_csv)

# Function to select a sample of images
def select_sample(df, label, sample_size=1):
    filtered_df = df[df['Predicted Label'] == label]
    return filtered_df.sample(n=sample_size, random_state=42)['Image'].tolist()

# Select samples
sample_good = select_sample(predictions, "Good")
sample_bad = select_sample(predictions, "Bad")

# Function to plot images
def plot_images(image_list, img_dir, title):
    plt.figure(figsize=(15, 15))
    for i, img_name in enumerate(image_list):
        img_path = os.path.join(img_dir, img_name)
        img = Image.open(img_path)
        plt.subplot(10, 10, i + 1)
        plt.imshow(img)
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

# Plot the images
plot_images(sample_good, output_good_dir, 'Sample of 1 Good Images')
plot_images(sample_bad, output_bad_dir, 'Sample of 1 Bad Images')
