In [1]:
import torch
import torch.nn as nn

In [2]:
#Augmentations
import numpy as np
import cv2
import random

# ----------------------------
# 1. Radial vignetting / limb-darkening perturbation
# ----------------------------
def radial_vignette(img, strength=0.2):
    """
    Apply radial vignette (simulate limb-darkening).
    strength: 0.0 (none) to 0.5 (strong)
    """
    h, w = img.shape
    y, x = np.ogrid[:h, :w]
    cy, cx = h / 2, w / 2
    r = np.sqrt((x - cx)**2 + (y - cy)**2)
    r = r / r.max()  # normalize radius [0,1]
    
    # limb-darkening multiplier
    mask = 1 - strength * (r**2)  
    vignette = img.astype(np.float32) * mask
    vignette = np.clip(vignette, 0, 255)
    return vignette.astype(np.uint8)

# ----------------------------
# 2. Poisson noise (photon noise)
# ----------------------------
def add_poisson_noise(img, scale_low=5.0, scale_high=100.0):
    """
    Add Poisson noise to simulate photon noise.
    """
    if img.dtype != np.float32:
        img = img.astype(np.float32) / 255.0
    scale = random.uniform(scale_low, scale_high)
    noisy = np.random.poisson(img * scale) / float(scale)
    noisy = np.clip(noisy, 0.0, 1.0)
    return (noisy * 255).astype(np.uint8)

# ----------------------------
# 3. Small rotations
# ----------------------------
def random_rotation(img, angle_range=20):
    """
    Rotate image by random angle within ±angle_range.
    """
    h, w = img.shape
    angle = random.uniform(-angle_range, angle_range)
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
    rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    return rotated

# ----------------------------
# 4. Random crop + resize
# ----------------------------
def random_crop_resize(img, crop_scale=0.8, out_size=(256, 256)):
    """
    Random crop followed by resize.
    crop_scale: fraction of original image kept (0.5–1.0).
    """
    h, w = img.shape
    ch, cw = int(h * crop_scale), int(w * crop_scale)
    y = random.randint(0, h - ch)
    x = random.randint(0, w - cw)
    crop = img[y:y+ch, x:x+cw]
    resized = cv2.resize(crop, out_size, interpolation=cv2.INTER_LINEAR)
    return resized

# ----------------------------
# 5. Gaussian blur (PSF-like)
# ----------------------------
def gaussian_blur(img, sigma_range=(0.5, 2.0)):
    """
    Apply Gaussian blur with random sigma.
    """
    sigma = random.uniform(*sigma_range)
    ksize = int(2 * round(3*sigma) + 1)  # kernel size ~ 6*sigma
    blurred = cv2.GaussianBlur(img, (ksize, ksize), sigmaX=sigma)
    return blurred

# ----------------------------
# 6. Brightness/contrast jitter
# ----------------------------
def brightness_contrast_jitter(img, brightness=0.2, contrast=0.3):
    """
    Random brightness/contrast adjustment.
    brightness: fraction (e.g. 0.2 → ±20%)
    contrast: fraction (e.g. 0.3 → ±30%)
    """
    b = random.uniform(-brightness, brightness) * 255
    c = 1.0 + random.uniform(-contrast, contrast)
    jittered = img.astype(np.float32) * c + b
    jittered = np.clip(jittered, 0, 255)
    return jittered.astype(np.uint8)


In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)  # latent: [B,64,H/8,W/8]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64,32,3,stride=2,padding=1,output_padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32,16,3,stride=2,padding=1,output_padding=1), nn.ReLU(),
            nn.ConvTranspose2d(16,1,3,stride=2,padding=1,output_padding=1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.deconv(x)


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent = self.encoder(x)    # compress to latent vector
        reconstructed = self.decoder(latent)  # reconstruct image
        return reconstructed


In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import os
import random
import numpy as np

class SolarDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, augment=True, img_size=(256,256),max_images=None):
        self.lr_files = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir)])
        self.hr_files = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir)])
        if max_images is not None:
            self.lr_files = self.lr_files[:max_images]
            self.hr_files = self.hr_files[:max_images]
        self.augment = augment
        self.img_size = img_size

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        # Load images
        lr = cv2.imread(self.lr_files[idx], cv2.IMREAD_GRAYSCALE)
        hr = cv2.imread(self.hr_files[idx], cv2.IMREAD_GRAYSCALE)

        # Resize to same size
        lr = cv2.resize(lr, self.img_size)
        hr = cv2.resize(hr, self.img_size)

        # Apply augmentations to LR only
        if self.augment:
            lr = radial_vignette(lr, strength=random.uniform(0.05,0.2))
            lr = add_poisson_noise(lr)
            lr = random_rotation(lr, angle_range=15)
            lr = gaussian_blur(lr, sigma_range=(0.5,1.5))
            lr = brightness_contrast_jitter(lr, brightness=0.15, contrast=0.2)
        
        # Normalize and convert to tensor [C,H,W]
        lr = torch.tensor(lr, dtype=torch.float32).unsqueeze(0)/255.0
        hr = torch.tensor(hr, dtype=torch.float32).unsqueeze(0)/255.0

        return lr, hr


In [10]:
batch_size = 16
max_images = 1000

train_dataset = SolarDataset("new_dataset/training/low_res", 
                             "new_dataset/training/high_res", 
                             augment=True, max_images=max_images)

val_dataset = SolarDataset("new_dataset/validation/low_res", 
                           "new_dataset/validation/high_res", 
                           augment=False, max_images=max_images)

test_dataset = SolarDataset("new_dataset/testing/low_res", 
                            "new_dataset/testing/high_res", 
                            augment=False, max_images=max_images)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
from torchvision import transforms
from torchmetrics.functional import structural_similarity_index_measure as ssim

# -----------------------------
# Metrics
# -----------------------------
def pixelwise_error(pred, target):
    return torch.mean(torch.abs(pred - target)).item()

def psnr(pred, target, max_val=1.0):
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(max_val / torch.sqrt(mse)).item()


def ssim_metric(pred, target):
    return ssim(pred, target, data_range=1.0).item()  

# -----------------------------
# Hyperparameters
# -----------------------------
batch_size = 8
lr = 1e-3
epochs = 20
latent_dim = 128

# -----------------------------
# Model, Loss, Optimizer
# -----------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# -----------------------------
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

best_val_loss = float('inf')
bottleneck_embeddings = []

# -----------------------------
# Training Loop
# -----------------------------
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for lr_imgs, hr_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

        optimizer.zero_grad()
        outputs = model(lr_imgs)
        loss = criterion(outputs, hr_imgs)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * lr_imgs.size(0)

    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0
    pixel_err = 0
    psnr_val = 0
    ssim_val = 0

    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

            outputs = model(lr_imgs)
            loss = criterion(outputs, hr_imgs)
            val_loss += loss.item() * lr_imgs.size(0)

            pixel_err += pixelwise_error(outputs, hr_imgs) * lr_imgs.size(0)
            psnr_val += psnr(outputs, hr_imgs) * lr_imgs.size(0)
            ssim_val += ssim_metric(outputs, hr_imgs) * lr_imgs.size(0)

            # Save bottleneck embeddings
            latent = model.encoder(lr_imgs)
            bottleneck_embeddings.append(latent.cpu().numpy())

    val_loss /= len(val_loader.dataset)
    pixel_err /= len(val_loader.dataset)
    psnr_val /= len(val_loader.dataset)
    ssim_val /= len(val_loader.dataset)

    print(f"Epoch {epoch+1}/{epochs} | "
          f"Train Loss: {train_loss:.6f} | "
          f"Val Loss: {val_loss:.6f} | "
          f"Pixel Error: {pixel_err:.6f} | "
          f"PSNR: {psnr_val:.2f} dB | "
          f"SSIM: {ssim_val:.4f}")

    # -----------------------------
    # Save checkpoints
    # -----------------------------
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"autoencoder_epoch{epoch+1}.pth"))
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_autoencoder.pth"))
        print(f"Best model updated at epoch {epoch+1}")

# Save bottleneck embeddings
bottleneck_embeddings = np.concatenate(bottleneck_embeddings, axis=0)
np.save("bottleneck_embeddings.npy", bottleneck_embeddings)
print("Bottleneck embeddings saved:", bottleneck_embeddings.shape)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics.functional import structural_similarity_index_measure as ssim
import os

# -----------------------------
# Metrics
# -----------------------------
def pixelwise_error(pred, target):
    return torch.mean(torch.abs(pred - target)).item()

def psnr(pred, target, max_val=1.0):
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(max_val / torch.sqrt(mse)).item()

def ssim_metric(pred, target):
    return ssim(pred, target, data_range=1.0).item()

# -----------------------------
# Load model
# -----------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Autoencoder().to(device)
model.load_state_dict(torch.load("checkpoints/best_autoencoder.pth"))
model.eval()

# Directory to save reconstructed images
os.makedirs("reconstructed_images", exist_ok=True)

# -----------------------------
# Testing and Visualization
# -----------------------------
psnr_list = []
ssim_list = []
pixel_err_list = []

count = 0  # to limit visualization
for lr_imgs, hr_imgs in tqdm(test_loader, desc="Testing"):
    lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
    with torch.no_grad():
        outputs = model(lr_imgs)
    
    for i in range(lr_imgs.size(0)):
    # Move to CPU and convert to float
        pred_img = outputs[i].detach().cpu().float()
        true_img = hr_imgs[i].detach().cpu().float()

    # Add channel dimension for SSIM: [1, 1, H, W]
        pred_img_ssim = pred_img.unsqueeze(0)  # shape: [1, 1, H, W]
        true_img_ssim = true_img.unsqueeze(0)  # shape: [1, 1, H, W]

    # Metrics
        psnr_list.append(psnr(pred_img, true_img))
        ssim_list.append(ssim_metric(pred_img_ssim, true_img_ssim))
        pixel_err_list.append(pixelwise_error(pred_img, true_img))
        
        # Save reconstructed image
        recon_img_path = f"reconstructed_images/recon_{count}.png"
        plt.imsave(recon_img_path, pred_img.squeeze(), cmap='gray')
        
        # Display side by side for first 5 images
        if count < 5:
            fig, axes = plt.subplots(1,2, figsize=(6,3))
            axes[0].imshow(true_img.squeeze(), cmap='gray')
            axes[0].set_title("Original")
            axes[0].axis('off')
            axes[1].imshow(pred_img.squeeze(), cmap='gray')
            axes[1].set_title("Reconstructed")
            axes[1].axis('off')
            plt.show()
        
        count += 1


print(f"Average PSNR: {np.mean(psnr_list):.2f} dB")
print(f"Average SSIM: {np.mean(ssim_list):.4f}")
print(f"Average Pixel Error: {np.mean(pixel_err_list):.6f}")
print(f"Reconstructed images saved in 'reconstructed_images/' folder")
