In [None]:
# Generated by Copilot
# Core imports
import os
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import random
import logging

# Add project root to path
project_root = Path(os.getcwd()).parent
sys.path.append(str(project_root))


from app.db.models import Image,Mask
from app.db.session import SessionLocal
# PyTorch imports
import torch
from torch.utils.data import DataLoader 
from tqdm.notebook import tqdm



# Import project modules
from app.services.ai.Unet import training

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
np.random.seed(seed)
random.seed(seed)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch version: {torch.__version__}")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Set paths
data_path = Path(project_root, "data")
image_path = data_path / "dataset" / "images"
mask_path = data_path / "dataset" / "masks"
model_path = data_path / "models"


In [None]:
# training on all images
from app.db.models import Patch
from app.services.image.patch_service import save_image_as_patches


def get_data_from_db():
    """
    Get images 
    """
    
    session = SessionLocal()    

    db_images = session.query(Image).all()[:20]
    if not db_images:
        logger.warning("No images found in the database.")
        return None
    
    logger.info(f"Total images retrieved: {len(db_images)}")

    # Filter out images where all masks are zero
    filtered_db_images = []
    for img in db_images:
        masks = session.query(Mask).filter(Mask.image_id == img.id).all()
        if masks and not all((np.load(mask.mask_path) == 0).all() for mask in masks):
            filtered_db_images.append(img)
    
    db_images = filtered_db_images
    logger.info(f"Filtered images count: {len(db_images)}")
    
    # convert to patches
    for img in tqdm(db_images, desc="Converting images to patches"):
        save_image_as_patches(img)  

    # get patches ids
    patches_ids = session.query(Patch).all()
    session.close()

    if not patches_ids:
        logger.warning("No patches found in the database.")
        return None

    logger.info(f"Total patches retrieved: {len(patches_ids)}")

    return patches_ids




In [None]:
dataset = get_data_from_db()

if not dataset:
    logger.error("No dataset found. Exiting.")
    sys.exit(1)

In [None]:
training.PatchDataset(dataset)[0]

In [None]:
# TheData = DataLoader(
#     training.PatchDataset(dataset[:100]),
#     batch_size=20,
#     shuffle=True,
#     num_workers=0,  
#     # pin_memory=True if device.type == "cuda" else False
# )

# Create lists to store metrics
train_losses = []
val_losses = []
val_dice_scores = []

# Split data into train and validation
from torch.utils.data import random_split

# Use 80% for training, 20% for validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(
    training.PatchDataset(train_dataset),
    batch_size=32,  # Increased batch size
    shuffle=True,
    num_workers=4,  # Use multiple workers for data loading
    pin_memory=True if torch.cuda.is_available() else False,  # Speed up GPU transfer
    prefetch_factor=2  # Prefetch next batches
)

val_loader = DataLoader(
    training.PatchDataset(val_dataset),
    batch_size=64,  # Can use larger batch size for validation
    shuffle=False,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False
)

# Modified training call with validation
model = training.Training(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    dataloader=train_loader,
    val_loader=val_loader,  # Add validation loader
    lr=0.0001,
    checkpoint_type="best",
    model_path="data/models",
    num_epochs=50,
    early_stopping_patience=10
)

In [None]:
# start training    
# model = training.Training(
#     device= torch.device("cuda" if torch.cuda.is_available() else "cpu"),
#     dataloader=TheData,
#     lr=0.0001,    
#     # checkpoint_type="best",
#     model_path="data/models",
#     # num_epochs=100,
#     # early_stopping_patience=20,
# )

In [None]:
def plot_training_progress(train_losses, val_losses, val_dice_scores):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot losses
    epochs = range(1, len(train_losses) + 1)
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot Dice scores
    ax2.plot(epochs, val_dice_scores, 'g-', label='Validation Dice Score')
    ax2.set_title('Validation Dice Score')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Dice Score')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

# Call this after each epoch or at the end of training
plot_training_progress(train_losses, val_losses, val_dice_scores)

In [None]:

# Get a batch of test data and run inference
model.eval()  # Set to evaluation mode
test_batch = next(iter(val_loader))
images, masks = test_batch

with torch.no_grad():
    predictions = model(images.to(device))
    
# Convert tensors to numpy for plotting
images = images.cpu().numpy()
masks = masks.cpu().numpy()
predictions = predictions.cpu().numpy()

# if value > .5 then 1 else 0
# predictions = (predictions > 0.5).astype(np.uint8)

print(f"Unique values in ground truth masks: {np.unique(masks)}")
print(f"Unique values in predicted masks: {np.unique(predictions)}")

# Plot results for the first few samples
n_samples = min(4, len(images))
fig, axes = plt.subplots(n_samples, 5, figsize=(12, 4*n_samples))
fig.suptitle('Test Results: Input - Ground Truth - Prediction')

for i in range(n_samples):
    # Plot original image
    axes[i, 0].imshow(images[i].transpose(1, 2, 0))
    axes[i, 0].set_title('Input Image')
    axes[i, 0].axis('off')
    
    # Plot ground truth mask
    axes[i, 1].imshow(masks[i].squeeze()[0], cmap='gray')
    axes[i, 1].set_title('Ground Truth channel 1')
    axes[i, 1].axis('off')
    axes[i, 2].imshow(predictions[i][0], cmap='gray')
    axes[i, 2].set_title('Prediction channel 1')
    axes[i, 2].axis('off')
    
    axes[i, 3].imshow(masks[i].squeeze()[1], cmap='gray')
    axes[i, 3].set_title('Ground Truth channel 2')
    axes[i, 3].axis('off')
    # Plot predicted mask
    axes[i, 4].imshow(predictions[i][1], cmap='gray')
    axes[i, 4].set_title('Prediction channel 2')
    axes[i, 4].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import numpy as np
from app.db.models import Patch
from app.db.session import SessionLocal

# Create a new session
session = SessionLocal()

# Query all patches and convert to DataFrame with array paths
patches_df = pd.DataFrame([{
    'id': patch.id,
    'image_id': patch.image_id,
    'img_patch': patch.img_patch,  # Keep as array
    'mask_patch': patch.mask_patch  # Keep as array
} for patch in session.query(Patch).all()])

# Close the session
session.close()

# Display the first few rows and basic information
print("DataFrame Shape:", patches_df.shape)
print("\nDataFrame Info:")
patches_df.info()
print("\nFirst few rows:")
display(patches_df.head())

# Save DataFrame to pickle file
patches_df.to_pickle("data/patches_df.pkl")
print("DataFrame saved to data/patches_df.pkl")

In [None]:
# count the number of patches
patches_df = pd.read_pickle("data/patches_df.pkl")
# to list
patches_df = patches_df.to_dict(orient='records')
num_patches = patches_df.shape[0]
print(f"Number of patches: {num_patches}")

In [None]:
# Start training with improved parameters
model, history = training.Training(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    dataloader=train_loader,
    val_loader=val_loader,
    lr=2e-4,  # Slightly higher learning rate
    model_path="data/models",
    num_epochs=100,  # More epochs since we have better regularization
    early_stopping_patience=15  # More patience with the new scheduler
)

# Plot the training progress
plot_training_progress(
    train_losses=history['train_losses'],
    val_losses=history['val_losses'],
    val_dice_scores=history['val_dice_scores']
)