# sift cnn descriptors pytorch (from replit)

In [None]:
import os
import time
import requests
import scipy.io
import numpy as np
import cv2
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset

# Helper class for custom dataset
class CustomImageDataset(Dataset):

    def __init__(self, image_labels, image_dir, transform=None):
        self.image_labels = image_labels
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, f"{idx+1}.png")
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        label = self.image_labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


# Load labels
def load_labels(mat_file):
    print("Loading labels...")
    start_time = time.time()
    labels = scipy.io.loadmat(mat_file)['labels'].flatten()
    end_time = time.time()
    print(f"Labels loading complete in {end_time - start_time:.2f} seconds.")
    return labels


# Preprocess images, extract SIFT features, and visualize
def preprocess_images_and_extract_sift(image_labels, image_dir):
    print("\nStarting preprocessing and SIFT extraction...")
    sift = cv2.SIFT_create()

    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])

    dataset = CustomImageDataset(image_labels, image_dir, transform)

    # Visualize initial images
    print("Visualizing initial images...")
    for i in range(5):
        img, _ = dataset[i]
        plt.imshow(img.squeeze(), cmap='gray')
        plt.show()

    # Extract SIFT features
    sift_descriptors = []
    for idx, (img, _) in enumerate(dataset):
        img_np = img.squeeze().numpy().astype(np.uint8)
        keypoints, descriptors = sift.detectAndCompute(img_np, None)
        sift_descriptors.append((keypoints, descriptors))

        # Progress update
        if (idx + 1) % 1000 == 0:
            print(f"Processed {idx + 1} images...")

    print("SIFT feature extraction complete.")

    # Visualize images with descriptors and keypoints
    print("Visualizing images with descriptors and keypoints...")
    for i in range(5):
        img_np = dataset[i][0].squeeze().numpy().astype(np.uint8)
        keypoints, _ = sift_descriptors[i]
        img_with_keypoints = cv2.drawKeypoints(img_np, keypoints, None)
        plt.imshow(img_with_keypoints, cmap='gray')
        plt.show()

    return sift_descriptors


# Setup PyTorch model
class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 128),
            nn.ReLU(),
            nn.Linear(128, 10)  # assuming 10 classes for this example
        )

    def forward(self, x):
        return self.model(x)


# Load and preprocess images
def setup_data_directories():
    print("Setting up data directories...")
    # Create directories if they don't exist
    Path("cells").mkdir(exist_ok=True)
    return True


# Load labels
def download_file(url, local_path):
    print(f"Downloading {local_path}...")
    response = requests.get(url)
    if response.status_code == 200:
        with open(local_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded {local_path}")
        return True
    else:
        print(f"Failed to download {local_path}")
        return False


# Download labels
def download_dataset(base_url, image_dir, mat_file):
    print("Starting dataset download...")
    start_time = time.time()

    # Setup directories
    setup_data_directories()

    # Download labels.mat
    labels_url = f"{base_url}/{mat_file}"
    if not download_file(labels_url, mat_file):
        return False

    # Download images
    success_count = 0
    for i in range(1, 10001):  # Assuming 10000 images
        image_url = f"{base_url}/{image_dir}/{i}.png"
        local_path = f"{image_dir}/{i}.png"

        if not os.path.exists(
                local_path):  # Only download if file doesn't exist
            if download_file(image_url, local_path):
                success_count += 1
        else:
            success_count += 1

        if i % 1000 == 0:
            print(f"Downloaded {i} images...")

    end_time = time.time()
    print(f"Dataset download completed in {end_time - start_time:.2f} seconds")
    print(f"Successfully downloaded {success_count} images")
    return True

# main function
def main():
    # GitHub raw content base URL
    base_url = "https://raw.githubusercontent.com/RyanS974/RyanS974/main/datasets/hep2cells/"
    image_dir = 'cells'
    mat_file = 'labels.mat'

    # Download dataset
    if not download_dataset(base_url, image_dir, mat_file):
        print("Failed to download dataset")
        return

    # Step 1: Load labels
    labels = load_labels(mat_file)

    # Step 2: Preprocess and extract SIFT features
    descriptors = preprocess_images_and_extract_sift(labels, image_dir)

    # Step 3: Split the data
    train_labels, temp_labels = train_test_split(labels,
                                                 test_size=0.3,
                                                 random_state=42)
    val_labels, test_labels = train_test_split(temp_labels,
                                               test_size=0.5,
                                               random_state=42)

    print(
        f"Dataset split into train ({len(train_labels)}), "
        f"validation ({len(val_labels)}), and test ({len(test_labels)}) sets.")

    # Step 4: Define CNN model
    model = SimpleCNN()

    print("Model defined.")

    # Further steps, defining loss, optimizer, and starting training and evaluation.


if __name__ == "__main__":
    main()
