# Imports & Setup

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import numpy as np
import wandb
import os
import torch.nn.functional as F
# import cv2
from PIL import Image
from tqdm import tqdm
import time
import random
import matplotlib.pyplot as plt

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# For reproducibility
torch.manual_seed(42)
random.seed(42)

%matplotlib inline

In [3]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Hyperparameters Configuration

In [4]:
# === Paths ===
HR_ROOT = "HR"
LR_ROOT = "LR"
SAVE_DIR = "saved_models"

# === Model Architecture ===
NUM_RES_BLOCKS = 10
IN_CHANNELS = 9  # 3 frames × 3 channels
OUT_CHANNELS = 3
FEATURES = 128
UPSCALE_FACTOR = 4  # 120×214 → 480×854 (~4x)

# === Training ===
BATCH_SIZE = 32
NUM_EPOCHS = 100
LR = 1e-3
SAVE_FREQ = 2  # Save every 2 epochs
USE_PERCEPTUAL = False # Change in hyperparameters if needed

# === Data Management ===
TEST_RATIO = 0.1
VAL_RATIO = 0.1
DATA_SPLIT_SEED = 42

# === WandB ===
WANDB_PROJECT = "video-superres"
WANDB_ENTITY = ""  # Your WandB username/team
# === WandB Logging ===
LOG_FREQ = 50  # Log every 50 batches
LOG_IMG_FREQ = 100  # Log images every 100 batches

### WandB Initialization

In [None]:

wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    config={
        "learning_rate": LR,
        "architecture": "MultiFrame-CNN",
        "dataset": "CustomVideoFrames",
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE,
        "upscale_factor": UPSCALE_FACTOR
    }
)



# Dataset Class Definition

In [6]:
class VideoSRDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, seq_ids, transform=None):
        """
        Args:
            hr_dir: Path to HR frames root
            lr_dir: Path to LR frames root
            seq_ids: List of sequence identifiers (format: "video_XXXX_seq_X")
            transform: Optional transforms
        """
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.transform = transform
        self.samples = []
        
        # Build samples list
        for seq_id in seq_ids:
            self._add_sequence_samples(seq_id)

    def _add_sequence_samples(self, seq_id):
        """Add all valid frame triplets from a sequence"""
        seq_lr_path = os.path.join(self.lr_dir, seq_id)
        seq_hr_path = os.path.join(self.hr_dir, seq_id)
        
        # Get sorted frame paths
        lr_frames = sorted([f for f in os.listdir(seq_lr_path) if f.endswith(".png")])
        hr_frames = sorted([f for f in os.listdir(seq_hr_path) if f.endswith(".png")])

        assert len(lr_frames) >= 5, f"{seq_id} has only {len(lr_frames)} LR frames!"
        assert len(hr_frames) >= 5, f"{seq_id} has only {len(hr_frames)} HR frames!"
        
        # Each sequence contains 5 frames: use middle 3 as anchor points
        for i in range(1, 4):  # Center frames 1,2,3 (0-based)
            lr_triplet = [
                os.path.join(seq_lr_path, lr_frames[i-1]),
                os.path.join(seq_lr_path, lr_frames[i]),
                os.path.join(seq_lr_path, lr_frames[i+1])
            ]
            hr_target = os.path.join(seq_hr_path, hr_frames[i])
            
            self.samples.append((lr_triplet, hr_target))

    def __len__(self):
        return len(self.samples)

    
    def __getitem__(self, idx):
        try:
            lr_paths, hr_path = self.samples[idx]
            
            # Load LR frames with progress indication
            lr_stack = []
            for i, path in enumerate(lr_paths):
                if not os.path.exists(path):
                    raise FileNotFoundError(f"LR frame {i} missing: {path}")
                img = Image.open(path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                lr_stack.append(img)
                
            # Load HR frame
            if not os.path.exists(hr_path):
                raise FileNotFoundError(f"HR frame missing: {hr_path}")
            hr_img = Image.open(hr_path).convert('RGB')
            if self.transform:
                hr_img = self.transform(hr_img)
                
            return torch.cat(lr_stack, dim=0), hr_img
            
        except Exception as e:
            print(f"\nError loading sample {idx}:")
            print(f"Sequence: {self.samples[idx][0][0].split('/')[-2]}")
            print(f"LR paths: {lr_paths}")
            print(f"HR path: {hr_path}")
            print(f"Error: {str(e)}")
            raise

# Data Splitting & Preparation

In [None]:
# First get all video sequence identifiers
all_sequences = os.listdir(HR_ROOT)

print("Total sequences in HR_ROOT:", len(all_sequences))
assert len(all_sequences) > 0, "No sequences found in HR_ROOT!"

# Extract unique video IDs (e.g., "video_0000" from "video_0000_seq_0")
video_ids = list(set(["_".join(s.split("_")[:2]) for s in all_sequences]))
video_ids.sort()

# Split video IDs into train/val/test
random.seed(DATA_SPLIT_SEED)
random.shuffle(video_ids)

num_total = len(video_ids)
num_test = int(num_total * TEST_RATIO)
num_val = int(num_total * VAL_RATIO)
num_train = num_total - num_test - num_val

train_vids = video_ids[:num_train]
val_vids = video_ids[num_train:num_train+num_val]
test_vids = video_ids[num_train+num_val:]

# Now collect all sequences for each split
def get_split_sequences(video_list):
    sequences = []
    for vid in video_list:
        # Get all sequences for this video
        seqs = [s for s in all_sequences if s.startswith(vid)]
        sequences.extend(seqs)
    return sequences

train_seqs = get_split_sequences(train_vids)
val_seqs = get_split_sequences(val_vids)
test_seqs = get_split_sequences(test_vids)

print(f"Total sequences - Train: {len(train_seqs)}, Val: {len(val_seqs)}, Test: {len(test_seqs)}")

# After splitting
print("\nSplit Summary:")
print(f"Train Videos: {len(train_vids)} | Val Videos: {len(val_vids)} | Test Videos: {len(test_vids)}")
print(f"Train Seqs: {len(train_seqs)} | Val Seqs: {len(val_seqs)} | Test Seqs: {len(test_seqs)}")

# DataLoader Setup

In [None]:
# Basic transforms (normalize to [-1, 1] range)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Create datasets
train_dataset = VideoSRDataset(HR_ROOT, LR_ROOT, train_seqs, transform=transform)
val_dataset = VideoSRDataset(HR_ROOT, LR_ROOT, val_seqs, transform=transform)
test_dataset = VideoSRDataset(HR_ROOT, LR_ROOT, test_seqs, transform=transform)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    prefetch_factor=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    prefetch_factor=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# Quick sanity check
sample_lr, sample_hr = next(iter(train_loader))
print(f"LR input shape: {sample_lr.shape}")  # Should be [B, 9, H, W]
print(f"HR target shape: {sample_hr.shape}") # Should be [B, 3, 4H, 4W]

# Model Architecture

### residual block

In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        residual = x
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x + residual

### model

In [None]:
# Base SR Model (first stage)
class SRModel(nn.Module):
    def __init__(self, in_channels=9, out_channels=3, features=64, num_res_blocks=5, upscale_factor=4):
        super().__init__()
        # Initial feature extraction
        self.conv1 = nn.Conv2d(in_channels, features, kernel_size=3, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualBlock(features) for _ in range(num_res_blocks)])
        # Upscaling
        self.conv2 = nn.Conv2d(features, out_channels * (upscale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.res_blocks(x)
        x = self.conv2(x)
        x = self.pixel_shuffle(x)
        return x

# New refinement network (second stage) to boost high-frequency details.
class RefinementNet(nn.Module):
    def __init__(self, channels=3, features=64, num_res_blocks=3):
        super(RefinementNet, self).__init__()
        self.conv_in = nn.Conv2d(channels, features, kernel_size=3, padding=1)
        # A few residual blocks for texture refinement
        res_blocks = [ResidualBlock(features) for _ in range(num_res_blocks)]
        self.res_blocks = nn.Sequential(*res_blocks)
        self.conv_out = nn.Conv2d(features, channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        out = F.relu(self.conv_in(x))
        out = self.res_blocks(out)
        out = self.conv_out(out)
        # Adding a residual connection so that the refinement net refines rather than re-predicts entirely.
        return x + out

# Composite model that feeds the output of the base model into the refinement network.
class TwoStageSRModel(nn.Module):
    def __init__(self, base_model, refinement_model):
        super(TwoStageSRModel, self).__init__()
        self.base_model = base_model
        self.refinement_model = refinement_model
        
    def forward(self, x):
        base_output = self.base_model(x)
        refined_output = self.refinement_model(base_output)
        return refined_output

# Instantiate the base model
base_model = SRModel(
    in_channels=IN_CHANNELS,      # from your hyperparameters (9 channels)
    out_channels=OUT_CHANNELS,    # (3 channels)
    features=FEATURES,            # e.g., 64
    num_res_blocks=NUM_RES_BLOCKS,  # e.g., 5
    upscale_factor=UPSCALE_FACTOR   # e.g., 4
).to(device)

# Optionally, if you already have pretrained weights for the base model,
# load them here. For example:
# checkpoint = torch.load('saved_models/best.pth', map_location=device)
# base_model.load_state_dict(checkpoint['state_dict'])

# (Optional) Freeze base model parameters if you want to train only the refinement net:
# for param in base_model.parameters():
#     param.requires_grad = False

# Instantiate the refinement network
refinement_net = RefinementNet(channels=OUT_CHANNELS, features=FEATURES, num_res_blocks=3).to(device)

# Create the composite two-stage model
model = TwoStageSRModel(base_model, refinement_net).to(device)

# Print model summary
print(model)
print(f"Total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Loss & Optimizer Setup

In [11]:
# L1 Loss (MAE)
l1_loss = nn.L1Loss().to(device)

# Optional Perceptual Loss (VGG-based)
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg19(pretrained=True).features[:35].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.to(device)
        self.loss = nn.L1Loss()
        
    def forward(self, pred, target):
        vgg_pred = self.vgg(self.normalize_vgg(pred))
        vgg_target = self.vgg(self.normalize_vgg(target.detach()))
        return self.loss(vgg_pred, vgg_target)
    
    def normalize_vgg(self, x):
        # Convert from [-1,1] range to VGG expected [0,1]
        return (x + 1) / 2

# Set USE_PERCEPTUAL = False to disable

perceptual_weight = 0.2 if USE_PERCEPTUAL else 0.0

perceptual_loss = VGGLoss().to(device) if USE_PERCEPTUAL else None

# Combined loss
def compute_loss(pred, target):
    loss = l1_loss(pred, target)
    if USE_PERCEPTUAL:
        loss += perceptual_weight * perceptual_loss(pred, target)
    return loss

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)

# Metrics Setup

In [12]:

psnr_metric = PeakSignalNoiseRatio().to(device)
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

def compute_metrics(pred, target):
    # Convert from [-1,1] to [0,1] range for metrics
    pred_denorm = (pred + 1) / 2
    target_denorm = (target + 1) / 2
    
    # PSNR
    # psnr = PeakSignalNoiseRatio().to(device)
    psnr_val = psnr_metric(pred_denorm, target_denorm)
    
    # SSIM
    # ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    ssim_val =ssim_metric(pred_denorm, target_denorm)
    
    return psnr_val.item(), ssim_val.item()

In [13]:
def log_predictions(model, epoch, num_samples=3):
    model.eval()
    with torch.no_grad():
        lr, hr = next(iter(val_loader))
        lr, hr = lr.to(device), hr.to(device)
        pred = model(lr)
        
        # Upsample LR to match HR dimensions
        lr_upsampled = torch.nn.functional.interpolate(
            lr[:, 3:6],  # Center frame
            size=hr.shape[-2:],  # Target size
            mode='bicubic',
            align_corners=False
        )
        
        # Denormalize
        lr_vis = (lr_upsampled + 1) / 2
        hr_vis = (hr + 1) / 2
        pred_vis = (pred + 1) / 2
        
        comparisons = []
        for i in range(num_samples):
            comparison = torch.cat([
                lr_vis[i].cpu(),
                pred_vis[i].cpu(),
                hr_vis[i].cpu()
            ], dim=-1)  # Concatenate along width
            comparisons.append(comparison)
        
        grid = torchvision.utils.make_grid(comparisons, nrow=1)
        
        wandb.log({
            "Predictions": wandb.Image(grid, caption=f"Epoch {epoch+1}: LR | Pred | HR")
        })

# Training Loop Setup

In [14]:
def train_epoch(model, loader):
    model.train()
    epoch_loss = 0.0
    psnr_total = 0.0
    ssim_total = 0.0
    
    for batch_idx, (lr, hr) in enumerate(loader):
        lr = lr.to(device)
        hr = hr.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(lr)
        outputs = outputs[:, :, :hr.size(2), :hr.size(3)]
        loss = compute_loss(outputs, hr)
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 0.01)
        optimizer.step()
        
        # Metrics
        psnr, ssim = compute_metrics(outputs.detach(), hr.detach())


         # Batch-level logging
        if batch_idx % LOG_FREQ == 0:
            wandb.log({
                "train_batch_loss": loss.item(),
                "train_batch_psnr": psnr,
                "train_batch_ssim": ssim
            })
            
        # Image logging
        if batch_idx % LOG_IMG_FREQ == 0:
            with torch.no_grad():
                pred_vis = (outputs[:1].detach() + 1) / 2
                lr_vis = (lr[:1, 3:6].detach() + 1) / 2
                grid = torch.cat([lr_vis.cpu(), pred_vis.cpu()], dim=-1)
                wandb.log({
                    "train_samples": wandb.Image(grid, 
                    caption=f"Batch {batch_idx}: LR | Pred")
                })


        epoch_loss += loss.item()
        psnr_total += psnr
        ssim_total += ssim
        
        # Progress update
        if batch_idx % 50 == 0:
            print(f"  Batch {batch_idx}/{len(loader)} - Loss: {loss.item():.4f}")
            
    avg_loss = epoch_loss / len(loader)
    avg_psnr = psnr_total / len(loader)
    avg_ssim = ssim_total / len(loader)
    return avg_loss, avg_psnr, avg_ssim

def validate(model, loader):
    model.eval()
    epoch_loss = 0.0
    psnr_total = 0.0
    ssim_total = 0.0
    
    with torch.no_grad():
        for lr, hr in loader:
            lr = lr.to(device)
            hr = hr.to(device)
            
            outputs = model(lr)
            outputs = outputs[:, :, :hr.size(2), :hr.size(3)]
            loss = compute_loss(outputs, hr)
            
            psnr, ssim = compute_metrics(outputs.detach(), hr.detach())
            epoch_loss += loss.item()
            psnr_total += psnr
            ssim_total += ssim
            
    avg_loss = epoch_loss / len(loader)
    avg_psnr = psnr_total / len(loader)
    avg_ssim = ssim_total / len(loader)
    return avg_loss, avg_psnr, avg_ssim

# Model Checkpoint Saving

In [15]:
# Initialize best validation loss tracker
best_val_loss = float('inf')

In [16]:
def save_checkpoint(epoch, model, optimizer, is_best=False):
    state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_metric': best_val_loss
    }
    
    # Ensure save directory exists
    os.makedirs(SAVE_DIR, exist_ok=True)
    
    # Always save latest
    latest_path = os.path.join(SAVE_DIR, "latest.pth")
    torch.save(state, latest_path)
    
    # Save periodically
    if epoch % SAVE_FREQ == 0:
        periodic_path = os.path.join(SAVE_DIR, f"epoch_{epoch}.pth")
        torch.save(state, periodic_path)
    
    # Save best separately
    if is_best:
        best_path = os.path.join(SAVE_DIR, "best.pth")
        torch.save(state, best_path)
        print(f"New best model saved at epoch {epoch} with val loss: {best_val_loss:.4f}")

# Training Execution

In [None]:
# Helper to format time
def format_time(seconds):
    if seconds < 60:
        return f"{seconds:.0f}s"
    elif seconds < 3600:
        return f"{seconds//60:.0f}m {seconds%60:.0f}s"
    else:
        return f"{seconds//3600:.0f}h {(seconds%3600)//60:.0f}m"

best_val_loss = float('inf')
epoch_times = []
start_time = time.time()

try:
    # Initialize main progress bar
    pbar = tqdm(range(NUM_EPOCHS), desc="Training", unit="epoch")
    
    for epoch in pbar:
        epoch_start = time.time()
        
        # --- Training Phase ---
        model.train()
        train_loss = 0.0
        train_psnr = 0.0
        train_ssim = 0.0
        
        # Batch progress bar
        batch_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
        for batch_idx, (lr, hr) in enumerate(batch_pbar):
            lr, hr = lr.to(device), hr.to(device)
            
            optimizer.zero_grad()
            outputs = model(lr)
            outputs = outputs[:, :, :hr.size(2), :hr.size(3)]
            loss = compute_loss(outputs, hr)
            loss.backward()
            optimizer.step()
            
            # Update metrics
            psnr, ssim = compute_metrics(outputs, hr)
            train_loss += loss.item()
            train_psnr += psnr
            train_ssim += ssim
            
            # Update batch progress
            avg_loss = train_loss / (batch_idx + 1)
            batch_pbar.set_postfix({
                'loss': f"{avg_loss:.4f}",
                'psnr': f"{train_psnr/(batch_idx+1):.2f}"
            })
            
        batch_pbar.close()
        
        # --- Validation Phase ---
        val_loss, val_psnr, val_ssim = validate(model, val_loader)
        scheduler.step(val_loss)
        
        # --- Epoch Timing ---
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)
        avg_epoch_time = np.mean(epoch_times[-5:])  # Moving average of last 5 epochs
        remaining_time = avg_epoch_time * (NUM_EPOCHS - epoch - 1)
        
        # Update main progress bar
        pbar.set_postfix({
            'val_loss': f"{val_loss:.4f}",
            'val_psnr': f"{val_psnr:.2f}",
            'eta': format_time(remaining_time)
        })
        
        # --- Model Checkpointing ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(epoch, model, optimizer, is_best=True)
        else:
            save_checkpoint(epoch, model, optimizer)
        
        # --- Logging ---
        wandb.log({
            "train/loss": train_loss/len(train_loader),
            "train/psnr": train_psnr/len(train_loader),
            "train/ssim": train_ssim/len(train_loader),
            "val/loss": val_loss,
            "val/psnr": val_psnr,
            "val/ssim": val_ssim,
            "epoch_time": epoch_time,
            "lr": optimizer.param_groups[0]['lr']
        })
        
        # --- Visualizations ---
        if (epoch + 1) % 2 == 0:  # Every 2 epochs to reduce overhead
            log_predictions(model, epoch)

except KeyboardInterrupt:
    print("\nTraining interrupted! Saving latest model...")
    save_checkpoint(epoch, model, optimizer)
    print(f"Safe to exit now. Total runtime: {format_time(time.time() - start_time)}")
    wandb.finish()
finally:
    pbar.close()

print(f"\nTraining complete! Total duration: {format_time(time.time() - start_time)}")

# Testing & Final Evaluation

In [None]:
def load_best_model():
    checkpoint = torch.load(os.path.join(SAVE_DIR, "best.pth"))
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model

# Load best model
model = load_best_model().to(device)

# Test evaluation
test_loss, test_psnr, test_ssim = validate(model, test_loader)

print(f"\nFinal Test Results:")
print(f"Loss: {test_loss:.4f}")
print(f"PSNR: {test_psnr:.2f} dB")
print(f"SSIM: {test_ssim:.4f}")

# Log to WandB
wandb.log({
    "test/loss": test_loss,
    "test/psnr": test_psnr,
    "test/ssim": test_ssim
})

# Save sample visualizations
def save_samples(loader, num_samples=3):
    model.eval()
    with torch.no_grad():
        samples = []
        for lr, hr in loader:
            lr = lr.to(device)
            pred = model(lr)
            
            # Upscale LR to match HR dimensions
            lr_upsampled = torch.nn.functional.interpolate(
                lr[:, 3:6],  # Center frame
                size=hr.shape[-2:],  # Target size (480x854)
                mode='bicubic',
                align_corners=False
            )
            
            # Convert to CPU and denormalize
            lr_vis = (lr_upsampled + 1) / 2
            pred_vis = (pred + 1) / 2
            hr_vis = (hr + 1) / 2

            # Create comparison samples
            for i in range(num_samples):
                sample = torch.cat([
                    lr_vis[i].cpu(),
                    pred_vis[i].cpu(),
                    hr_vis[i].cpu()
                ], dim=-1)  # Concatenate along width
                samples.append(sample)
            
            break  # Only first batch
        
        grid = torchvision.utils.make_grid(samples, nrow=1)
        wandb.log({"Test Results": wandb.Image(grid, caption="LR | Pred | HR")})

save_samples(test_loader)

# Visualization & Analysis

In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

# # Loss
# plt.subplot(1, 3, 1)
# plt.plot(wandb.run.history()['train/loss'], label='Train')
# plt.plot(wandb.run.history()['val/loss'], label='Validation')
# plt.title('Loss Curve')
# plt.xlabel('Epoch')
# plt.legend()

# # PSNR
# plt.subplot(1, 3, 2)
# plt.plot(wandb.run.history()['train/psnr'], label='Train')
# plt.plot(wandb.run.history()['val/psnr'], label='Validation')
# plt.title('PSNR')
# plt.xlabel('Epoch')

# # SSIM
# plt.subplot(1, 3, 3)
# plt.plot(wandb.run.history()['train/ssim'], label='Train')
# plt.plot(wandb.run.history()['val/ssim'], label='Validation')
# plt.title('SSIM')
# plt.xlabel('Epoch')

plt.tight_layout()
plt.show()

# Additional analysis
wandb.finish()
print("Training completed and results logged!")