In [None]:
!pip install pytorch-wavelets

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.utils import spectral_norm
from pytorch_wavelets import DWTForward  # Make sure pytorch_wavelets is installed


In [None]:
# Add missing beta schedule function
def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clamp(betas, 0.0001, 0.9999)

# ------------------------ Wavelet Block ------------------------ #
class WaveletBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.dwt = DWTForward(J=1, mode='zero', wave='haar')
        self.conv = nn.Conv2d(9, 64, kernel_size=1) 
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        yl, yh = self.dwt(x) 
        yh = yh[0]  # first level
        yh = yh.reshape(yh.shape[0], -1, yh.shape[-2], yh.shape[-1]) 

        # Upsample high-frequency to match input
        yh_upsampled = nn.functional.interpolate(yh, size=x.shape[-2:], mode='bilinear', align_corners=False)

        features = self.relu(self.conv(yh_upsampled))
        return features

# ------------------------ Residual Block ------------------------ #
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return self.relu(out)

# ------------------------ Generator ------------------------ #
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_features=64):
        super().__init__()
        self.wavelet     = WaveletBlock()
        self.conv1       = nn.Conv2d(in_channels, num_features, 3, padding=1)
        self.res_blocks  = nn.Sequential(*[ResBlock(num_features) for _ in range(8)])
        
        # first upsample + sharpen
        self.upconv1     = nn.ConvTranspose2d(num_features, num_features, 4, stride=2, padding=1)
        self.res_up1     = ResBlock(num_features)
        
        # second upsample + sharpen
        self.upconv2     = nn.ConvTranspose2d(num_features, num_features, 4, stride=2, padding=1)
        self.res_up2     = ResBlock(num_features)
        
        # RGB
        self.conv_final  = nn.Conv2d(num_features, in_channels, 3, padding=1)
        self.tanh        = nn.Tanh()

    def forward(self, x):
        # 1) initial conv + wavelet
        feat          = self.conv1(x)
        wavelet_feats = self.wavelet(x)
        out           = feat + wavelet_feats
        
        # 2) deep residual blocks
        out = self.res_blocks(out)
        
        # 3) upsample + sharpening
        out = self.upconv1(out)
        out = self.res_up1(out)
        
        # 4) upsample + sharpening
        out = self.upconv2(out)
        out = self.res_up2(out)
        
        # 5) to RGB
        out = self.conv_final(out)
        return out

# ------------------------ Discriminator ------------------------ #
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, base_features=64):
        super().__init__()
        def sn_conv(in_f, out_f, k, s, p):
            # SpectralNorm + Conv + LeakyReLU
            return nn.Sequential(
                spectral_norm(nn.Conv2d(in_f, out_f, kernel_size=k, stride=s, padding=p)),
                nn.LeakyReLU(0.2, inplace=True)
            )

        # 70×70 PatchGAN:
        self.model = nn.Sequential(
            # input: N×3×H×W → N×64×H/2×W/2
            *sn_conv(in_channels,   base_features,    4, 2, 1),
            # N×64→128, H/2→H/4
            *sn_conv(base_features, base_features*2,  4, 2, 1),
            # N×128→256, H/4→H/8
            *sn_conv(base_features*2, base_features*4,4, 2, 1),
            # N×256→512, H/8→H/16 (stride=1 to keep patch size ~70)
            *sn_conv(base_features*4, base_features*8,4, 1, 1),
            # Final conv to 1-channel "realness" map
            spectral_norm(nn.Conv2d(base_features*8, 1, kernel_size=4, stride=1, padding=1))
        )

    def forward(self, x):
        return self.model(x)

# ------------------------ UNet Denoiser ------------------------ #
# --- Timestep Embedding --- #
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -(math.log(10000) / (half_dim)))
        emb = t[:, None].float() * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

# --- Self-Attention Block --- #
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key   = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.shape
        q = self.query(x).view(B, -1, H*W).permute(0, 2, 1)
        k = self.key(x).view(B, -1, H*W)                    
        attn = torch.bmm(q, k) / (C ** 0.5)                
        attn = torch.softmax(attn, dim=-1)

        v = self.value(x).view(B, -1, H*W)             
        out = torch.bmm(v, attn.permute(0, 2, 1)).view(B, C, H, W)
        return self.gamma * out + x

# --- UNet Block with Time Embedding --- #
class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim=None):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        if time_emb_dim:
            self.time_proj = nn.Sequential(
                nn.Linear(time_emb_dim, out_ch),
                nn.ReLU(inplace=True)
            )
        else:
            self.time_proj = None

    def forward(self, x, t_emb=None):
        out = self.conv(x)
        if self.time_proj is not None and t_emb is not None:
            B, C, H, W = out.shape
            time_feat = self.time_proj(t_emb).view(B, C, 1, 1)
            out = out + time_feat
        return out

# --- Full UNet Denoiser --- #
class UNetDenoiser(nn.Module):
    def __init__(self, in_channels=3, base=64, time_emb_dim=128):
        super().__init__()
        self.time_embed = SinusoidalTimeEmbedding(time_emb_dim)

        self.enc1 = UNetBlock(in_channels, base, time_emb_dim)
        self.enc2 = UNetBlock(base, base*2, time_emb_dim)
        self.enc3 = UNetBlock(base*2, base*4, time_emb_dim)
        self.enc4 = UNetBlock(base*4, base*8, time_emb_dim)

        self.pool = nn.MaxPool2d(2)

        self.middle = UNetBlock(base*8, base*8, time_emb_dim)
        self.attention =  SelfAttention(base*8)

        self.up4 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec4 = UNetBlock(base*12, base*4, time_emb_dim)

        self.up3 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec3 = UNetBlock(base*6, base*2, time_emb_dim)

        self.up2 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec2 = UNetBlock(base*3, base, time_emb_dim)

        self.up1 = nn.ConvTranspose2d(base, base, 2, stride=2)
        self.dec1 = UNetBlock(base*2, base, time_emb_dim)

        self.outc = nn.Conv2d(base, in_channels, 1)

    def forward(self, x, t):
        t_emb = self.time_embed(t)

        e1 = self.enc1(x, t_emb)
        e2 = self.enc2(self.pool(e1), t_emb)
        e3 = self.enc3(self.pool(e2), t_emb)
        e4 = self.enc4(self.pool(e3), t_emb)

        m = self.middle(self.pool(e4), t_emb)
        m = self.attention(m)

        d4 = self.up4(m)
        d4 = self.dec4(torch.cat([d4, e4], dim=1), t_emb)

        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1), t_emb)

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1), t_emb)

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1), t_emb)

        return self.outc(d1)

# ------------------------ Improved Diffusion Generator ------------------------ #
class DiffusionGenerator(nn.Module):
    def __init__(self, in_channels=3, timesteps=1000):
        super().__init__()
        self.timesteps = timesteps
        self.denoiser = UNetDenoiser(in_channels)
        self.register_buffer("betas", cosine_beta_schedule(timesteps))
        
        # Pre-compute important values used during training and sampling
        alphas = 1. - self.betas
        self.register_buffer("alphas_cumprod", torch.cumprod(alphas, dim=0))
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(self.alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1. - self.alphas_cumprod))
        
        # For sampling
        self.register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas))
        posterior_variance = self.betas * (1. - self.alphas_cumprod.clone() / self.alphas_cumprod)
        self.register_buffer("posterior_variance", posterior_variance)
        self.register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer("posterior_mean_coef1", self.betas * torch.sqrt(self.alphas_cumprod) / (1. - self.alphas_cumprod))
        self.register_buffer("posterior_mean_coef2", (1. - self.alphas_cumprod) * torch.sqrt(alphas) / (1. - self.alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        """Forward diffusion: add noise to the image"""
        if noise is None:
            noise = torch.randn_like(x_start)
            
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise, noise

    def p_losses(self, x_start, t, noise=None):
        """Training loss for the denoiser"""
        if noise is None:
            noise = torch.randn_like(x_start)
            
        x_noisy, _ = self.q_sample(x_start, t, noise)
        predicted_noise = self.denoiser(x_noisy, t)
        
        loss = F.mse_loss(predicted_noise, noise)
        return loss, x_noisy

    @torch.no_grad()
    def p_sample(self, x, t):
        """Sample from the model at timestep t"""
        betas_t = self.betas[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
        
        # Predict the noise
        model_output = self.denoiser(x, t)
        
        # Get the predicted x_0
        pred_original_sample = (x - sqrt_one_minus_alphas_cumprod_t * model_output) * sqrt_recip_alphas_t
        
        # Add noise for t > 0
        if t[0] > 0:
            noise = torch.randn_like(x)
            variance = torch.sqrt(self.posterior_variance[t].view(-1, 1, 1, 1))
            x = pred_original_sample + variance * noise
            
        return x

    @torch.no_grad()
    def sample(self, x, steps=100):
        """Gradually remove noise from a noisy image"""
        # Use a subset of timesteps for faster sampling
        timesteps = torch.linspace(0, self.timesteps - 1, steps).long().to(x.device)
        timesteps = timesteps.flip(0)  # Reverse for denoising
        
        curr_x = x
        
        for i, t in enumerate(timesteps):
            t_batch = torch.full((curr_x.shape[0],), t, device=curr_x.device, dtype=torch.long)
            curr_x = self.p_sample(curr_x, t_batch)
            
        return curr_x
            
    def forward(self, x, sampling_steps=100):
        """Add noise and then denoise"""
        # Add noise according to the diffusion process
        t = torch.randint(0, self.timesteps, (x.size(0),), device=x.device)
        x_noisy, _ = self.q_sample(x, t)
        
        # Then denoise to get diffusion-processed image
        x_denoised = self.sample(x_noisy, steps=sampling_steps)
        
        return x_denoised

# ------------------------ Combined Generator ------------------------ #
class CombinedGenerator(nn.Module):
    def __init__(self, in_channels=3, diffusion_steps=100):
        super().__init__()
        self.diffusion = DiffusionGenerator(in_channels, timesteps=1000)
        self.gan = Generator(in_channels)
        self.diffusion_steps = diffusion_steps

    def forward(self, x):
        # Use diffusion model to enhance details
        x_diff = self.diffusion(x, sampling_steps=self.diffusion_steps)
        # Then use GAN to super-resolve
        gen_hr = self.gan(x_diff)
        return gen_hr, x_diff

In [None]:
import torch
from torchsummary import summary  # Optional, for model summary

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = CombinedGenerator().to(device)
discriminator = Discriminator().to(device)

# Dummy input - adjust size based on your expected LR image size
dummy_input = torch.randn(1, 3, 64, 64).to(device)

# Forward pass
with torch.no_grad():
    output = model(dummy_input)
    d_out = discriminator(output[0])

# Print input and output shapes
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output[0].shape}")
print(f"Output diffusion shape: {output[1].shape}")
print(f"Output disscriminator shape: {d_out.shape}")

# Optional: summary of model
# summary(model, input_size=(3, 128, 128))


In [None]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import random
import numpy as np
import math

class DIV2KDataset(Dataset):
    """DIV2K dataset with HR images and synthetic LR pairs"""
    def __init__(self, root_dir, train=True, scale_factor=4, patch_size=256, augment=True):
        """
        Args:
            root_dir: Directory with HR images
            train: Whether to use training or validation set
            scale_factor: Downsampling factor for LR images
            patch_size: Size of HR patches to extract
            augment: Whether to apply data augmentation
        """
        self.root_dir = root_dir
        self.patch_size = patch_size
        self.scale_factor = scale_factor
        self.augment = augment
        self.lr_patch_size = patch_size // scale_factor
        
        # Get file list
        split = 'train' if train else 'val'
        self.hr_files = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir) 
                                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')) 
                                ])
        
        # Transformations for augmentation
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(90)
        ])
        
        # Create a unique directory for temp files if needed
        self.temp_dir = os.path.join(os.getcwd(), "temp_compression_files")
        os.makedirs(self.temp_dir, exist_ok=True)
    
    def __len__(self):
        return len(self.hr_files)
    
    def __getitem__(self, idx):
        try:
            # Load HR image
            hr_img = Image.open(self.hr_files[idx]).convert('RGB')
            
            # Random crop if image is large enough
            if hr_img.width > self.patch_size and hr_img.height > self.patch_size:
                # Random crop
                left = random.randint(0, hr_img.width - self.patch_size)
                top = random.randint(0, hr_img.height - self.patch_size)
                hr_img = hr_img.crop((left, top, left + self.patch_size, top + self.patch_size))
            else:
                # Resize if image is too small
                hr_img = hr_img.resize((self.patch_size, self.patch_size), Image.BICUBIC)
            
            # Apply augmentation
            if self.augment:
                hr_img = self.transform(hr_img)
            
            # Convert to tensor and normalize
            hr_tensor = transforms.ToTensor()(hr_img)
            
            # Generate LR image with bicubic downsampling
            lr_tensor = torch.nn.functional.interpolate(
                hr_tensor.unsqueeze(0), 
                scale_factor=1/self.scale_factor, 
                mode='bicubic', 
                align_corners=False
            ).squeeze(0)
            
            # Add Gaussian noise to simulate real LR images
            noise = torch.randn_like(lr_tensor) * 0.01
            lr_tensor = torch.clamp(lr_tensor + noise, 0, 1)
            
            # JPEG compression artifacts
            if random.random() < 0.5:
                # Create a unique temp file path for this worker and index
                # Use thread/process ID to avoid conflicts between DataLoader workers
                pid = os.getpid()
                temp_file = os.path.join(self.temp_dir, f"temp_{pid}_{idx}.jpg")
                
                try:
                    # Convert to PIL image with quality degradation
                    lr_img = transforms.ToPILImage()(lr_tensor)
                    # Save with JPEG compression
                    compression_factor = random.randint(60, 95)
                    lr_img.save(temp_file, quality=compression_factor)
                    
                    # Reload as tensor if file exists
                    if os.path.exists(temp_file):
                        lr_tensor = transforms.ToTensor()(Image.open(temp_file))
                        # Remove temporary file
                        os.remove(temp_file)
                except Exception as e:
                    # If any error occurs, log it and continue with the original tensor
                    print(f"Warning: JPEG compression failed for idx {idx}: {e}")
            
            return {'lr': lr_tensor, 'hr': hr_tensor}
            
        except Exception as e:
            print(f"Error processing sample {idx} from {self.hr_files[idx]}: {e}")
            # Return a placeholder if processing fails
            # This allows the dataloader to continue even if one sample fails
            placeholder_hr = torch.zeros(3, self.patch_size, self.patch_size)
            placeholder_lr = torch.zeros(3, self.lr_patch_size, self.lr_patch_size)
            return {'lr': placeholder_lr, 'hr': placeholder_hr}
    
    def __del__(self):
        """Clean up any remaining temporary files when the dataset is destroyed"""
        try:
            # Remove the temp directory and all its contents if it exists
            if os.path.exists(self.temp_dir):
                for file in os.listdir(self.temp_dir):
                    try:
                        os.remove(os.path.join(self.temp_dir, file))
                    except:
                        pass
                os.rmdir(self.temp_dir)
        except:
            pass
            
def setup_diffusion_eval_grid(generator, device, num_examples=4):
    """Creates a visualization grid to show diffusion denoising process"""
    # Get some evaluation samples 
    dataset = DIV2KDataset(root_dir="/kaggle/input/div2k_train_hr/", train=False)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=num_examples, shuffle=True)
    
    lr_imgs, hr_imgs = next(iter(dataloader))['lr'].to(device), next(iter(dataloader))['hr'].to(device)
    
    # Extract diffusion component
    diffusion = generator.module.diffusion if isinstance(generator, torch.nn.DataParallel) else generator.diffusion
    timesteps = diffusion.timesteps
    
    # Save original LR images 
    images = [lr_imgs.cpu()]
    
    # Add noise progressively to show forward process
    t_steps = [int(timesteps * p) for p in [0.25, 0.5, 0.75, 1.0]]
    for t in t_steps:
        t_batch = torch.tensor([t] * lr_imgs.shape[0], device=device)
        noisy_imgs, _ = diffusion.q_sample(lr_imgs, t_batch)
        images.append(noisy_imgs.cpu())
    
    # Denoise progressively to show reverse process
    x = noisy_imgs
    reverse_steps = [int(timesteps * p) for p in [0.75, 0.5, 0.25, 0.0]]
    for t in reverse_steps:
        t_batch = torch.tensor([t] * x.shape[0], device=device)
        with torch.no_grad():
            x = diffusion.p_sample(x, t_batch)
        images.append(x.cpu())
        
    # Final images (LR, HR, SR)
    with torch.no_grad():
        gen_hr, _ = generator(lr_imgs)
    images.append(gen_hr.cpu())  # SR output
    images.append(hr_imgs.cpu())  # Ground truth HR
    
    # Create a grid
    rows = len(images)
    cols = num_examples
    grid = torch.zeros(rows * cols, 3, lr_imgs.shape[2], lr_imgs.shape[3])
    
    for i, batch in enumerate(images):
        for j in range(cols):
            grid[i*cols + j] = batch[j]
            
    return grid.clamp(0, 1)

def evaluate_model(generator, discriminator, test_dataloader, device, num_samples=5):
    """Evaluate model on test set and print metrics"""
    generator.eval()
    discriminator.eval()
    
    psnr_metric = torchmetrics.PeakSignalNoiseRatio().to(device)
    ssim_metric = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    
    total_psnr = 0
    total_ssim = 0
    total_samples = 0
    
    with torch.no_grad():
        for i, batch in enumerate(test_dataloader):
            if i >= num_samples:
                break
                
            lr_imgs = batch['lr'].to(device)
            hr_imgs = batch['hr'].to(device)
            
            # Generate SR images
            sr_imgs, _ = generator(lr_imgs)
            sr_imgs = sr_imgs.clamp(0, 1)
            
            # Calculate metrics
            psnr = psnr_metric(sr_imgs, hr_imgs)
            ssim = ssim_metric(sr_imgs, hr_imgs)
            
            total_psnr += psnr.item() * lr_imgs.size(0)
            total_ssim += ssim.item() * lr_imgs.size(0)
            total_samples += lr_imgs.size(0)
            
            # Save sample images
            if i == 0:
                # Create comparison grid
                grid_imgs = []
                for j in range(min(4, lr_imgs.size(0))):
                    # Bicubic upscale for reference
                    bicubic = torch.nn.functional.interpolate(
                        lr_imgs[j:j+1], scale_factor=4, mode='bicubic', align_corners=False
                    )
                    grid_imgs.extend([
                        bicubic.squeeze(0),
                        sr_imgs[j],
                        hr_imgs[j]
                    ])
                
                # Save grid
                torchvision.utils.save_image(
                    grid_imgs, 
                    f"results/test_samples.png", 
                    nrow=3, 
                    normalize=True
                )
    
    # Calculate average metrics
    avg_psnr = total_psnr / total_samples
    avg_ssim = total_ssim / total_samples
    
    print(f"Test Results: PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")
    return avg_psnr, avg_ssim

In [None]:
len(os.listdir("/kaggle/input/div2k_train_hr/DIV2K_train_HR"))

In [None]:
if __name__ == "__main__":
    # Create dataset instance
    dataset = DIV2KDataset(root_dir="/kaggle/input/DIV2K_train_HR/DIV2K_train_HR", train=True)
    print(f"Dataset size: {len(dataset)}")
    
    # Get a sample
    sample = dataset[0]
    print(f"LR image shape: {sample['lr'].shape}")
    print(f"HR image shape: {sample['hr'].shape}")

In [None]:
import torchvision.models as models
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import pytorch_wavelets
import kornia
from pytorch_wavelets import DWTForward
import kornia.losses
import kornia.filters
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
import os


class PerceptualLoss(nn.Module):
    def __init__(self, layers=[3, 8, 15, 22]):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_FEATURES).features
        self.selected_layers = layers
        self.vgg = vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        loss = 0
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            y = layer(y)
            if i in self.selected_layers:
                loss += F.l1_loss(x, y)
        return loss

def edge_loss(pred, target):
    pred_edges = kornia.filters.sobel(pred)
    target_edges = kornia.filters.sobel(target)
    return nn.L1Loss()(pred_edges, target_edges)


class WaveletLoss(nn.Module):
    def __init__(self):
        super(WaveletLoss, self).__init__()
        self.dwt = DWTForward(J=1, mode='zero', wave='haar')

    def forward(self, pred, target):
        yl_pred, yh_pred = self.dwt(pred)
        yl_target, yh_target = self.dwt(target)

        loss_l = nn.L1Loss()(yl_pred, yl_target)
        loss_h = nn.L1Loss()(yh_pred[0], yh_target[0])
        return loss_l + loss_h

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])) + \
           torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]))

def color_loss(fake, real):
    mean_fake = fake.mean(dim=[2,3)
    mean_real = real.mean(dim=[2,3])
    return nn.functional.l1_loss(mean_fake, mean_real)

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchmetrics
from pytorch_wavelets import DWTForward
import contextlib
import gc

In [None]:
def tv_loss(x):
    """Total variation loss for image smoothness"""
    batch_size = x.size()[0]
    h_x = x.size()[2]
    w_x = x.size()[3]
    count_h = (x.size()[2] - 1) * x.size()[3]
    count_w = x.size()[2] * (x.size()[3] - 1)
    h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x-1, :]), 2).sum()
    w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x-1]), 2).sum()
    return (h_tv + w_tv) / (batch_size * 3 * count_h * count_w)

def edge_loss(pred, target, alpha=1.0):
    """Edge preservation loss using Sobel operator"""
    def sobel(x):
        # Define sobel filters
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=x.device).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=x.device).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
        
        # Apply sobel filters to each channel (RGB)
        grad_x = F.conv2d(x, sobel_x, padding=1, groups=3)
        grad_y = F.conv2d(x, sobel_y, padding=1, groups=3)
        
        return torch.sqrt(grad_x**2 + grad_y**2 + 1e-8)
    
    # Get edges
    edges_pred = sobel(pred)
    edges_target = sobel(target)
    
    # MSE between edges
    return F.mse_loss(edges_pred, edges_target) * alpha

def color_loss(pred, target, alpha=1.0):
    """Color fidelity loss by comparing RGB channels"""
    # Calculate mean color per channel
    pred_mean = pred.mean(dim=[2, 3])
    target_mean = target.mean(dim=[2, 3])
    
    # Calculate color loss
    return F.mse_loss(pred_mean, target_mean) * alpha

class WaveletLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dwt = DWTForward(J=3, mode='zero', wave='haar')
        self.criterion = nn.L1Loss()
        
    def forward(self, x, y):
        """Compare wavelet coefficients at multiple levels"""
        # Get wavelet decomposition
        x_ll, x_h = self.dwt(x)
        y_ll, y_h = self.dwt(y)
        
        # Compare low frequency coefficients
        loss = self.criterion(x_ll, y_ll)
        
        # Compare high frequency coefficients at different levels
        for i in range(len(x_h)):
            loss += self.criterion(x_h[i], y_h[i])
            
        return loss

class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Use VGG features for perceptual loss
        vgg = torchvision.models.vgg19(pretrained=True).features
        self.slices = nn.Sequential()
        for i in range(16):  # Use layers up to relu4_1
            self.slices.add_module(str(i), vgg[i])
            
        # Freeze parameters
        for param in self.slices.parameters():
            param.requires_grad = False
            
        self.criterion = nn.L1Loss()
        
    def forward(self, x, y):
        x = (x - 0.5) * 2
        y = (y - 0.5) * 2
        
        x_features = self.slices(x)
        y_features = self.slices(y)
        
        return self.criterion(x_features, y_features)

def train(
    epochs=50,            
    batch_size=8,
    lr=0.0001,             
    b1=0.5,
    b2=0.999,
    diffusion_steps=100,   
    diffusion_weight=0.2, 
    save_interval=5        
):
    os.makedirs("results", exist_ok=True)
    os.makedirs("saved_models", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # --- Add this safety check for CUDA memory ---
    if torch.cuda.is_available():
        # Try setting to deterministic mode for better debugging
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        # Check available GPU memory before starting
        for i in range(torch.cuda.device_count()):
            total_mem = torch.cuda.get_device_properties(i).total_memory / 1e9
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}, Total memory: {total_mem:.2f} GB")
    
    # Force garbage collection to start clean
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        

    # --- Create Models ---
    generator = CombinedGenerator(in_channels=3, diffusion_steps=diffusion_steps).to(device)
    discriminator = Discriminator().to(device)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        # Use only first two GPUs if available to avoid memory issues
        device_ids = list(range(min(2, torch.cuda.device_count())))
        print(f"Limiting to {len(device_ids)} GPUs: {device_ids}")
        
        generator = nn.DataParallel(generator, device_ids=device_ids)
        discriminator = nn.DataParallel(discriminator, device_ids=device_ids)
    else:
        print("Using single GPU or CPU")

    # --- Losses and Optimizers ---
    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_content = nn.L1Loss()
    criterion_diffusion = nn.MSELoss()
    criterion_perceptual = PerceptualLoss().to(device)
    criterion_edge = edge_loss
    criterion_wavelet = WaveletLoss().to(device)

    # Separate optimizers for diffusion and GAN components
    diffusion_params = generator.module.diffusion.parameters() if isinstance(generator, nn.DataParallel) else generator.diffusion.parameters()
    gan_params = generator.module.gan.parameters() if isinstance(generator, nn.DataParallel) else generator.gan.parameters()
    
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
    optimizer_Diff = torch.optim.Adam(diffusion_params, lr=lr/2, betas=(b1, b2))
    optimizer_G = torch.optim.Adam(gan_params, lr=lr, betas=(b1, b2))

    # --- Learning Rate Schedulers ---
    scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=20, gamma=0.5)
    scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=20, gamma=0.5)
    scheduler_Diff = torch.optim.lr_scheduler.StepLR(optimizer_Diff, step_size=20, gamma=0.5)

    # --- Metrics ---
    psnr_metric = torchmetrics.PeakSignalNoiseRatio().to(device)
    ssim_metric = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    dataset = DIV2KDataset(root_dir="/kaggle/input/DIV2K_train_HR/DIV2K_train_HR", train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    generator_losses = []
    discriminator_losses = []
    diffusion_losses = []
    psnr_scores = []
    ssim_scores = []

    for epoch in range(epochs):
        g_loss_epoch = 0
        d_loss_epoch = 0
        diff_loss_epoch = 0
        psnr_epoch = 0
        ssim_epoch = 0
        
        # Tqdm progress bar
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), 
                           desc=f"Epoch {epoch+1}/{epochs}", leave=True)

        for i, batch in progress_bar:
            lr_imgs = batch['lr'].to(device)
            hr_imgs = batch['hr'].to(device)

            valid = torch.ones((lr_imgs.size(0), 1, 30, 30), device=device)
            fake = torch.zeros((lr_imgs.size(0), 1, 30, 30), device=device)

            # --- Train Diffusion Denoiser ---
            optimizer_Diff.zero_grad()
            
            # Random timesteps for diffusion training
            t = torch.randint(0, generator.module.diffusion.timesteps if isinstance(generator, nn.DataParallel) 
                             else generator.diffusion.timesteps, 
                             (lr_imgs.size(0),), device=device)
            
            # Get diffusion loss directly from the model
            diff_loss, _ = generator.module.diffusion.p_losses(lr_imgs, t) if isinstance(generator, nn.DataParallel) \
                else generator.diffusion.p_losses(lr_imgs, t)
            
            diff_loss.backward()
            optimizer_Diff.step()
            diff_loss_epoch += diff_loss.item()

            # --- Train Generator ---
            optimizer_G.zero_grad()
            
            # Forward pass through the combined model
            gen_hr, x_diff = generator(lr_imgs)
            
            # Clamp outputs to valid range
            gen_hr = gen_hr.clamp(0, 1)
            x_diff = x_diff.clamp(0, 1)
            
            # Discriminator outputs for fake content
            pred_fake = discriminator(gen_hr)

            # Calculate generator loss components
            loss_GAN = criterion_GAN(pred_fake, valid)
            loss_content = criterion_content(gen_hr, hr_imgs)
            loss_diff_mse = criterion_content(x_diff, lr_imgs)  # Ensure diffusion output resembles input
            loss_perceptual = criterion_perceptual(gen_hr, hr_imgs)
            loss_edge = criterion_edge(gen_hr, hr_imgs)
            loss_wavelet = criterion_wavelet(gen_hr, hr_imgs)
            loss_tv = tv_loss(gen_hr)
            loss_color = color_loss(gen_hr, hr_imgs)

            # Combined loss with rebalanced weights
            loss_G = (
                0.5 * loss_content +      # Reduced slightly
                0.1 * loss_GAN +          # Increased importance
                0.15 * loss_perceptual +  # Increased perceptual weight
                0.05 * loss_edge +
                0.05 * loss_wavelet +
                0.001 * loss_tv +
                0.02 * loss_color         # Slightly more color fidelity
            )

            # Backprop
            loss_G.backward()
            optimizer_G.step()

            # --- Train Discriminator ---
            optimizer_D.zero_grad()
            
            # Real images
            pred_real = discriminator(hr_imgs)
            loss_real = criterion_GAN(pred_real, valid)
            
            # Fake images
            pred_fake = discriminator(gen_hr.detach())  # detach to avoid updating generator
            loss_fake = criterion_GAN(pred_fake, fake)
            
            # Combined discriminator loss
            loss_D = (loss_real + loss_fake) / 2
            
            # Backprop
            loss_D.backward()
            optimizer_D.step()

            # Track losses
            g_loss_epoch += loss_G.item()
            d_loss_epoch += loss_D.item()

            # --- Calculate Metrics ---
            with torch.no_grad():
                gen_hr_clamped = gen_hr.clamp(0, 1)
                hr_imgs_clamped = hr_imgs.clamp(0, 1)

                psnr_score = psnr_metric(gen_hr_clamped, hr_imgs_clamped)
                ssim_score = ssim_metric(gen_hr_clamped, hr_imgs_clamped)

                psnr_epoch += psnr_score.item()
                ssim_epoch += ssim_score.item()
                
            # Update progress bar
            progress_bar.set_postfix({
                'G_loss': loss_G.item(), 
                'D_loss': loss_D.item(), 
                'Diff_loss': diff_loss.item(),
                'PSNR': psnr_score.item(),
                'SSIM': ssim_score.item()
            })
            
        # End of batch loop

        # --- Update Learning Rates ---
        scheduler_G.step()
        scheduler_D.step()
        scheduler_Diff.step()

        # Calculate average metrics for epoch
        avg_g_loss = g_loss_epoch / len(dataloader)
        avg_d_loss = d_loss_epoch / len(dataloader)
        avg_diff_loss = diff_loss_epoch / len(dataloader)
        avg_psnr = psnr_epoch / len(dataloader)
        avg_ssim = ssim_epoch / len(dataloader)

        # Record metrics
        generator_losses.append(avg_g_loss)
        discriminator_losses.append(avg_d_loss)
        diffusion_losses.append(avg_diff_loss)
        psnr_scores.append(avg_psnr)
        ssim_scores.append(avg_ssim)

        # Print epoch summary
        print(f"[Epoch {epoch+1}/{epochs}] Generator: {avg_g_loss:.4f}, Discriminator: {avg_d_loss:.4f}, Diffusion: {avg_diff_loss:.4f}")
        print(f"[Epoch {epoch+1}/{epochs}] PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")

        # Visualize samples at the end of each epoch
        if (epoch + 1) % 1 == 0:
            # Select a small batch for visualization
            with torch.no_grad():
                try:
                    num_images = min(gen_hr.size(0), 4)
                    gen_images = gen_hr[:num_images].detach().cpu()
                    real_images = hr_imgs[:num_images].detach().cpu()
        
                    fig, axes = plt.subplots(2, num_images, figsize=(12, 6))
                    for idx in range(num_images):
                        axes[0, idx].imshow(gen_images[idx].permute(1, 2, 0).clamp(0, 1))
                        axes[0, idx].set_title("Generated")
                        axes[0, idx].axis("off")
        
                        axes[1, idx].imshow(real_images[idx].permute(1, 2, 0).clamp(0, 1))
                        axes[1, idx].set_title("Real")
                        axes[1, idx].axis("off")
        
                    plt.suptitle(f"Epoch {epoch+1}")
                    plt.tight_layout()
                    plt.show()
                    image_shown = True
                except Exception as e:
                    print(f"Visualization error: {e}")
                    print("Skipping visualization for this epoch")
                    
                    # Make sure to restore the generator to its original state
                    if isinstance(generator, nn.DataParallel):
                        # Make sure module is on the right device
                        generator.module.to(device)
                    else:
                        generator = generator.to(device)
                    
                    # Force garbage collection
                    gc.collect()
                    torch.cuda.empty_cache()
    
    # --- Additional safety measures for model saving ---
    if (epoch + 1) % save_interval == 0:
        try:
            # Move models to CPU for safer saving
            if isinstance(generator, nn.DataParallel):
                gen_state = generator.module.state_dict()
            else:
                gen_state = generator.state_dict()
                
            if isinstance(discriminator, nn.DataParallel):
                disc_state = discriminator.module.state_dict()
            else:
                disc_state = discriminator.state_dict()
            
            # Save models
            torch.save(gen_state, f"saved_models/generator_epoch_{epoch+1}.pth")
            torch.save(disc_state, f"saved_models/discriminator_epoch_{epoch+1}.pth")
            
        except Exception as e:
            print(f"Error saving models: {e}")
            print("Attempting emergency save...")
            torch.save(gen_state, f"saved_models/generator_emergency.pth")
            
        # Clear CUDA cache
        torch.cuda.empty_cache()
        
    torch.cuda.empty_cache()

    # --- Final Plots ---
    # Plot losses
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(generator_losses, label="Generator", color='blue')
    plt.plot(discriminator_losses, label="Discriminator", color='red')
    plt.plot(diffusion_losses, label="Diffusion", color='green')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Losses")
    plt.legend()
    plt.grid(True)
    
    # Plot metrics
    plt.subplot(1, 2, 2)
    plt.plot(psnr_scores, label="PSNR", color='orange')
    plt.plot([s * 20 for s in ssim_scores], label="SSIM × 20", color='purple')
    plt.xlabel("Epochs")
    plt.ylabel("Score")
    plt.title("Image Quality Metrics")
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig("results/training_plots.png")
    plt.show()
    
    return generator, discriminator

In [None]:
model, disc = train(
    epochs=150,          
    batch_size=8,          
    lr=0.0001,             
    diffusion_steps=100,   
    diffusion_weight=0.2,  
    save_interval=10        
)

In [None]:
torch.cuda.empty_cache()