# Burst2Scene AI - Project Demonstration
This notebook demonstrates the inference pipeline of the Burst2Scene GAN-based model using a simplified sample of real burst images. It loads pretrained models and performs inference on a 10-frame burst sample.

In [None]:
!pip install -r requirements.txt

## 📁 Check Required Files

In [None]:
from pathlib import Path
required = [
    Path('generator_last.pth'),
    Path('discriminator_last.pth'),
    Path('demo_burst/burst_00')
]
for path in required:
    print(f"{'✔️' if path.exists() else '❌'} {path}")

## 📦 Imports

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
import os
import numpy as np

## 📂 Load Burst Sample Dataset

In [9]:
from pathlib import Path
from PIL import Image
from torchvision import transforms
import torch

class BurstDemoDataset:
    def __init__(self, burst_path, max_frames=None):
        self.burst_path = Path(burst_path)
        self.frames = sorted(list(self.burst_path.glob('*.png')) + list(self.burst_path.glob('*.jpg')))
        if not self.frames:
            raise FileNotFoundError(f"No image frames found in {burst_path}")
        if max_frames:
            self.frames = self.frames[:max_frames]
        self.transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor()
        ])

    def load_tensor(self):
        tensors = []
        for f in self.frames:
            try:
                img = Image.open(f).convert("RGB")
                tensors.append(self.transform(img))
            except Exception as e:
                print(f"Error reading {f.name}: {e}")
        return torch.stack(tensors)

# Update path to your actual burst folder
dataset = BurstDemoDataset('demo_burst/burst_00')
burst_tensor = dataset.load_tensor()
print(f'Loaded burst shape: {burst_tensor.shape}')


Loaded burst shape: torch.Size([10, 3, 128, 128])


###Load Pretrained Generator & PatchGAN Discriminator

In [10]:
from generator_colab import Generator
from discriminator_PatchGAN import ConditionalDiscriminator

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator().to(device)
discriminator = ConditionalDiscriminator().to(device)

generator.load_state_dict(torch.load('generator_last.pth', map_location=device))
discriminator.load_state_dict(torch.load('discriminator_last.pth', map_location=device))

generator.eval()
discriminator.eval()
print('Models loaded successfully.')

Models loaded successfully.


## 🧪 Inference on Sample Burst

In [11]:
@torch.no_grad()
def infer(model, burst):
    burst = burst.unsqueeze(0).to(device)
    output = model(burst)
    return output.squeeze(0).cpu()

result = infer(generator, burst_tensor)

# Display first input frame and generated output
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(burst_tensor[0].permute(1, 2, 0).clamp(0, 1))
axs[0].set_title('Input: First Burst Frame')
axs[0].axis('off')

axs[1].imshow(result.permute(1, 2, 0).clamp(0, 1))
axs[1].set_title('Generated Scene')
axs[1].axis('off')
plt.show()

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 10, 3, 128, 128]

## 📸 Run Model Evaluation Preview (Inlined `preview.py`)

In [14]:
import torch
import os
from torchvision.utils import save_image, make_grid
from validation_dataset_colab import BurstDataset
from generator_colab import Generator
from discriminator_PatchGAN import ConditionalDiscriminator  # Adjusted to correct PatchGAN
import piq
from PIL import Image, ImageDraw, ImageFont

# --- Config ---
checkpoint_path = "generator_last.pth"  # ← now points to root
burst_dir = "burst_validation_high_variation"
preview_name = "preview_patchgan"
base_preview_dir = os.path.join("model_preview_dir", preview_name)
os.makedirs(base_preview_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load models ---
generator = Generator().to(device)
generator.load_state_dict(torch.load(checkpoint_path, map_location=device))
generator.eval()

discriminator = ConditionalDiscriminator().to(device)
discriminator.eval()

# --- Load dataset ---
dataset = BurstDataset(bursts_dir=burst_dir, burst_size=10)

unnorm = lambda x: (x + 1) / 2

def annotate_image(image_path, text):
    image = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    bbox = draw.textbbox((0, 0), text, font=font)
    width, height = image.size
    draw.rectangle([(0, height - (bbox[3] - bbox[1]) - 10), (width, height)], fill=(0, 0, 0))
    draw.text((10, height - (bbox[3] - bbox[1]) - 5), text, fill="white", font=font)
    image.save(image_path)

# --- Inference loop ---
for i in range(len(dataset)):
    burst, target = dataset[i]
    burst = burst.to(device)
    target = target.to(device)
    burst_input = burst.view(1, -1, burst.size(2), burst.size(3))

    with torch.no_grad():
        output = generator(burst_input)
        confidence = torch.sigmoid(discriminator(burst_input, output)).mean().item()

    output = unnorm(output.squeeze(0))
    target = unnorm(target)
    burst_frames = [unnorm(frame) for frame in burst]

    # Save visual outputs
    sample_dir = os.path.join(base_preview_dir, f"sample_{i}")
    os.makedirs(sample_dir, exist_ok=True)

    output_path = os.path.join(sample_dir, "generated_output.png")
    save_image(output, output_path)
    save_image(target, os.path.join(sample_dir, "target_frame.png"))
    save_image(make_grid(burst_frames, nrow=5), os.path.join(sample_dir, "burst_grid.png"))

    # Evaluate
    output_u = output.unsqueeze(0)
    target_u = target.unsqueeze(0)
    psnr_score = piq.psnr(output_u, target_u, data_range=1.0).item()
    ssim_score = piq.ssim(output_u, target_u, data_range=1.0).item()

    label = f"PSNR = {psnr_score:.2f}, SSIM = {ssim_score:.4f}, Conf = {confidence:.4f}"
    annotate_image(output_path, label)
    with open(os.path.join(sample_dir, "metrics.txt"), "w") as f:
        f.write(label + "\n")

    print(f"🖼 Sample {i}: {label}")

# Write full summary
with open(os.path.join(base_preview_dir, "info.txt"), "w") as log_file:
    for i in range(len(dataset)):
        sample_metrics_path = os.path.join(base_preview_dir, f"sample_{i}", "metrics.txt")
        if os.path.exists(sample_metrics_path):
            with open(sample_metrics_path, "r") as f:
                log_file.write(f"Sample {i}: {f.readline().strip()}\n")

print(f"✅ Previews saved to: {base_preview_dir}")


🖼 Sample 0: PSNR = 23.49, SSIM = 0.7337, Conf = 0.0664
🖼 Sample 1: PSNR = 16.40, SSIM = 0.2263, Conf = 0.0312
✅ Previews saved to: model_preview_dir/preview_patchgan
