In [None]:
# Import libraries
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import pandas as pd
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import time
import logging
import psutil
from torch.amp import autocast, GradScaler
from IPython.display import display, Image as IPImage
import torch.nn.functional as F
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR
from torch.optim.lr_scheduler import SequentialLR

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
else:
    logging.warning("No GPU detected. Ensure GPU is enabled in Runtime settings.")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Paths
BASE_PATH = '/content/drive/MyDrive/intel_project'
SHARP_PATH = os.path.join(BASE_PATH, 'sharp')
DEFOCUSED_PATH = os.path.join(BASE_PATH, 'defocused_blurred')
MOTION_PATH = os.path.join(BASE_PATH, 'motion_blurred')
RESULTS_DIR = os.path.join(BASE_PATH, 'results')
os.makedirs(RESULTS_DIR, exist_ok=True)

# Custom Dataset with Fallback for Missing Defocused Images
class ImageTripletDataset(Dataset):
    def __init__(self, sharp_dir, defocused_dir, motion_dir, transform=None, augment_transform=None):
        self.sharp_dir = sharp_dir
        self.defocused_dir = defocused_dir
        self.motion_dir = motion_dir
        self.transform = transform
        self.augment_transform = augment_transform
        self.image_ids = []

        sharp_files = [f for f in os.listdir(sharp_dir) if f.endswith(('.jpg', '.jpeg'))]
        for f in sharp_files:
            if f.rsplit('_', 1)[-1] in ['S.jpg', 'S.jpeg']:
                base_id = f.rsplit('_', 1)[0]
                ext = f.split('.')[-1]
                sharp_path = os.path.join(sharp_dir, f"{base_id}_S.{ext}")
                defocused_path = os.path.join(defocused_dir, f"{base_id}_F.{ext}")
                motion_path = os.path.join(motion_dir, f"{base_id}_M.{ext}")

                # Require only sharp image; use sharp as fallback for defocused if missing
                if os.path.exists(sharp_path):
                    self.image_ids.append((base_id, ext))
                    if not os.path.exists(defocused_path):
                        logging.warning(f"Missing defocused image for {base_id}, using sharp as fallback")
                    if not os.path.exists(motion_path):
                        logging.warning(f"Missing motion image for {base_id}")

        if not self.image_ids:
            raise ValueError("No valid sharp images found.")

        logging.info(f"Found {len(self.image_ids)} valid sharp images for triplets.")

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        base_id, ext = self.image_ids[idx]

        sharp_path = os.path.join(self.sharp_dir, f"{base_id}_S.{ext}")
        defocused_path = os.path.join(self.defocused_dir, f"{base_id}_F.{ext}")
        motion_path = os.path.join(self.motion_dir, f"{base_id}_M.{ext}")

        try:
            sharp_img = Image.open(sharp_path).convert('RGB')
            # Use sharp as fallback if defocused is missing
            if os.path.exists(defocused_path):
                defocused_img = Image.open(defocused_path).convert('RGB')
            else:
                defocused_img = sharp_img.copy()
            motion_img = Image.open(motion_path).convert('RGB') if os.path.exists(motion_path) else sharp_img.copy()
        except Exception as e:
            logging.error(f"Failed to load images for {base_id}: {str(e)}")
            return None

        sharp_img = sharp_img.resize((640, 360), Image.BICUBIC)
        defocused_img = defocused_img.resize((640, 360), Image.BILINEAR)
        motion_img = motion_img.resize((640, 360), Image.BILINEAR)

        if self.augment_transform:
            # Apply augmentation to all images consistently
            seed = torch.randint(0, 100000, (1,)).item()
            torch.manual_seed(seed)
            sharp_img = self.augment_transform(sharp_img)
            torch.manual_seed(seed)
            defocused_img = self.augment_transform(defocused_img)
            torch.manual_seed(seed)
            motion_img = self.augment_transform(motion_img)

        if self.transform:
            sharp_img = self.transform(sharp_img)
            defocused_img = self.transform(defocused_img)
            motion_img = self.transform(motion_img)

            if torch.isnan(sharp_img).any() or torch.isinf(sharp_img).any() or \
               torch.isnan(defocused_img).any() or torch.isinf(defocused_img).any() or \
               torch.isnan(motion_img).any() or torch.isinf(motion_img).any():
                logging.error(f"Invalid tensor values for {base_id}")
                return None

        return sharp_img, defocused_img, motion_img

# Perceptual Loss
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(pretrained=True).features
        self.slice1 = nn.Sequential(*vgg[:4]).eval()
        self.slice2 = nn.Sequential(*vgg[4:9]).eval()
        self.criterion = nn.MSELoss()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        x = (x * 0.5 + 0.5).clamp(0, 1)
        y = (y * 0.5 + 0.5).clamp(0, 1)
        x_s1 = self.slice1(x)
        y_s1 = self.slice1(y)
        x_s2 = self.slice2(x_s1)
        y_s2 = self.slice2(y_s1)
        loss = self.criterion(x_s1, y_s1) + self.criterion(x_s2, y_s2)
        return loss

# EDSR Model
class EDSR(nn.Module):
    def __init__(self, scale_factor=2, pretrained=False):
        super(EDSR, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.res_blocks = nn.Sequential(*[self.make_res_block(64) for _ in range(16)])
        self.conv2 = nn.Conv2d(64, 64 * (scale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.conv3 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

        if pretrained:
            logging.info("Pre-trained EDSR not implemented; using custom initialization.")
        else:
            self._initialize_weights()

    def make_res_block(self, channels):
        return nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = (x * 0.5 + 0.5).clamp(0, 1)  # Denormalize
        residual = x
        x = self.conv1(x)
        x = self.res_blocks(x)
        x = self.conv2(x)
        x = self.pixel_shuffle(x)
        residual = F.interpolate(residual, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.conv3(x)
        x = x + residual
        x = x.clamp(0, 1)
        x = (x - 0.5) / 0.5  # Normalize to [-1, 1]
        return x

# Simplified Student Model
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.res_blocks = nn.Sequential(*[self.make_res_block(64) for _ in range(2)])  # Reduced blocks
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.conv5 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.skip_conv = nn.Conv2d(3, 64, kernel_size=1)
        self.dropout = nn.Dropout(0.1)
        self.relu = nn.ReLU(inplace=True)

    def make_res_block(self, channels):
        return nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        residual = x
        skip = self.skip_conv(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.dropout(x + skip)
        x = self.res_blocks(x)
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.conv5(x)
        x = x + 0.1 * residual
        return x

# Teacher Model
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.model = EDSR(scale_factor=2, pretrained=False)  # Set to True if pre-trained weights available
        self.model.eval()

    def forward(self, x):
        with torch.no_grad():
            outputs = self.model(x)
            if torch.isnan(outputs).any() or torch.isinf(outputs).any():
                logging.warning("NaN/Inf in teacher output, using input as fallback")
                outputs = x
        return outputs

# Transforms
transform = transforms.Compose([
    transforms.Resize((360, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
])

# Dataset and DataLoader
try:
    dataset = ImageTripletDataset(SHARP_PATH, DEFOCUSED_PATH, MOTION_PATH, transform=transform, augment_transform=augment_transform)
except ValueError as e:
    logging.error(str(e))
    exit(1)

filtered_dataset = []
for i in range(len(dataset)):
    sample = dataset[i]
    if sample is not None:
        filtered_dataset.append(sample)
    else:
        logging.warning(f"Skipping invalid sample at index {i}")

if not filtered_dataset:
    logging.error("No valid samples in dataset")
    exit(1)

logging.info(f"Total valid samples after filtering: {len(filtered_dataset)}")
train_size = int(0.8 * len(filtered_dataset))
test_size = len(filtered_dataset) - train_size
logging.info(f"Train size: {train_size}, Test size: {test_size}")
train_dataset, test_dataset = torch.utils.data.random_split(filtered_dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

# Models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

teacher_model = TeacherModel().to(device)
student_model = StudentModel().to(device)
perceptual_loss = PerceptualLoss().to(device)

# Mixed precision scaler
scaler = GradScaler('cuda')

# Loss and Optimizer
criterion_mse = nn.MSELoss()
optimizer = optim.Adam(student_model.parameters(), lr=0.0005)
warmup_scheduler = LinearLR(optimizer, start_factor=0.1, total_iters=5)
main_scheduler = CosineAnnealingLR(optimizer, T_max=70)
scheduler = SequentialLR(optimizer, [warmup_scheduler, main_scheduler], [5])

# Training Loop
def train_model(epochs=75):
    student_model.train()
    teacher_model.eval()
    for epoch in range(epochs):
        running_loss = 0.0
        valid_batches = 0
        for sharp_imgs, defocused_imgs, motion_imgs in train_loader:
            sharp_imgs, defocused_imgs = sharp_imgs.to(device), defocused_imgs.to(device)

            optimizer.zero_grad()

            with autocast('cuda'):
                with torch.no_grad():
                    teacher_out = teacher_model(defocused_imgs)
                    if torch.isnan(teacher_out).any() or torch.isinf(teacher_out).any():
                        logging.warning(f"Invalid teacher output for batch {valid_batches}, skipping")
                        continue
                    teacher_out = F.interpolate(teacher_out, size=(360, 640), mode='bilinear', align_corners=False)

                student_out = student_model(defocused_imgs)
                loss_mse = criterion_mse(student_out, sharp_imgs)
                loss_teacher = criterion_mse(student_out, teacher_out)
                loss_perceptual = perceptual_loss(student_out, sharp_imgs)
                loss = 0.5 * loss_mse + 0.2 * loss_teacher + 0.3 * loss_perceptual

                if torch.isnan(loss_mse) or torch.isnan(loss_teacher) or torch.isnan(loss_perceptual):
                    logging.warning(f"NaN detected: MSE={loss_mse.item()}, Teacher={loss_teacher.item()}, Perceptual={loss_perceptual.item()}")
                    continue

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
            valid_batches += 1

            del teacher_out, student_out
            torch.cuda.empty_cache()

        if valid_batches == 0:
            logging.error("No valid batches processed in epoch")
            return

        avg_loss = running_loss / valid_batches
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        scheduler.step()

    torch.save(student_model.state_dict(), os.path.join(RESULTS_DIR, "student_model.pth"))

# Evaluation
def evaluate_model():
    student_model.eval()
    ssim_scores = []
    psnr_scores = []
    inference_times = []
    sample_images = []

    for i, (sharp_imgs, defocused_imgs, motion_imgs) in enumerate(test_loader):
        sharp_imgs, defocused_imgs = sharp_imgs.to(device), defocused_imgs.to(device)

        start_time = time.time()
        with torch.no_grad(), autocast('cuda'):
            sharpened_imgs = student_model(defocused_imgs)
        inference_time = time.time() - start_time
        inference_times.append(inference_time)

        sharp_np = sharp_imgs.cpu().numpy().transpose(0, 2, 3, 1)
        sharpened_np = sharpened_imgs.cpu().numpy().transpose(0, 2, 3, 1)
        sharp_np = (sharp_np * 0.5 + 0.5).clip(0, 1)
        sharpened_np = (sharpened_np * 0.5 + 0.5).clip(0, 1)

        for j in range(sharp_np.shape[0]):
            logging.debug(f"Sharp image {i}_{j}: min={sharp_np[j].min():.4f}, max={sharp_np[j].max():.4f}")
            logging.debug(f"Sharpened image {i}_{j}: min={sharpened_np[j].min():.4f}, max={sharpened_np[j].max():.4f}")
            ssim_score = ssim(sharp_np[j], sharpened_np[j], channel_axis=2, data_range=1.0)
            psnr_score = psnr(sharp_np[j], sharpened_np[j], data_range=1.0)
            ssim_scores.append(ssim_score)
            psnr_scores.append(psnr_score)

            if i * 1 + j < 2 and len(sample_images) < 2:
                sharp_path = os.path.join(RESULTS_DIR, f"sharp_{i}_{j}.png")
                sharpened_path = os.path.join(RESULTS_DIR, f"sharpened_{i}_{j}.png")
                defocused_path = os.path.join(RESULTS_DIR, f"defocused_{i}_{j}.png")
                sharp_img_uint8 = (sharp_np[j] * 255).astype(np.uint8)
                sharpened_img_uint8 = (sharpened_np[j] * 255).astype(np.uint8)
                defocused_np = defocused_imgs.cpu().numpy().transpose(0, 2, 3, 1)
                defocused_np = (defocused_np * 0.5 + 0.5).clip(0, 1)
                defocused_img_uint8 = (defocused_np[j] * 255).astype(np.uint8)
                cv2.imwrite(sharp_path, cv2.cvtColor(sharp_img_uint8, cv2.COLOR_RGB2BGR))
                cv2.imwrite(sharpened_path, cv2.cvtColor(sharpened_img_uint8, cv2.COLOR_RGB2BGR))
                cv2.imwrite(defocused_path, cv2.cvtColor(defocused_img_uint8, cv2.COLOR_RGB2BGR))
                logging.info(f"Saved sharp image to {os.path.abspath(sharp_path)}")
                logging.info(f"Saved sharpened image to {os.path.abspath(sharpened_path)}")
                logging.info(f"Saved defocused image to {os.path.abspath(defocused_path)}")
                sample_images.append((sharp_np[j], sharpened_np[j], defocused_np[j]))
                print(f"Sample {i*1+j+1} (SSIM: {ssim_score:.4f}, PSNR: {psnr_score:.4f}):")
                plt.figure(figsize=(15, 5))
                plt.subplot(1, 3, 1)
                plt.title("Defocused")
                plt.imshow(defocused_np[j])
                plt.axis('off')
                plt.subplot(1, 3, 2)
                plt.title("Sharp")
                plt.imshow(sharp_np[j])
                plt.axis('off')
                plt.subplot(1, 3, 3)
                plt.title("Sharpened")
                plt.imshow(sharpened_np[j])
                plt.axis('off')
                plt.show()

    avg_ssim = np.mean(ssim_scores)
    avg_psnr = np.mean(psnr_scores)
    avg_fps = 1 / np.mean(inference_times)

    plt.figure(figsize=(10, 6))
    plt.plot(ssim_scores, label='SSIM Score per Image')
    plt.axhline(y=avg_ssim, color='r', linestyle='--', label=f'Average SSIM: {avg_ssim:.4f}')
    plt.xlabel('Test Image Index')
    plt.ylabel('SSIM Score')
    plt.title('SSIM Scores on Test Dataset')
    plt.legend()
    ssim_plot_path = os.path.join(RESULTS_DIR, 'ssim_scores.png')
    plt.savefig(ssim_plot_path)
    plt.close()
    display(IPImage(ssim_plot_path))

    plt.figure(figsize=(10, 6))
    plt.plot(psnr_scores, label='PSNR Score per Image')
    plt.axhline(y=avg_psnr, color='r', linestyle='--', label=f'Average PSNR: {avg_psnr:.4f}')
    plt.xlabel('Test Image Index')
    plt.ylabel('PSNR Score')
    plt.title('PSNR Scores on Test Dataset')
    plt.legend()
    psnr_plot_path = os.path.join(RESULTS_DIR, 'psnr_scores.png')
    plt.savefig(psnr_plot_path)
    plt.close()
    display(IPImage(psnr_plot_path))

    return avg_ssim, avg_psnr, avg_fps, sample_images

# Generate Report
def generate_report(avg_ssim, avg_psnr, avg_fps, sample_images):
    report = f"""
# Image Sharpening Model Report

## Data Sources
- **Dataset**: Custom dataset with {len(dataset)} sharp images (defocused images missing for most triplets).
- **Paths**:
  - Sharp: {SHARP_PATH}
  - Defocused: {DEFOCUSED_PATH}
  - Motion: {MOTION_PATH}
- **Teacher Model**: Custom EDSR

## Model Description

### Teacher Model
- **Architecture**: EDSR with 16 residual blocks.
- **Parameters**: ~1M.
- **Role**: Provides stable sharpened outputs.

### Student Model
- **Architecture**: Simplified CNN with 5 layers, 2 residual blocks, skip connections, and residual connection.
  - Conv1: 3 -> 64, Conv2: 64 -> 64, ResBlocks: 64 -> 64 (x2), Conv3: 64 -> 128, Conv4: 128 -> 64, Conv5: 64 -> 3.
  - Skip connection: Conv 3 -> 64, Dropout: 0.1.
- **Parameters**: ~50K.
- **Role**: Lightweight sharpening.

## Training Process
- **Dataset Split**: {train_size} train, {test_size} test.
- **Preprocessing**: Resize to 640x360, data augmentation (flips, rotation).
- **Loss**: 0.5 * MSE, 0.2 * Teacher MSE, 0.3 * Perceptual.
- **Optimizer**: Adam, lr=0.0005, LinearWarmupCosineAnnealingLR.
- **Epochs**: 75.

## Performance Analysis
- **SSIM**: {avg_ssim:.4f} (target > 0.90).
- **PSNR**: {avg_psnr:.4f} (target > 30.0).
- **FPS**: {avg_fps:.2f} (target 30-60).
- **Test Dataset**: {len(test_dataset)} images.

## Results
- **SSIM Plot**: {RESULTS_DIR}/ssim_scores.png.
- **PSNR Plot**: {RESULTS_DIR}/psnr_scores.png.
- **Sample Images**: {RESULTS_DIR}/sharp_*.png, sharpened_*.png, defocused_*.png.
- **Model**: {RESULTS_DIR}/student_model.pth.

## Conclusion
SSIM {avg_ssim:.4f}, PSNR {avg_psnr:.4f}, FPS {avg_fps:.2f}. Dataset issues (missing defocused images) severely impacted performance.
"""
    with open(os.path.join(RESULTS_DIR, "sharpening_report.md"), "w") as f:
        f.write(report)

# Main Execution
if __name__ == "__main__":
    print("Training model...")
    train_model(epochs=75)

    print("Evaluating model...")
    avg_ssim, avg_psnr, avg_fps, sample_images = evaluate_model()

    print("Generating report...")
    generate_report(avg_ssim, avg_psnr, avg_fps, sample_images)

    print(f"Results saved in {RESULTS_DIR}")
    print(f"Average SSIM: {avg_ssim:.4f}, Average PSNR: {avg_psnr:.4f}, Average FPS: {avg_fps:.2f}")