In [10]:
# Run necessary imports
import matplotlib.pyplot as plt
import torch

from src.data_loading.torch_data_utils import load_data_with_logging
from src.models.residual_preprocess import PreprocessingResidual
from src.models.resnet50 import ResNet
from src.prep_and_processing.processing_utils import (
    apply_smote,
    tensor_to_numpy,
    reshape_for_vit,
    normalize_smote_output,
    numpy_to_tensor
)
from src.models.ViT import ViTModel

# Load Data
This cell loads the data and prints the shape of training and validation batches.

In [2]:
# Load data
data_dir = "data/KAU"
dataloaders = load_data_with_logging(data_dir)

# Print the shape of the training data
images, labels = next(iter(dataloaders['train']))
print(f"Training batch shape: {images.shape}")

# Print the shape of the validation data
images, labels = next(iter(dataloaders['val']))
print(f"Validation batch shape: {images.shape}")

Training batch shape: torch.Size([32, 3, 224, 224])
Validation batch shape: torch.Size([32, 3, 224, 224])


# Visualize Class Images
This cell plots one image from each class in the training dataset.

In [None]:
# Get class names
class_names = dataloaders['train'].dataset.classes

# Plot one image from each class
fig, axes = plt.subplots(1, len(class_names), figsize=(15, 5))

for i, class_name in enumerate(class_names):
    found = False
    for images, labels in dataloaders['train']:
        mask = (labels == i).nonzero(as_tuple=True)[0]
        if mask.numel() > 0:
            idx = mask[0]
            image = images[idx].permute(1, 2, 0).numpy()
            image = (image - image.min()) / (image.max() - image.min())  # Normalize to [0, 1]
            axes[i].imshow(image)
            axes[i].set_title(class_name)
            axes[i].axis('off')
            found = True
            break
    if not found:
        axes[i].set_title(f"{class_name}\n(Not Found)")
        axes[i].axis('off')
plt.tight_layout()
plt.show()

# Process Data Using PreprocessingResidual
This cell initializes the PreprocessingResidual class, processes the training images, applies a linear projection, and prints the shape of the projected output.

In [4]:
# Initialize the PreprocessingResidual stack
cnn_processor = PreprocessingResidual()

# Process the data using the Residual processor
images, labels = next(iter(dataloaders['train']))
processed_imgs = cnn_processor(images)

# Apply linear projection
projected_output = PreprocessingResidual.linear_projection(processed_imgs.size(1), processed_imgs)

# Print the shape of the projected output
print(f"Projected output shape: {projected_output.shape}")

Projected output shape: torch.Size([32, 224, 224, 3])


# Feature Extraction Using ResNet Layer
This cell uses ResNet layer to preprocess images to better extract features for further processing.

In [5]:
featureExtractor = ResNet().get_model()

# Permute needed for [B, 3,224,224] shape model expects
projected_output = projected_output.permute(0, 3, 1, 2)

# Extract features from the processed images
extracted_features = featureExtractor(projected_output)

# Perform global average pooling
pooled_features = ResNet.global_average_pooling(extracted_features)

# Print the shape of extracted features
print(f"Extracted features shape: {extracted_features.shape}")

# Print the shape of pooled features
print(f"Pooled features shape: {pooled_features.shape}")

Extracted features shape: torch.Size([32, 2048, 7, 7])
Pooled features shape: torch.Size([32, 2048])


# Apply SMOTE to Oversample Features and Labels
This cell applies the SMOTE function to oversample features and labels, and prints their shapes.

In [6]:
pooled_features_np = tensor_to_numpy(pooled_features)
labels_np = tensor_to_numpy(labels)

# Apply SMOTE to oversample features and labels
smote_features, smote_labels = apply_smote(pooled_features_np, labels_np)

# Print the shape of the oversampled features and labels
print(f"Oversampled features shape: {smote_features.shape}")
print(f"Oversampled labels shape: {smote_labels.shape}")

# Normalize the SMOTE output
normalized_features, normalized_labels = normalize_smote_output(smote_features, smote_labels)

# Store the augmented arrays
augmented_arrays = (normalized_features, normalized_labels)

# Print the shape of the normalized features
print(f"Normalized features shape: {normalized_features.shape}")

# Convert augmented arrays to tensors
augmented_features = numpy_to_tensor(normalized_features)
augmented_labels = numpy_to_tensor(normalized_labels)

# Print the shape of the augmented tensors
print(f"Augmented features tensor shape: {augmented_features.shape}")
print(f"Augmented labels tensor shape: {augmented_labels.shape}")


Oversampled features shape: (36, 2048)
Oversampled labels shape: (36,)
Normalized features shape: (36, 2048)
Augmented features tensor shape: torch.Size([36, 2048])
Augmented labels tensor shape: torch.Size([36])


# Explanation of Batch Size Change
The batch size changed from 32 to 44 after reshaping for ViT due to the SMOTE oversampling process.

SMOTE generates synthetic samples to balance the dataset by oversampling minority classes. This increases the total number of samples in the batch, resulting in a larger batch size.

In this case, the original batch size of 32 was augmented with additional samples, leading to a new batch size of 44. The reshaping operation reflects this updated batch size.

In [7]:
# Reshape augmented features for ViT
vit_ready_features = reshape_for_vit(augmented_features)
print(f"ViT-ready features shape: {vit_ready_features.shape}")

ViT-ready features shape: torch.Size([36, 16, 16, 8])


In [11]:
# Initialize the ViT model
vit_model = ViTModel(num_classes=2)  # Assuming binary classification

# Perform inference on ViT-ready features
vit_predictions = vit_model.inference(vit_ready_features)

# Print the shape of the predictions
print(f"ViT predictions shape: {vit_predictions.shape}")

KeyboardInterrupt: 