In [None]:
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
from datasets import load_dataset
from torchmetrics import StructuralSimilarityIndexMeasure
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from tqdm import tqdm
import time
from torch.utils.data import DataLoader, TensorDataset

In [None]:
import os
from PIL import Image
from torchvision import transforms

def load_div2k_dataset_manual(
    train_hr_dir='/content/drive/MyDrive/BSDS500/images/train',
    valid_hr_dir='/content/drive/MyDrive/BSDS500/images/val',
    num_samples=200
):
    """
    Manually load DIV2K HR image pairs from local directories by creating synthetic LR images.

    Parameters:
    - train_hr_dir: Path to DIV2K_train_HR
    - valid_hr_dir: Path to DIV2K_valid_HR
    - num_samples: Total samples to use (split 80/20 between train and val)
    """

    hr_tf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    lr_tf = transforms.Compose([
        transforms.Resize((128, 128), interpolation=Image.BICUBIC),
        transforms.Resize((256, 256), interpolation=Image.BICUBIC),
        transforms.ToTensor()
    ])

    train_pairs = []
    val_pairs = []

    # Process training images
    train_filenames = sorted(os.listdir(train_hr_dir))[:int(0.8 * num_samples)]
    for fname in train_filenames:
        hr_path = os.path.join(train_hr_dir, fname)
        try:
            hr_img = Image.open(hr_path).convert("RGB")
            hr_tensor = hr_tf(hr_img)
            lr_tensor = lr_tf(hr_img)

            train_pairs.append((lr_tensor, hr_tensor))
        except Exception as e:
            print(f"Error loading training image {fname}: {e}")
            continue

    # Process validation images
    val_filenames = sorted(os.listdir(valid_hr_dir))[:num_samples - len(train_pairs)]
    for fname in val_filenames:
        hr_path = os.path.join(valid_hr_dir, fname)
        try:
            hr_img = Image.open(hr_path).convert("RGB")
            hr_tensor = hr_tf(hr_img)
            lr_tensor = lr_tf(hr_img)

            val_pairs.append((lr_tensor, hr_tensor))
        except Exception as e:
            print(f"Error loading validation image {fname}: {e}")
            continue

    print(f"Loaded {len(train_pairs)} training pairs and {len(val_pairs)} validation pairs")
    print(f"Train HR path: {train_hr_dir}")
    print(f"Valid HR path: {valid_hr_dir}")
    return train_pairs, val_pairs



In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

class TeacherSharpeningModel(nn.Module):
    """Teacher model based on ResNet50 encoder-decoder"""
    def __init__(self):
        super().__init__()
        # Use ResNet50 with default weights
        base = models.resnet50(weights='DEFAULT')
        # Remove the last two layers (avgpool and fc) to keep spatial dimensions
        # ResNet50 encoder output has 2048 channels before avgpool
        self.encoder = nn.Sequential(*list(base.children())[:-2])

        # Add adaptive pooling to ensure consistent size
        # ResNet50 encoder output spatial size is 8x8 for 256x256 input
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 8))

        # Decoder needs to handle 2048 input channels from ResNet50 encoder
        self.decoder = nn.Sequential(
            # From 8x8 to 16x16
            nn.ConvTranspose2d(2048, 1024, 4, stride=2, padding=1), # Adjusted input channels
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            # From 16x16 to 32x32
            nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1), # Adjusted channels
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # From 32x32 to 64x64
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # Adjusted channels
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # From 64x64 to 128x128
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # Adjusted channels
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # From 128x128 to 256x256
            nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1), # Adjusted channels
            nn.Sigmoid()
        )

    def forward(self, x):
        # Debug: print input shape
        # print(f"Input shape: {x.shape}")

        x = self.encoder(x)
        # print(f"After encoder: {x.shape}")

        x = self.adaptive_pool(x)
        # print(f"After adaptive pool: {x.shape}")

        x = self.decoder(x)
        # print(f"Output shape: {x.shape}")

        return x

In [None]:
import torch
import torch.nn as nn

class StudentSharpeningModel(nn.Module):
    """Enhanced student model with skip connections and Tanh output"""
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # Downsample
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        # Decoder
        self.up = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)  # Upsample
        self.dec = nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Tanh()  # Tanh for sharper output
        )

    def forward(self, x):
        x1 = self.enc1(x)   # (B, 32, H, W)
        x2 = self.enc2(x1)  # (B, 64, H/2, W/2)

        up = self.up(x2)    # (B, 32, H, W)
        up = up + x1        # Skip connection

        out = self.dec(up)  # (B, 3, H, W)
        return (out + 1) / 2  # Rescale from [-1, 1] to [0, 1]


In [None]:
def train_student_model(train_pairs, val_pairs, epochs=5):
    """Train student model using knowledge distillation"""

    # Initialize models
    teacher_model = TeacherSharpeningModel().to(device)
    student_model = StudentSharpeningModel().to(device)

    # Set teacher to eval mode (no training)
    teacher_model.eval()

    # Convert pairs to tensors
    lr_imgs = torch.stack([x[0] for x in train_pairs])
    hr_imgs = torch.stack([x[1] for x in train_pairs])

    val_lr_imgs = torch.stack([x[0] for x in val_pairs])
    val_hr_imgs = torch.stack([x[1] for x in val_pairs])

    train_loader = DataLoader(TensorDataset(lr_imgs, hr_imgs), batch_size=8, shuffle=True)

    # Loss functions and optimizer
    mse_loss = nn.MSELoss()
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    # Training loop
    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for lr, hr in progress_bar:
            lr = lr.to(device)
            hr = hr.to(device)

            # Generate teacher output (no gradients)
            with torch.no_grad():
                teacher_output = teacher_model(lr)

            # Student output
            student_output = student_model(lr)

            # Combined loss: distillation + ground truth
            loss_distill = mse_loss(student_output, teacher_output)
            loss_gt = mse_loss(student_output, hr)
            loss_ssim = 1 - ssim_metric(student_output, hr)

            # Weighted combination
            total_loss = 0.4 * loss_distill + 0.4 * loss_gt + 0.2 * loss_ssim

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()
            progress_bar.set_postfix({'loss': f'{total_loss.item():.4f}'})

        scheduler.step()
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Average Loss: {avg_loss:.4f}")

        # Validation
        if epoch % 2 == 0:
            val_ssim = evaluate_model(student_model, val_lr_imgs, val_hr_imgs)
            print(f"Validation SSIM: {val_ssim:.4f}")

    return student_model, teacher_model, val_lr_imgs, val_hr_imgs

In [None]:
def evaluate_model(model, val_lr, val_hr):
    """Evaluate model performance"""
    model.eval()
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        val_lr = val_lr.to(device)
        val_hr = val_hr.to(device)
        pred = model(val_lr)
        ssim_score = ssim_metric(pred, val_hr).item()

    return ssim_score

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images"""
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

In [None]:
def show_results(student_model, val_lr, val_hr, index=0):
    """Display comparison results"""
    student_model.eval()

    with torch.no_grad():
        lr_img = val_lr[index].cpu().permute(1, 2, 0)
        hr_img = val_hr[index].cpu().permute(1, 2, 0)
        pred_img = student_model(val_lr[index:index+1].to(device)).squeeze(0).cpu().permute(1, 2, 0)

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(lr_img.clamp(0, 1))
    axs[0].set_title('Low Resolution Input')
    axs[0].axis('off')

    axs[1].imshow(pred_img.clamp(0, 1))
    axs[1].set_title('Student Model Output')
    axs[1].axis('off')

    axs[2].imshow(hr_img.clamp(0, 1))
    axs[2].set_title('Ground Truth HR')
    axs[2].axis('off')

    plt.tight_layout()
    plt.show()

    # Calculate metrics
    psnr = calculate_psnr(pred_img, hr_img)
    print(f"PSNR: {psnr:.2f} dB")

In [None]:
def benchmark_fps(model, resolution=(1080, 1920)):
    """Benchmark model FPS at given resolution"""
    model.eval()

    # Create test input
    test_input = torch.randn(1, 3, resolution[0], resolution[1]).to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(5):
            _ = model(test_input)

    # Benchmark
    torch.cuda.synchronize() if device == 'cuda' else None
    start_time = time.time()

    with torch.no_grad():
        for _ in range(30):
            _ = model(test_input)

    torch.cuda.synchronize() if device == 'cuda' else None
    end_time = time.time()

    fps = 30 / (end_time - start_time)
    print(f"FPS at {resolution[0]}x{resolution[1]}: {fps:.2f}")

In [None]:
def main():
    """Main execution function"""
    # Load dataset
    train_pairs, val_pairs = load_div2k_dataset_manual(
        train_hr_dir='/content/drive/MyDrive/BSDS500/images/train',
        valid_hr_dir='/content/drive/MyDrive/BSDS500/images/val',
        num_samples=200
    )

    if len(train_pairs) == 0 or len(val_pairs) == 0:
        print("Error: No data loaded. Please check dataset loading.")
        return


    # Initialize models
    print("\nInitializing Teacher and Student models...")
    # Use the updated Teacher model with ResNet50
    teacher_model = TeacherSharpeningModel().to(device)
    student_model = StudentSharpeningModel().to(device)

    # Set teacher to eval mode (no training)
    teacher_model.eval()

    # Convert pairs to tensors
    lr_imgs = torch.stack([x[0] for x in train_pairs])
    hr_imgs = torch.stack([x[1] for x in train_pairs])

    val_lr_imgs = torch.stack([x[0] for x in val_pairs])
    val_hr_imgs = torch.stack([x[1] for x in val_pairs])

    train_loader = DataLoader(TensorDataset(lr_imgs, hr_imgs), batch_size=8, shuffle=True)

    # Loss functions and optimizer for the Student model
    mse_loss = nn.MSELoss()
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    epochs = 25 # You can adjust the number of epochs

    # Training loop for the Student model with Knowledge Distillation
    print("\nStarting training with Student model and Knowledge Distillation...")
    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for lr, hr in progress_bar:
            lr = lr.to(device)
            hr = hr.to(device)

            # Generate teacher output (no gradients)
            with torch.no_grad():
                teacher_output = teacher_model(lr)

            # Student output
            student_output = student_model(lr)

            # Combined loss: distillation + ground truth + SSIM
            loss_distill = mse_loss(student_output, teacher_output) # Loss against Teacher output
            loss_gt = mse_loss(student_output, hr) # Loss against Ground Truth
            loss_ssim = 1 - ssim_metric(student_output, hr) # SSIM loss

            # Weighted combination (adjust weights as needed)
            # Example weights: 40% distillation, 40% ground truth, 20% SSIM
            total_loss = 0.4 * loss_distill + 0.4 * loss_gt + 0.2 * loss_ssim

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()
            progress_bar.set_postfix({'loss': f'{total_loss.item():.4f}'})

        scheduler.step()
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Average Loss: {avg_loss:.4f}")

        # Validation
        if epoch % 2 == 0:
            val_ssim = evaluate_model(student_model, val_lr_imgs, val_hr_imgs)
            print(f"Validation SSIM: {val_ssim:.4f}")


    # Final evaluation
    print("\nFinal Evaluation:")
    final_ssim = evaluate_model(student_model, val_lr_imgs, val_hr_imgs)
    print(f"Final SSIM Score: {final_ssim:.4f} ({final_ssim * 100:.2f}%)")

    # Show results
    print("\nSample Results:")
    for i in range(min(3, len(val_pairs))):
        print(f"\nSample {i+1}:")
        # Use student_model for show_results
        show_results(student_model, val_lr_imgs, val_hr_imgs, index=i)
        print("MOS Rating (1-5): ______")  # Manual evaluation

    # Benchmark FPS
    print("\nFPS Benchmarking:")
    # Benchmark student_model
    benchmark_fps(student_model, resolution=(1080, 1920))

    # Save model
    # Changed to use a Colab-compatible path in Google Drive
    save_path = "/content/drive/MyDrive/Student/student.pt"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(student_model.state_dict(), save_path)
    print(f"\n🎉 Student model saved as '{save_path}'")

    return student_model, teacher_model, val_lr_imgs, val_hr_imgs # Return both models and validation data

if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import time
import os

# Set your model path here
model_path = "C:/Users/LENOVO/Downloads/vedha/student_model.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Define Student Model (must match training definition)
class StudentSharpeningModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.up = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
        self.dec = nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        up = self.up(x2)
        up = up + x1  # Skip connection
        out = self.dec(up)
        return (out + 1) / 2  # Rescale from [-1, 1] to [0, 1]

# Load model
try:
    loaded_model = StudentSharpeningModel().to(device)
    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval()
    print(f" Student model loaded successfully from:\n   {model_path}")
except Exception as e:
    print(f"Failed to load model: {e}")
    exit()

# Load and preprocess input image
def load_and_preprocess_image(image_path, target_size=(256, 256)):
    try:
        img = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([
            transforms.Resize(target_size, interpolation=Image.BICUBIC),
            transforms.ToTensor()
        ])
        return transform(img).unsqueeze(0).to(device)  # (1, 3, H, W)
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

# Save output
def save_output_image(tensor, output_path="sharpened_output.png"):
    try:
        img = transforms.ToPILImage()(tensor.squeeze(0).cpu())
        img.save(output_path)
        print(f" Sharpened image saved to: {output_path}")
    except Exception as e:
        print(f" Error saving image: {e}")

# Main inference loop
if __name__ == "__main__":
    image_path = input(" Enter path to the image you want to sharpen: ").strip()

    if not os.path.exists(image_path):
        print(" File not found. Please check the path.")
        exit()

    input_tensor = load_and_preprocess_image(image_path)

    if input_tensor is not None:
        print(" Running inference...")
        start_time = time.time()
        with torch.no_grad():
            output_tensor = loaded_model(input_tensor)
        end_time = time.time()

        print(f" Inference time: {end_time - start_time:.4f} seconds")

        # Save and show result
        output_path = "sharpened_output.png"
        save_output_image(output_tensor, output_path)

        plt.imshow(transforms.ToPILImage()(output_tensor.squeeze(0).cpu()))
        plt.title(" Sharpened Output")
        plt.axis('off')
        plt.show()
