# Packages Import

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os
import cv2
import random

from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Project Configs

In [1]:
from dataclasses import dataclass

@dataclass
class Config:
    # Dataset path
    DATASET_DIR: str = 'dataset/'

    # Images paths
    NORMAL_IMAGES_FOLDER: str = 'dataset/Normal/'
    TUBERCULOSIS_IMAGES_FOLDER: str = 'dataset/Tuberculosis'

    # Dataframes paths
    NORMAL_XLSX_PATH: str = 'dataset/Normal.metadata.xlsx'
    TUBERCULOSIS_XLSX_PATH: str = 'dataset/Tuberculosis.metadata.xlsx'

    # Hyperparameters
    BATCH_SIZE: int = 64
    LEARNING_RATE: float = 0.001
    NUM_EPOCHS: int = 10
    MOMENTUM: float = 0.9
    WEIGHT_DECAY: float = 1e-5


# Utils

In [10]:
class Utils:
    @staticmethod
    def trim(im):
        """
        Converts image to grayscale using cv2, then computes binary matrix
        of the pixels that are above a certain threshold, then takes out
        the first row where a certain percentage of the pixels are above the
        threshold will be the first clip point. Same idea for col, max row, max col.
        """
        percentage = 0.02

        img = np.array(im)
        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        im = img_gray > 0.1 * np.mean(img_gray[img_gray != 0])
        row_sums = np.sum(im, axis=1)
        col_sums = np.sum(im, axis=0)
        rows = np.where(row_sums > img.shape[1] * percentage)[0]
        cols = np.where(col_sums > img.shape[0] * percentage)[0]
        min_row, min_col = np.min(rows), np.min(cols)
        max_row, max_col = np.max(rows), np.max(cols)
        im_crop = img[min_row: max_row + 1, min_col: max_col + 1]
        return Image.fromarray(im_crop)

    @staticmethod
    def plot_images_with_labels(data_df, num_images=5, random_seed=42, show_image_mode=True):
        """
        Plot some images with their corresponding labels from the given DataFrame.

        Args:
            data_df (pd.DataFrame): The DataFrame containing 'filepaths' and 'labels' columns.
            num_images (int, optional): Number of images to plot. Defaults to 5.
            random_seed (int, optional): Random seed for reproducibility. Defaults to 42.
            show_image_mode (bool, optional): Whether to show image mode (RGB or not) alongside the labels. 
                                              Defaults to True.
        """
        random.seed(random_seed)
        sampled_data = data_df.sample(n=num_images)

        num_rows = (num_images - 1) // 5 + 1
        num_cols = min(num_images, 5)

        fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
        for idx, (filepath, label) in enumerate(zip(sampled_data['filepaths'], sampled_data['labels'])):
            image = Image.open(filepath)

            row_idx = idx // 5
            col_idx = idx % 5

            axes[row_idx, col_idx].imshow(image)
            axes[row_idx, col_idx].axis('off')

            if show_image_mode:
                is_rgb = image.mode == 'RGB'
                axes[row_idx, col_idx].set_title(f'Label: {label} | RGB: {is_rgb}')
            else:
                axes[row_idx, col_idx].set_title(f'Label: {label}')

        for idx in range(num_images, num_rows * 5):
            row_idx = idx // 5
            col_idx = idx % 5
            fig.delaxes(axes[row_idx, col_idx])

        plt.tight_layout()
        plt.show()

    @staticmethod
    def resize_and_save_images(data_df, save_dir, new_size=(224, 224)):
        """
        Resize images from the DataFrame and save them to the specified directory.

        Args:
            data_df (pd.DataFrame): The DataFrame containing 'filepaths' and 'labels' columns.
            save_dir (str): The directory path where the resized images will be saved.
            new_size (tuple, optional): The new size to which the images will be resized. Defaults to (224, 224).
        """
        # Create the save directories if they don't exist
        normal_save_dir = os.path.join(save_dir, 'Normal')
        tuberculosis_save_dir = os.path.join(save_dir, 'Tuberculosis')
        os.makedirs(normal_save_dir, exist_ok=True)
        os.makedirs(tuberculosis_save_dir, exist_ok=True)

        for filepath, label in tqdm(zip(data_df['filepaths'], data_df['labels'])):
            image = cv2.imread(filepath)
            trimmed_image = np.array(Utils.trim(image))
            resized_image = cv2.resize(trimmed_image, new_size, interpolation=cv2.INTER_CUBIC)

            if label == 'Normal':
                label_save_dir = normal_save_dir
            else:
                label_save_dir = tuberculosis_save_dir

            filename_without_ext = os.path.splitext(os.path.basename(filepath))[0]

            save_filename = f"{filename_without_ext}.png"
            save_path = os.path.join(label_save_dir, save_filename)
            cv2.imwrite(save_path, resized_image)


# Preprocessing

In [11]:
def get_filepaths_and_labels(sdir):
    filepaths = []
    labels = []
    classlist = sorted(os.listdir(sdir))
    for _class in classlist:
        classpath = os.path.join(sdir, _class)
        if os.path.isdir(classpath):
            flist = sorted(os.listdir(classpath))
            for f in tqdm(flist, ncols=130, desc=f'{_class:25s}', unit='files', colour='blue'):
                fpath = os.path.join(classpath, f)
                filepaths.append(fpath)
                labels.append(_class)
    return filepaths, labels

def create_dataframes(filepaths, labels):
    Fseries = pd.Series(filepaths, name='filepaths')
    Lseries = pd.Series(labels, name='labels')
    df = pd.concat([Fseries, Lseries], axis=1)
    return df

def split_data(df):
    train_df, dummy_df = train_test_split(df, train_size=.8, shuffle=True, random_state=123, stratify=df['labels'])
    valid_df, test_df = train_test_split(dummy_df, train_size=.5, shuffle=True, random_state=123, stratify=dummy_df['labels'])
    return train_df, test_df, valid_df

def calculate_average_image_size(df, num_samples=50):
    sample_df = df.sample(n=num_samples, replace=False)
    ht = 0
    wt = 0
    count = 0
    for i in range(len(sample_df)):
        fpath = sample_df['filepaths'].iloc[i]
        try:
            img = cv2.imread(fpath)
            h, w, _ = img.shape
            wt += w
            ht += h
            count += 1
        except:
            pass
    average_height = int(ht / count)
    average_weight = int(wt / count)
    aspect_ratio = average_height / average_weight
    return average_height,average_weight,aspect_ratio


def make_dataframes(sdir):
    filepaths, labels = get_filepaths_and_labels(sdir)
    df = create_dataframes(filepaths, labels)
    train_df, test_df, valid_df = split_data(df)
    average_height, average_weight, aspect_ratio = calculate_average_image_size(train_df)
    
    # Other statistics and information can be printed here if needed.
    class_count = len(train_df['labels'].unique())
    counts = list(train_df['labels'].value_counts())
    
    return train_df, test_df, valid_df, class_count, average_height, average_weight, aspect_ratio


# Dataset

In [12]:
class TBDataset(torch.utils.data.Dataset):
    def __init__(self, data_df, transform=None):
        self.data_df = data_df
        self.transform = transform if transform is not None else self._default_transform()
        self.class_labels = {'Normal': 0, 'Tuberculosis': 1}

    def __repr__(self):
        return f"TBDataset: Number of samples: {len(self)}"

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, index):
        filepath = self.data_df.iloc[index, 0]
        label = self.data_df.iloc[index, 1]

        # Load image using PIL
        image = Image.open(filepath)

        if self.transform:
            image = self.transform(image)

        return image, label
    
    def _default_transform(self):
        return transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])


# DataLoader

In [13]:
def create_dataloader(train_transform,valid_transform,test_transform,config):
    train_df, test_df, valid_df, class_count, average_height, average_weight, aspect_ratio = make_dataframes(config.DATASET_DIR)
    train_dataset = TBDataset(train_df, transform=train_transform)
    valid_dataset = TBDataset(valid_df, transform=valid_transform)
    test_dataset = TBDataset(test_df, transform=test_transform)

    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True
    )

    valid_dataloader = torch.utils.data.DataLoader(
        dataset=valid_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False
    )

    test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False
    )

    return train_dataloader,valid_dataloader,test_dataloader