In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import os
from PIL import Image
import math
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import matplotlib.pyplot as plt
import sys

# --- 1. Custom Dataset Definition ---
class MicroplasticDataset(Dataset):
    """Custom PyTorch Dataset for loading the microplastic hologram images."""
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the subfolders of images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # The paper specifies 6 classes: 0, 1, 2, 3, 4, and >=5 MPs
        # We will map the count from filenames to these class indices.
        self.class_mapping = {0:0, 1:1, 2:2, 3:3, 4:4} # Counts >=5 will map to class 5

        for subdir in os.listdir(root_dir):
            subdir_path = os.path.join(root_dir, subdir)
            if not os.path.isdir(subdir_path):
                continue
            
            for filename in os.listdir(subdir_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif')):
                    try:
                        # Get the string part of the filename before the extension (e.g., "11" from "11.jpg")
                        filename_base = os.path.splitext(filename)[0]
                        # The label is the FIRST DIGIT of this number.
                        plastic_count = int(filename_base[0])
                        
                        # Map count to class index (0-5)
                        label = self.class_mapping.get(plastic_count, 5) # Default to class 5 for counts >= 5
                        
                        self.image_paths.append(os.path.join(subdir_path, filename))
                        self.labels.append(label)
                    except (ValueError, IndexError):
                        print(f"Warning: Could not parse label from filename: {filename}. Skipping.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        # Open image and convert to grayscale ('L')
        image = Image.open(img_path).convert('L')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# --- 2. Model and Loss Function (Unchanged) ---
class HCCNN(nn.Module):
    """
    Implementation of the Holographic-Classifier Convolutional Neural Network (HC-CNN)
    as described in the paper "Microplastic pollution monitoring with holographic
    classification and deep learning" (Zhu et al., 2021).
    """
    def __init__(self, num_classes=6):
        super(HCCNN, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(), nn.Dropout(0.5),
            nn.Linear(16384, 1024), nn.ReLU(True),
            nn.Linear(1024, 512), nn.ReLU(True),
            nn.Linear(512, 256), nn.ReLU(True),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        x = self.feature_extractor(x)
        return self.classifier(x)

class CCQLLoss(nn.Module):
    """Implementation of the Correct-Class Quadratic Loss (L_CCQL)."""
    def __init__(self, num_classes=6):
        super(CCQLLoss, self).__init__()
        self.num_classes = num_classes
        self.alpha = math.sqrt(num_classes - 1) - 1
        self.mse_loss = nn.MSELoss(reduction='mean')
    def forward(self, logits, targets):
        softmax_preds = torch.softmax(logits, dim=1)
        one_hot_targets = nn.functional.one_hot(targets, num_classes=self.num_classes).float()
        mse_term = self.mse_loss(softmax_preds, one_hot_targets)
        correct_class_probs = softmax_preds[range(len(targets)), targets]
        alpha_term = self.alpha * torch.mean((1.0 - correct_class_probs) ** 2)
        return mse_term + alpha_term

# --- 3. Training and Evaluation Functions ---
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    """The main training loop."""
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data)
            total_samples += labels.size(0)
            
        epoch_loss = running_loss / total_samples
        epoch_acc = correct_predictions.double() / total_samples
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += torch.sum(preds == labels.data)
                val_total += labels.size(0)
                
        val_epoch_loss = val_loss / val_total
        val_epoch_acc = val_correct.double() / val_total
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_epoch_acc.item())
        
        print(f'Epoch {epoch+1}/{num_epochs} | '
              f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
              f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')
              
    return history

def evaluate_model(model, test_loader, device):
    """Evaluate the model on the test set."""
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    
    print("\n--- Test Set Evaluation ---")
    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1-Score: {f1:.4f}')


# --- 4. Main Execution Block ---
if __name__ == '__main__':
    # --- Configuration ---
    # !! IMPORTANT !! Change this to your dataset path in Kaggle
    DATA_DIR = '/kaggle/input/microplastics-data/micro_plastic'
    
    NUM_CLASSES = 6
    LEARNING_RATE = 0.0001
    BATCH_SIZE = 32 # Increased for better GPU utilization
    NUM_EPOCHS = 100 # As per paper, training is stable around 100-130 iterations/epochs
    IMG_SIZE = 128
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Data Transformations ---
    # No data augmentation is applied, only resizing, tensor conversion, and normalization.
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]) # Normalize for grayscale
        ]),
        'test': transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ]),
    }

    # --- Dataset and Dataloaders ---
    print("\nLoading and splitting dataset...")

    # Add validation checks to ensure the data path is correct and images are found
    if not os.path.isdir(DATA_DIR):
        print(f"--- !!! ERROR !!! ---", file=sys.stderr)
        print(f"The specified DATA_DIR does not exist or is not a directory: '{DATA_DIR}'", file=sys.stderr)
        print("Please check the path. In a Kaggle notebook, you can find your data path in the 'Input' section on the right panel.", file=sys.stderr)
        sys.exit(1) # Exit the script

    full_dataset = MicroplasticDataset(root_dir=DATA_DIR, transform=data_transforms['train'])
    
    dataset_size = len(full_dataset)
    if dataset_size == 0:
        print(f"--- !!! ERROR !!! ---", file=sys.stderr)
        print(f"No images were found in '{DATA_DIR}'.", file=sys.stderr)
        print("Please check that the directory contains subfolders with your images and that the filenames start with a number (e.g., '0_image.png').", file=sys.stderr)
        sys.exit(1) # Exit the script

    # Splitting data 8:1:1
    train_size = int(0.8 * dataset_size)
    val_size = int(0.1 * dataset_size)
    test_size = dataset_size - train_size - val_size

    # Handle cases where the dataset is too small for a 10% split
    if val_size == 0 and dataset_size > 0:
        val_size = 1
        train_size -= 1
    if test_size == 0 and dataset_size > 1:
        test_size = 1
        train_size -=1
    
    train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])
    
    # The training transform now has no augmentation, so we can use it for all sets.
    # We no longer need to apply a different transform for the validation and test sets.
    val_dataset.dataset.transform = data_transforms['test']
    test_dataset.dataset.transform = data_transforms['test']

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    print(f"Total images: {dataset_size}")
    print(f"Training set size: {len(train_dataset)}")
    print(f"Validation set size: {len(val_dataset)}")
    print(f"Test set size: {len(test_dataset)}")

    # --- Model, Loss, Optimizer ---
    print("\nInitializing Model, Loss, and Optimizer...")
    model = HCCNN(num_classes=NUM_CLASSES).to(device)
    criterion = CCQLLoss(num_classes=NUM_CLASSES)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # --- Training ---
    print("\nStarting training...")
    history = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, device)

    # --- Evaluation ---
    evaluate_model(model, test_loader, device)
    
    # --- Plotting Results ---
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

    # Save the trained model
    torch.save(model.state_dict(), 'hccnn_microplastic_model.pth')
    print("\nTrained model saved to hccnn_microplastic_model.pth")



In [None]:
# --- 3. Main Execution Block ---
# ... (code before) ...

    # !! IMPORTANT !! Path to the new image you want to classify.
    # CHANGE THIS LINE from the placeholder:
    IMAGE_TO_PREDICT = 'path/to/your/image.jpg' 

    # TO THIS (using the example path):
    IMAGE_TO_PREDICT = '/kaggle/input/microplastics-data/micro_plastic/pe/1_some_image.jpg'

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import os

# --- 1. Re-define the Model Architecture ---
# This MUST be the exact same architecture as the one used for training.
class HCCNN(nn.Module):
    """
    Implementation of the Holographic-Classifier Convolutional Neural Network (HC-CNN)
    as described in the paper "Microplastic pollution monitoring with holographic
    classification and deep learning" (Zhu et al., 2021).
    """
    def __init__(self, num_classes=6):
        super(HCCNN, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(), nn.Dropout(0.5),
            nn.Linear(16384, 1024), nn.ReLU(True),
            nn.Linear(1024, 512), nn.ReLU(True),
            nn.Linear(512, 256), nn.ReLU(True),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        x = self.feature_extractor(x)
        return self.classifier(x)

# --- 2. Prediction Function ---
def predict_image(model, image_path, device):
    """Loads an image, preprocesses it, and makes a prediction."""
    
    # Define the same transformations as used for the validation/test set
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]) # Normalize for grayscale
    ])
    
    # Load and preprocess the image
    try:
        image = Image.open(image_path).convert('L') # Convert to grayscale
        image_tensor = transform(image).unsqueeze(0) # Add a batch dimension
    except FileNotFoundError:
        print(f"Error: The image file was not found at {image_path}")
        return None

    # Move the model and tensor to the same device
    model.to(device)
    image_tensor = image_tensor.to(device)
    
    # Set the model to evaluation mode
    model.eval()
    
    # Make a prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted_class = torch.max(outputs, 1)
        
    return predicted_class.item(), probabilities.cpu().numpy().flatten()

# --- 3. Main Execution Block ---
if __name__ == '__main__':
    # --- Configuration ---
    # Path to the saved model weights. In Kaggle, this will be in the output directory.
    MODEL_PATH = '/kaggle/working/hccnn_microplastic_model.pth'
    
    # !! IMPORTANT !! Path to the new image you want to classify.
    # You'll need to upload a new image or use one from your test set for this to work.
    # Example: IMAGE_TO_PREDICT = '/kaggle/input/your-dataset/test/some_image.jpg'
    IMAGE_TO_PREDICT = '/kaggle/input/test-img/test_image.jpg' 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # --- Load the Model ---
    # 1. Initialize the model architecture
    model = HCCNN(num_classes=6)
    
    # 2. Load the saved weights (the state dictionary)
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        print("Model loaded successfully from:", MODEL_PATH)
    except FileNotFoundError:
        print(f"Error: Model file not found at {MODEL_PATH}")
        print("Please make sure you have run the training script first and the model was saved.")
        exit() # Exit if model is not found

    # --- Make a Prediction ---
    prediction, probabilities = predict_image(model, IMAGE_TO_PREDICT, device)
    
    if prediction is not None:
        # Define class names for readability
        class_names = {0:'0 MP', 1:'1 MP', 2:'2 MP', 3:'3 MP', 4:'4 MP', 5:'>=5 MP'}
        
        print(f"\nPrediction for image: {os.path.basename(IMAGE_TO_PREDICT)}")
        print(f"Predicted Class: {prediction} ({class_names.get(prediction, 'Unknown')})")
        
        print("\nClass Probabilities:")
        for i, prob in enumerate(probabilities):
            print(f"  - Class {class_names.get(i, 'Unknown')}: {prob:.4f} ({prob*100:.2f}%)")
