In [None]:
"""
Imports for building a deep learning model using PyTorch, Torchvision, and other essential libraries.

- PyTorch: Provides deep learning functionality, including neural networks, optimization, and custom datasets.
- Torchvision: Contains utilities for vision-based tasks, including datasets and image transformations.
- Image handling and visualization: Libraries for handling and displaying images, including OpenCV.
- Data manipulation: Libraries for handling data, file processing, and randomization.
"""

# PyTorch core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset  # Dataset for custom dataset creation

# Torchvision imports for vision-based tasks
import torchvision
from torchvision import transforms, models, datasets

# Image handling and visualization
import matplotlib.pyplot as plt
from PIL import Image
import cv2  # OpenCV for image processing

# Data manipulation and file handling
import numpy as np
import pandas as pd
from collections import Counter  # Counting utility for analyzing data
import glob  # File path handling
import os  # Operating system interface for directory management

# Randomization
from random import shuffle, seed  # Random shuffling and seeding for reproducibility

In [None]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
torch.cuda.empty_cache()  # Free unused memory

In [None]:
# Directories
TRAIN_DIR = './data/data_256'
TEST_DIR = './data/test_256'
VAL_DIR = './data/val_256'

In [None]:
def count_files_in_dir(directory):
    '''
    Counts the number of files in each subdirectory of the given directory,
    returning a DataFrame with folder names, subfolder names, and their image counts.

    Type 'quit' to exit, or 'default' to display all folders.

    Args:
        directory (str): The path to the main directory containing subfolders.

    Returns:
        df (pd.DataFrame): DataFrame containing folder names, subfolder names, and image counts.
    '''
    data = []

    # Iterate through subdirectories (folders)
    for folder in os.listdir(directory):
        folder_path = os.path.join(directory, folder)

        # Check if it's a directory (i.e., a folder)
        if os.path.isdir(folder_path):
            # Iterate through subdirectories (subfolders)
            for subdir in os.listdir(folder_path):
                subdir_path = os.path.join(folder_path, subdir)

                # Check if it's a directory (i.e., a subfolder)
                if os.path.isdir(subdir_path):
                    # Count number of files in the subdirectory
                    file_count = len(os.listdir(subdir_path))
                    data.append({'Folder': folder, 'Subfolder': subdir, 'Image Count': file_count})

    # Create a DataFrame from the collected data
    df = pd.DataFrame(data)

    user_input = input("Enter your choice: ").strip().lower()

    if user_input == 'quit':
        print("Exiting the program.")
        return None  # or you can raise an exception or return a specific value if needed
    elif user_input == 'default':
        pd.set_option('display.max_rows', None)  # Show all rows in DataFrame
        return df
    elif user_input.isalpha() and len(user_input) == 1:
        # Display the selected folder's subfolders
        filtered_df = df[df['Folder'].str.lower() == user_input]  # Filter by folder
        if not filtered_df.empty:
            pd.set_option('display.max_rows', None)  # Show all rows in DataFrame
            return filtered_df
        else:
            print(f"No subfolders found for folder: {user_input}")
            return None  # or handle this case as needed
    else:
        print("Invalid input. Please enter a valid folder letter or 'quit'.")
        return None  # or handle this case as needed

In [None]:
# Count files in train and test directories
train_class_counts = count_files_in_dir(TRAIN_DIR)
train_class_counts

In [None]:
from torch.utils.data import ConcatDataset

class CustomDataset(Dataset):
    '''
    TODO: docstring
    '''
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        
        # Load all images from the directory
        for filename in os.listdir(root_dir):
            if filename.endswith(('.png', '.jpg', '.jpeg')):  # Adjust extensions as needed
                self.images.append(os.path.join(root_dir, filename))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert("RGB")  # Ensure image is in RGB format

        if self.transform:
            image = self.transform(image)

        return image

def load_preprocess_resnet50(train_dir=TRAIN_DIR, val_dir=VAL_DIR, test_dir=TEST_DIR, batch_size=32):
    '''
    Load and preprocess the data for ResNet50 by iterating through directories
    a-z in the train_dir and applying ImageFolder to each one. Also loads 
    validation and test data.

    Args:
        train_dir (str): Path to the training data directory.
        val_dir (str): Path to the validation data directory.
        test_dir (str): Path to the test data directory.
        batch_size (int): The batch size for the DataLoader.

    Returns:
        tuple: DataLoader for training, validation, and test datasets, and list of class names.
    '''

    # Define the transformations for ResNet50
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize images to 224x224
        transforms.ToTensor(),  # Convert images to PyTorch tensors
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet stats
    ])

    def create_dataset_from_folders(parent_dir):
        """
        Iterates through directories (a-z) at the same level in parent_dir 
        and applies ImageFolder to each directory.
        """
        datasets_list = []
        class_names = set()  # Use a set to avoid duplicates
        for folder in os.listdir(parent_dir):  # Loop through folders a-z
            folder_path = os.path.join(parent_dir, folder)

            if os.path.isdir(folder_path):
                # Apply ImageFolder on each folder (a, b, c, etc.)
                folder_dataset = datasets.ImageFolder(root=folder_path, transform=transform)
                datasets_list.append(folder_dataset)
                # Collect class names from this folder's dataset
                class_names.update(folder_dataset.classes)

        return ConcatDataset(datasets_list), sorted(list(class_names))  # Concatenate all folder datasets and return class names

    # Create the training dataset and get the unique class names
    train_dataset, class_names = create_dataset_from_folders(train_dir)

    # Create DataLoader for the training dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Create DataLoader for the validation dataset using CustomDataset
    val_dataset = CustomDataset(root_dir=val_dir, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Create DataLoader for the test dataset using CustomDataset
    test_dataset = CustomDataset(root_dir=test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, class_names

In [None]:
class ResNet50_CNN(nn.Module):
    '''
    Custom ResNet50 model for image classification with the option to add custom layers.
    '''
    
    def __init__(self, num_classes, custom_layers=None):
        super(ResNet50_CNN, self).__init__()
        
        # Load the pre-trained ResNet50 model
        self.resnet50 = models.resnet50(pretrained=True)

        # Remove the last fully connected layer
        self.resnet50.fc = nn.Identity()  # Use Identity layer to keep the output from the last block

        # Custom layers can be added if specified
        self.custom_layers = custom_layers if custom_layers is not None else nn.Identity()  # Identity if no custom layers

        # Fully connected layer for classification
        self.fc = nn.Linear(2048, num_classes)  # Input features are now 2048

    def forward(self, x):
        x = self.resnet50(x)  # Forward pass through ResNet50
        x = self.custom_layers(x)  # Forward pass through custom layers (if any)
        x = self.fc(x)  # Final classification layer
        return x

In [None]:
# Check for GPU and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def train_val_cnn(num_epochs=5, learning_rate=0.001):
   pass

# Call the training function
train_val_cnn()

# TODO: SOTA METHOD