In [None]:
# First, uninstall existing PyTorch
!pip uninstall -y torch torchvision torchaudio

# Install PyTorch with CUDA 12.1 support (for RTX 4060)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Verify installation
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")

In [None]:
!pip uninstall sympy -y
!pip install sympy


In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121


In [None]:
import torch

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")



In [None]:
!pip install gradio


In [None]:
!pip install kornia


In [1]:
# %% [markdown]
# # Image Sharpening using Knowledge Distillation (GPU Optimized)

# %%
# Initial setup for Windows
import os
import pickle
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torch.amp import GradScaler
from torchvision.models import vgg16, VGG16_Weights
from piq import ssim
import gdown
import math
from einops import rearrange

# Environment setup
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")
    torch.backends.cudnn.benchmark = True

# %% [markdown]
## Configuration (Optimized for RTX 4060)

# %%
class Config:
    # ===== UPDATE THESE PATHS =====
    TRAIN_BLUR = r"C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\train\blur"
    TRAIN_SHARP_PATH = r"C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\train\sharp"
    TEST_BLUR_PATH = r"C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\test\blur"
    TEST_SHARP_PATH = r"C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\test\sharp"
    # ==============================
    
    # Training parameters
    BATCH_SIZE = 8  # Increased for GPU
    PATCH_SIZE = 256
    NUM_EPOCHS = 100
    LR = 1e-4
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 0  # Windows often has issues with multiprocessing
    
    # Model paths
    TEACHER_WEIGHTS = "motion_deblurring.pth"
    STUDENT_SAVE_PATH = "student_model.pth"
    
    # Preloading cache
    TRAIN_CACHE_PATH = "train_dataset_cache.pkl"
    TEST_CACHE_PATH = "test_dataset_cache.pkl"
    
    # Evaluation
    BENCHMARK_SIZE = 100
    TARGET_RES = (1920, 1080)

config = Config()

# Verify paths
print("\nPath verification:")
for path_type in ['TRAIN_BLUR', 'TRAIN_SHARP', 'TEST_BLUR', 'TEST_SHARP']:
    path = getattr(config, f"{path_type}_PATH")
    exists = os.path.exists(path)
    print(f"{path_type}: {'✅' if exists else '❌'} {path}")
    if exists:
        print(f"   Contains {len(os.listdir(path))} files")

# %% [markdown]
## Fixed Restormer Implementation (WithBias Version)

# %%
# Restormer components with WithBias LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type='WithBias'):
        super().__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        return self.body(x)

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape
        
    def forward(self, x):
        mu = x.mean(1, keepdim=True)
        sigma = x.var(1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight[None, :, None, None]

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape
        
    def forward(self, x):
        mu = x.mean(1, keepdim=True)
        sigma = x.var(1, keepdim=True, unbiased=False)
        x_normalized = (x - mu) / torch.sqrt(sigma + 1e-5)
        return x_normalized * self.weight[None, :, None, None] + self.bias[None, :, None, None]

class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super().__init__()
        hidden_features = int(dim*ffn_expansion_factor)
        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, padding=1, groups=hidden_features*2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        
    def forward(self, x):
        b,c,h,w = x.shape
        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        out = (attn @ v)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = self.project_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super().__init__()
        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super().__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, padding=1, bias=bias)

    def forward(self, x):
        return self.proj(x)

class Downsample(nn.Module):
    def __init__(self, n_feat):
        super().__init__()
        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

class Upsample(nn.Module):
    def __init__(self, n_feat):
        super().__init__()
        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)

class Restormer(nn.Module):
    def __init__(self, 
        inp_channels=3, 
        out_channels=3, 
        dim=48,
        num_blocks=[4,6,6,8], 
        num_refinement_blocks=4,
        heads=[1,2,4,8],
        ffn_expansion_factor=2.66,
        bias=False,
        LayerNorm_type='WithBias'
    ):
        super().__init__()
        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.down1_2 = Downsample(dim)
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.down2_3 = Downsample(int(dim*2**1))
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
        self.down3_4 = Downsample(int(dim*2**2))
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        self.up4_3 = Upsample(int(dim*2**3))
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
        self.up3_2 = Upsample(int(dim*2**2))
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.up2_1 = Upsample(int(dim*2**1))
        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, padding=1, bias=bias)

    def forward(self, inp_img):
        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)
        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)
        inp_enc_level4 = self.down3_4(out_enc_level3)        
        latent = self.latent(inp_enc_level4)
        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)
        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)
        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)
        out_dec_level1 = self.refinement(out_dec_level1)
        out_dec_level1 = self.output(out_dec_level1) + inp_img
        return out_dec_level1

# Security registration for model loading
try:
    torch.serialization.add_safe_globals([Restormer])
    print("Added Restormer to safe globals for secure model loading")
except AttributeError:
    print("Warning: torch.serialization.add_safe_globals not available in this PyTorch version")

# %% [markdown]
## Download Teacher Weights

# %%
if not os.path.exists(config.TEACHER_WEIGHTS):
    print("Downloading teacher weights...")
    gdown.download(
        "https://github.com/swz30/Restormer/releases/download/v1.0/motion_deblurring.pth", 
        config.TEACHER_WEIGHTS, 
        quiet=False
    )
    print("Download complete!")

# %% [markdown]
## Model Definitions (GPU Optimized)

# %%
# Lightweight Student Model with memory optimizations
class StudentModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder with inplace ReLU to save memory
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.output = nn.Conv2d(32, 3, 1)
    
    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        b = self.bottleneck(p2)
        d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.output(d1)

# Teacher Model
class TeacherModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Restormer()
        
        # Load weights with explicit weights_only=False
        state_dict = torch.load(config.TEACHER_WEIGHTS, map_location=config.DEVICE, weights_only=False)
        self.model.load_state_dict(state_dict['params'])
        self.model.eval()
    
    def forward(self, x):
        with torch.no_grad():
            return self.model(x)

# Initialize models with GPU optimization
teacher = TeacherModel().to(config.DEVICE)
student = StudentModel().to(config.DEVICE)

# Freeze teacher model
for param in teacher.parameters():
    param.requires_grad = False

print(f"Teacher parameters: {sum(p.numel() for p in teacher.parameters())/1e6:.2f}M")
print(f"Student parameters: {sum(p.numel() for p in student.parameters())/1e6:.2f}M")

# Test models on GPU
if torch.cuda.is_available():
    with torch.no_grad():
        test_input = torch.randn(1, 3, 256, 256).to(config.DEVICE)
        output = teacher(test_input)
        print(f"Teacher test output shape: {output.shape}")
        output = student(test_input)
        print(f"Student test output shape: {output.shape}")

# %% [markdown]
## Dataset Preparation with Preloading Cache

# %%
class GoProDataset(Dataset):
    def __init__(self, blur_dir, sharp_dir, patch_size=256, train=True, cache_path=None):
        self.blur_dir = blur_dir
        self.sharp_dir = sharp_dir
        self.patch_size = patch_size
        self.train = train
        self.cache_path = cache_path
        
        # Try to load from cache
        if cache_path and os.path.exists(cache_path):
            print(f"Loading preloaded dataset from cache: {cache_path}")
            with open(cache_path, 'rb') as f:
                self.blur_images, self.sharp_images = pickle.load(f)
        else:
            self.blur_files = sorted(os.listdir(blur_dir))
            self.sharp_files = sorted(os.listdir(sharp_dir))
            
            assert len(self.blur_files) == len(self.sharp_files), "Mismatched dataset sizes"
            
            # Preload images to RAM
            print(f"Preloading {len(self.blur_files)} images...")
            self.blur_images = []
            self.sharp_images = []
            for i in tqdm(range(len(self.blur_files))):
                self.blur_images.append(Image.open(os.path.join(blur_dir, self.blur_files[i])).convert('RGB'))
                self.sharp_images.append(Image.open(os.path.join(sharp_dir, self.sharp_files[i])).convert('RGB'))
            print("Preloading complete!")
            
            # Save to cache if path provided
            if cache_path:
                print(f"Saving preloaded dataset to cache: {cache_path}")
                with open(cache_path, 'wb') as f:
                    pickle.dump((self.blur_images, self.sharp_images), f)

    def __len__(self):
        return len(self.blur_images)

    def __getitem__(self, idx):
        blur_img = self.blur_images[idx]
        sharp_img = self.sharp_images[idx]
        
        if self.train:
            w, h = blur_img.size
            x = random.randint(0, w - self.patch_size)
            y = random.randint(0, h - self.patch_size)
            blur_img = blur_img.crop((x, y, x+self.patch_size, y+self.patch_size))
            sharp_img = sharp_img.crop((x, y, x+self.patch_size, y+self.patch_size))
            
            if random.random() > 0.5:
                blur_img = blur_img.transpose(Image.FLIP_LEFT_RIGHT)
                sharp_img = sharp_img.transpose(Image.FLIP_LEFT_RIGHT)
        else:
            blur_img = blur_img.resize(config.TARGET_RES)
            sharp_img = sharp_img.resize(config.TARGET_RES)
        
        # Convert to tensor
        blur_tensor = TF.to_tensor(blur_img)
        sharp_tensor = TF.to_tensor(sharp_img)
        
        return blur_tensor, sharp_tensor

# Create datasets with caching
train_dataset = GoProDataset(
    config.TRAIN_BLUR_PATH,
    config.TRAIN_SHARP_PATH,
    config.PATCH_SIZE,
    train=True,
    cache_path=config.TRAIN_CACHE_PATH
)

test_dataset = GoProDataset(
    config.TEST_BLUR_PATH,
    config.TEST_SHARP_PATH,
    config.PATCH_SIZE,
    train=False,
    cache_path=config.TEST_CACHE_PATH
)

# DataLoaders with GPU optimizations
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False,
    persistent_workers=True if config.NUM_WORKERS > 0 else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

# Save sample images instead of showing to prevent blocking
os.makedirs("debug", exist_ok=True)
blur, sharp = next(iter(train_loader))
print(f"Batch shape: {blur.shape}")

def save_tensor_image(tensor, path):
    img = tensor.cpu().permute(1, 2, 0).numpy()
    img = (img * 255).clip(0, 255).astype(np.uint8)
    Image.fromarray(img).save(path)

save_tensor_image(blur[0], "debug/blurry_sample.jpg")
save_tensor_image(sharp[0], "debug/sharp_sample.jpg")
print("Saved sample images to debug/ folder!")

# %% [markdown]
## Training Setup (GPU Optimized)

# %%
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.3, gamma=0.2):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        
        # Use modern weights API
        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16]
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.eval().to(config.DEVICE)
        
        self.l1_loss = nn.L1Loss()
        
    def forward(self, student_out, teacher_out, target):
        rec_loss = self.l1_loss(student_out, target)
        s_features = self.vgg(student_out)
        with torch.no_grad():
            t_features = self.vgg(teacher_out)
        perc_loss = self.l1_loss(s_features, t_features)
        feat_loss = self.l1_loss(student_out, teacher_out)
        return (self.alpha * rec_loss + 
                self.beta * perc_loss + 
                self.gamma * feat_loss)

criterion = DistillationLoss().to(config.DEVICE)
optimizer = torch.optim.Adam(student.parameters(), lr=config.LR)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5
)

# Mixed precision scaler (updated API)
scaler = torch.amp.GradScaler()

# Training checkpoint
def save_checkpoint(epoch, history):
    checkpoint = {
        'epoch': epoch,
        'student_state_dict': student.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'history': history,
        'best_ssim': best_ssim
    }
    torch.save(checkpoint, f"checkpoint_epoch_{epoch}.pth")
    print(f"Saved checkpoint at epoch {epoch}")

# %% [markdown]
## Training Loop (GPU Optimized)

# %%
def validate():
    student.eval()
    total_ssim = 0.0
    count = 0
    
    with torch.no_grad():
        for blur, sharp in test_loader:
            blur = blur.to(config.DEVICE, non_blocking=True)
            sharp = sharp.to(config.DEVICE, non_blocking=True)
            
            # Run in full precision for validation
            output = student(blur)
            
            # Clamp output to [0, 1] range for SSIM calculation
            output_clamped = output.clamp(0, 1)
            
            # Convert to float32 for SSIM calculation
            total_ssim += ssim(
                output_clamped.float(), 
                sharp.float(), 
                data_range=1.0
            ).item()
            count += 1
            
            if count >= config.BENCHMARK_SIZE:
                break
    
    return total_ssim / count

# Training with GPU optimizations
best_ssim = 0.0
history = {'loss': [], 'ssim': []}

# Try to load checkpoint
start_epoch = 0
checkpoint_files = [f for f in os.listdir() if f.startswith('checkpoint_epoch_')]
if checkpoint_files:
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
    print(f"Resuming from checkpoint: {latest_checkpoint}")
    checkpoint = torch.load(latest_checkpoint, map_location=config.DEVICE)
    student.load_state_dict(checkpoint['student_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    history = checkpoint['history']
    best_ssim = checkpoint['best_ssim']
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")

print("Starting training...")
for epoch in range(start_epoch, config.NUM_EPOCHS):
    student.train()
    epoch_loss = 0.0
    steps = 0
    
    # Optimized progress bar
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS}", total=len(train_loader))
    for blur, sharp in pbar:
        # Async data transfer
        blur = blur.to(config.DEVICE, non_blocking=True)
        sharp = sharp.to(config.DEVICE, non_blocking=True)
        
        with torch.no_grad():
            # Updated autocast API
            with torch.amp.autocast(device_type=config.DEVICE.type, dtype=torch.float16):
                teacher_out = teacher(blur)
        
        # Mixed precision training
        optimizer.zero_grad(set_to_none=True)
        
        # Updated autocast API
        with torch.amp.autocast(device_type=config.DEVICE.type, dtype=torch.float16):
            student_out = student(blur)
            loss = criterion(student_out, teacher_out, sharp)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_loss += loss.item()
        steps += 1
        
        # Update progress bar
        pbar.set_postfix(loss=loss.item())
    
    # Validation and checkpointing
    avg_ssim = validate()
    epoch_loss /= len(train_loader)
    scheduler.step(avg_ssim)
    
    history['loss'].append(epoch_loss)
    history['ssim'].append(avg_ssim)
    
    if avg_ssim > best_ssim:
        best_ssim = avg_ssim
        torch.save(student.state_dict(), config.STUDENT_SAVE_PATH)
        print(f"Saved best model (SSIM: {best_ssim:.4f})")
    
    print(f"Epoch {epoch+1}/{config.NUM_EPOCHS} | Loss: {epoch_loss:.4f} | SSIM: {avg_ssim:.4f}")
    
    # Save checkpoint every epoch
    save_checkpoint(epoch, history)
    
    # Manual GPU utilization report
    if torch.cuda.is_available():
        mem_used = torch.cuda.memory_allocated() / 1024**3
        mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU Memory: {mem_used:.2f}/{mem_total:.2f} GB ({mem_used/mem_total:.1%})")

# Final model save
torch.save(student.state_dict(), "final_student_model.pth")

# %% [markdown]
## Evaluation & Results (GPU Optimized)

# %%
# Load best model
student.load_state_dict(torch.load(config.STUDENT_SAVE_PATH))
student.eval()

# Test SSIM
def test_ssim():
    total_ssim = 0.0
    count = 0
    
    with torch.no_grad():
        for blur, sharp in test_loader:
            blur = blur.to(config.DEVICE)
            sharp = sharp.to(config.DEVICE)
            
            output = student(blur)
            output_clamped = output.clamp(0, 1)
            
            # Convert to float32 for SSIM
            total_ssim += ssim(
                output_clamped.float(), 
                sharp.float(), 
                data_range=1.0
            ).item()
            count += 1
            
            if count >= config.BENCHMARK_SIZE:
                break
    
    return total_ssim / count

# Test speed
def test_fps():
    student.eval()
    dummy_input = torch.randn(1, 3, *config.TARGET_RES).to(config.DEVICE)
    
    # Warmup
    for _ in range(10):
        _ = student(dummy_input)
    
    # Benchmark
    start = time.time()
    for _ in range(100):
        _ = student(dummy_input)
    torch.cuda.synchronize()  # Wait for GPU to finish
    elapsed = time.time() - start
    
    return 100 / elapsed

# Run evaluation
avg_ssim = test_ssim()
fps = test_fps()

print(f"\n{'='*40}")
print(f"Final Evaluation:")
print(f"{'='*40}")
print(f"SSIM: {avg_ssim:.4f}")
print(f"FPS: {fps:.2f} @ {config.TARGET_RES}")
print(f"{'='*40}")

# Save visual comparison
blur, sharp = next(iter(test_loader))
with torch.no_grad():
    output = student(blur.to(config.DEVICE)).cpu()

def save_tensor_image(tensor, path):
    img = tensor.permute(1, 2, 0).numpy()
    img = (img * 255).clip(0, 255).astype(np.uint8)
    Image.fromarray(img).save(path)

save_tensor_image(blur[0], "debug/input_blurry.jpg")
save_tensor_image(output[0], "debug/output_sharpened.jpg")
save_tensor_image(sharp[0], "debug/ground_truth.jpg")
print("Saved comparison images to debug/ folder!")

PyTorch version: 2.5.1+cu121
CUDA available: True
GPU: NVIDIA GeForce RTX 4060 Laptop GPU
VRAM: 8.00 GB

Path verification:
TRAIN_BLUR: ✅ C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\train\blur
   Contains 2103 files
TRAIN_SHARP: ✅ C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\train\sharp
   Contains 2103 files
TEST_BLUR: ✅ C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\test\blur
   Contains 1111 files
TEST_SHARP: ✅ C:\Users\Lenovo\Downloads\image_sharpening_kd_project (2)\image_sharpening_kd_project\data\Gopro\test\sharp
   Contains 1111 files
Added Restormer to safe globals for secure model loading
Teacher parameters: 26.13M
Student parameters: 0.47M
Teacher test output shape: torch.Size([1, 3, 256, 256])
Student test output shape: torch.Size([1, 3, 256, 256])
Loading preloaded dataset from cache: train_dataset_cache.pkl
Loading prel

Epoch 1/100: 100%|███████████████████████████████████████████████████████| 263/263 [06:38<00:00,  1.51s/it, loss=0.106]


Saved best model (SSIM: 0.7461)
Epoch 1/100 | Loss: 0.1889 | SSIM: 0.7461
Saved checkpoint at epoch 0
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 2/100: 100%|██████████████████████████████████████████████████████| 263/263 [05:56<00:00,  1.35s/it, loss=0.0978]


Saved best model (SSIM: 0.7571)
Epoch 2/100 | Loss: 0.1068 | SSIM: 0.7571
Saved checkpoint at epoch 1
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 3/100: 100%|█████████████████████████████████████████████████████████| 263/263 [05:02<00:00,  1.15s/it, loss=0.1]


Saved best model (SSIM: 0.7579)
Epoch 3/100 | Loss: 0.0987 | SSIM: 0.7579
Saved checkpoint at epoch 2
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 4/100: 100%|██████████████████████████████████████████████████████| 263/263 [05:06<00:00,  1.17s/it, loss=0.0734]


Saved best model (SSIM: 0.7584)
Epoch 4/100 | Loss: 0.0965 | SSIM: 0.7584
Saved checkpoint at epoch 3
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 5/100: 100%|███████████████████████████████████████████████████████| 263/263 [04:52<00:00,  1.11s/it, loss=0.111]


Saved best model (SSIM: 0.7597)
Epoch 5/100 | Loss: 0.0947 | SSIM: 0.7597
Saved checkpoint at epoch 4
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 6/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:52<00:00,  1.11s/it, loss=0.0781]


Epoch 6/100 | Loss: 0.0973 | SSIM: 0.7596
Saved checkpoint at epoch 5
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 7/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0714]


Saved best model (SSIM: 0.7600)
Epoch 7/100 | Loss: 0.0953 | SSIM: 0.7600
Saved checkpoint at epoch 6
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 8/100: 100%|███████████████████████████████████████████████████████| 263/263 [04:52<00:00,  1.11s/it, loss=0.123]


Epoch 8/100 | Loss: 0.0949 | SSIM: 0.7595
Saved checkpoint at epoch 7
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 9/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0767]


Epoch 9/100 | Loss: 0.0946 | SSIM: 0.7598
Saved checkpoint at epoch 8
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 10/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:59<00:00,  1.14s/it, loss=0.0698]


Epoch 10/100 | Loss: 0.0944 | SSIM: 0.7597
Saved checkpoint at epoch 9
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 11/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0928]


Saved best model (SSIM: 0.7611)
Epoch 11/100 | Loss: 0.0951 | SSIM: 0.7611
Saved checkpoint at epoch 10
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 12/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0682]


Epoch 12/100 | Loss: 0.0932 | SSIM: 0.7604
Saved checkpoint at epoch 11
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 13/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.132]


Epoch 13/100 | Loss: 0.0936 | SSIM: 0.7594
Saved checkpoint at epoch 12
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 14/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0821]


Saved best model (SSIM: 0.7617)
Epoch 14/100 | Loss: 0.0955 | SSIM: 0.7617
Saved checkpoint at epoch 13
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 15/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0623]


Epoch 15/100 | Loss: 0.0939 | SSIM: 0.7608
Saved checkpoint at epoch 14
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 16/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0809]


Epoch 16/100 | Loss: 0.0931 | SSIM: 0.7612
Saved checkpoint at epoch 15
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 17/100: 100%|███████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.11]


Epoch 17/100 | Loss: 0.0940 | SSIM: 0.7601
Saved checkpoint at epoch 16
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 18/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.108]


Saved best model (SSIM: 0.7624)
Epoch 18/100 | Loss: 0.0938 | SSIM: 0.7624
Saved checkpoint at epoch 17
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 19/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0787]


Epoch 19/100 | Loss: 0.0933 | SSIM: 0.7619
Saved checkpoint at epoch 18
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 20/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0749]


Saved best model (SSIM: 0.7629)
Epoch 20/100 | Loss: 0.0936 | SSIM: 0.7629
Saved checkpoint at epoch 19
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 21/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.137]


Saved best model (SSIM: 0.7630)
Epoch 21/100 | Loss: 0.0929 | SSIM: 0.7630
Saved checkpoint at epoch 20
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 22/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0821]


Epoch 22/100 | Loss: 0.0933 | SSIM: 0.7619
Saved checkpoint at epoch 21
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 23/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0888]


Epoch 23/100 | Loss: 0.0926 | SSIM: 0.7626
Saved checkpoint at epoch 22
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 24/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0784]


Saved best model (SSIM: 0.7635)
Epoch 24/100 | Loss: 0.0918 | SSIM: 0.7635
Saved checkpoint at epoch 23
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 25/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0795]


Epoch 25/100 | Loss: 0.0926 | SSIM: 0.7626
Saved checkpoint at epoch 24
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 26/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.127]


Epoch 26/100 | Loss: 0.0924 | SSIM: 0.7633
Saved checkpoint at epoch 25
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 27/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0804]


Epoch 27/100 | Loss: 0.0926 | SSIM: 0.7622
Saved checkpoint at epoch 26
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 28/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:48<00:00,  1.10s/it, loss=0.111]


Saved best model (SSIM: 0.7648)
Epoch 28/100 | Loss: 0.0930 | SSIM: 0.7648
Saved checkpoint at epoch 27
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 29/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.101]


Saved best model (SSIM: 0.7649)
Epoch 29/100 | Loss: 0.0924 | SSIM: 0.7649
Saved checkpoint at epoch 28
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 30/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0703]


Saved best model (SSIM: 0.7650)
Epoch 30/100 | Loss: 0.0914 | SSIM: 0.7650
Saved checkpoint at epoch 29
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 31/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0757]


Epoch 31/100 | Loss: 0.0916 | SSIM: 0.7620
Saved checkpoint at epoch 30
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 32/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:48<00:00,  1.10s/it, loss=0.122]


Epoch 32/100 | Loss: 0.0911 | SSIM: 0.7645
Saved checkpoint at epoch 31
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 33/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0876]


Saved best model (SSIM: 0.7652)
Epoch 33/100 | Loss: 0.0915 | SSIM: 0.7652
Saved checkpoint at epoch 32
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 34/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0844]


Saved best model (SSIM: 0.7657)
Epoch 34/100 | Loss: 0.0923 | SSIM: 0.7657
Saved checkpoint at epoch 33
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 35/100: 100%|█████████████████████████████████████████████████████| 263/263 [05:04<00:00,  1.16s/it, loss=0.0795]


Epoch 35/100 | Loss: 0.0922 | SSIM: 0.7643
Saved checkpoint at epoch 34
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 36/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:48<00:00,  1.10s/it, loss=0.0636]


Epoch 36/100 | Loss: 0.0916 | SSIM: 0.7641
Saved checkpoint at epoch 35
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 37/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0966]


Epoch 37/100 | Loss: 0.0903 | SSIM: 0.7647
Saved checkpoint at epoch 36
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 38/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:48<00:00,  1.10s/it, loss=0.116]


Saved best model (SSIM: 0.7658)
Epoch 38/100 | Loss: 0.0906 | SSIM: 0.7658
Saved checkpoint at epoch 37
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 39/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.105]


Epoch 39/100 | Loss: 0.0914 | SSIM: 0.7656
Saved checkpoint at epoch 38
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 40/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.111]


Saved best model (SSIM: 0.7661)
Epoch 40/100 | Loss: 0.0908 | SSIM: 0.7661
Saved checkpoint at epoch 39
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 41/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0697]


Saved best model (SSIM: 0.7662)
Epoch 41/100 | Loss: 0.0915 | SSIM: 0.7662
Saved checkpoint at epoch 40
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 42/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0951]


Epoch 42/100 | Loss: 0.0905 | SSIM: 0.7649
Saved checkpoint at epoch 41
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 43/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.117]


Epoch 43/100 | Loss: 0.0908 | SSIM: 0.7654
Saved checkpoint at epoch 42
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 44/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0703]


Saved best model (SSIM: 0.7665)
Epoch 44/100 | Loss: 0.0909 | SSIM: 0.7665
Saved checkpoint at epoch 43
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 45/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:52<00:00,  1.11s/it, loss=0.065]


Saved best model (SSIM: 0.7679)
Epoch 45/100 | Loss: 0.0902 | SSIM: 0.7679
Saved checkpoint at epoch 44
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 46/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:53<00:00,  1.12s/it, loss=0.0765]


Epoch 46/100 | Loss: 0.0900 | SSIM: 0.7672
Saved checkpoint at epoch 45
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 47/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0831]


Epoch 47/100 | Loss: 0.0911 | SSIM: 0.7653
Saved checkpoint at epoch 46
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 48/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0894]


Epoch 48/100 | Loss: 0.0892 | SSIM: 0.7665
Saved checkpoint at epoch 47
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 49/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.123]


Epoch 49/100 | Loss: 0.0905 | SSIM: 0.7658
Saved checkpoint at epoch 48
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 50/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0573]


Epoch 50/100 | Loss: 0.0904 | SSIM: 0.7670
Saved checkpoint at epoch 49
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 51/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0815]


Epoch 51/100 | Loss: 0.0907 | SSIM: 0.7660
Saved checkpoint at epoch 50
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 52/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0848]


Epoch 52/100 | Loss: 0.0887 | SSIM: 0.7675
Saved checkpoint at epoch 51
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 53/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0676]


Epoch 53/100 | Loss: 0.0907 | SSIM: 0.7659
Saved checkpoint at epoch 52
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 54/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.102]


Saved best model (SSIM: 0.7679)
Epoch 54/100 | Loss: 0.0887 | SSIM: 0.7679
Saved checkpoint at epoch 53
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 55/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.103]


Epoch 55/100 | Loss: 0.0905 | SSIM: 0.7665
Saved checkpoint at epoch 54
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 56/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0839]


Epoch 56/100 | Loss: 0.0894 | SSIM: 0.7673
Saved checkpoint at epoch 55
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 57/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:48<00:00,  1.10s/it, loss=0.0956]


Epoch 57/100 | Loss: 0.0894 | SSIM: 0.7665
Saved checkpoint at epoch 56
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 58/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.116]


Saved best model (SSIM: 0.7685)
Epoch 58/100 | Loss: 0.0902 | SSIM: 0.7685
Saved checkpoint at epoch 57
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 59/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.104]


Epoch 59/100 | Loss: 0.0899 | SSIM: 0.7676
Saved checkpoint at epoch 58
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 60/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.136]


Epoch 60/100 | Loss: 0.0892 | SSIM: 0.7672
Saved checkpoint at epoch 59
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 61/100: 100%|███████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.11]


Epoch 61/100 | Loss: 0.0877 | SSIM: 0.7672
Saved checkpoint at epoch 60
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 62/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0879]


Epoch 62/100 | Loss: 0.0889 | SSIM: 0.7682
Saved checkpoint at epoch 61
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 63/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0803]


Epoch 63/100 | Loss: 0.0894 | SSIM: 0.7680
Saved checkpoint at epoch 62
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 64/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0844]


Epoch 64/100 | Loss: 0.0888 | SSIM: 0.7672
Saved checkpoint at epoch 63
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 65/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0952]


Epoch 65/100 | Loss: 0.0887 | SSIM: 0.7680
Saved checkpoint at epoch 64
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 66/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.122]


Epoch 66/100 | Loss: 0.0878 | SSIM: 0.7680
Saved checkpoint at epoch 65
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 67/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.112]


Epoch 67/100 | Loss: 0.0885 | SSIM: 0.7682
Saved checkpoint at epoch 66
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 68/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0817]


Epoch 68/100 | Loss: 0.0899 | SSIM: 0.7674
Saved checkpoint at epoch 67
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 69/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0783]


Epoch 69/100 | Loss: 0.0888 | SSIM: 0.7677
Saved checkpoint at epoch 68
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 70/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0814]


Epoch 70/100 | Loss: 0.0891 | SSIM: 0.7677
Saved checkpoint at epoch 69
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 71/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0759]


Epoch 71/100 | Loss: 0.0875 | SSIM: 0.7678
Saved checkpoint at epoch 70
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 72/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:48<00:00,  1.10s/it, loss=0.0741]


Epoch 72/100 | Loss: 0.0887 | SSIM: 0.7681
Saved checkpoint at epoch 71
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 73/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0919]


Epoch 73/100 | Loss: 0.0880 | SSIM: 0.7680
Saved checkpoint at epoch 72
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 74/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.111]


Epoch 74/100 | Loss: 0.0875 | SSIM: 0.7682
Saved checkpoint at epoch 73
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 75/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.093]


Epoch 75/100 | Loss: 0.0882 | SSIM: 0.7673
Saved checkpoint at epoch 74
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 76/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0735]


Epoch 76/100 | Loss: 0.0891 | SSIM: 0.7677
Saved checkpoint at epoch 75
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 77/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0638]


Epoch 77/100 | Loss: 0.0900 | SSIM: 0.7678
Saved checkpoint at epoch 76
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 78/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0529]


Epoch 78/100 | Loss: 0.0881 | SSIM: 0.7680
Saved checkpoint at epoch 77
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 79/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:48<00:00,  1.10s/it, loss=0.0664]


Epoch 79/100 | Loss: 0.0879 | SSIM: 0.7677
Saved checkpoint at epoch 78
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 80/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0885]


Epoch 80/100 | Loss: 0.0889 | SSIM: 0.7683
Saved checkpoint at epoch 79
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 81/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0808]


Epoch 81/100 | Loss: 0.0889 | SSIM: 0.7679
Saved checkpoint at epoch 80
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 82/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0963]


Epoch 82/100 | Loss: 0.0885 | SSIM: 0.7680
Saved checkpoint at epoch 81
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 83/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0579]


Epoch 83/100 | Loss: 0.0878 | SSIM: 0.7676
Saved checkpoint at epoch 82
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 84/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0828]


Epoch 84/100 | Loss: 0.0894 | SSIM: 0.7680
Saved checkpoint at epoch 83
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 85/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0891]


Epoch 85/100 | Loss: 0.0887 | SSIM: 0.7680
Saved checkpoint at epoch 84
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 86/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0688]


Epoch 86/100 | Loss: 0.0883 | SSIM: 0.7681
Saved checkpoint at epoch 85
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 87/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0714]


Epoch 87/100 | Loss: 0.0895 | SSIM: 0.7680
Saved checkpoint at epoch 86
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 88/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:54<00:00,  1.12s/it, loss=0.0668]


Epoch 88/100 | Loss: 0.0882 | SSIM: 0.7683
Saved checkpoint at epoch 87
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 89/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0739]


Epoch 89/100 | Loss: 0.0890 | SSIM: 0.7683
Saved checkpoint at epoch 88
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 90/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0615]


Epoch 90/100 | Loss: 0.0890 | SSIM: 0.7682
Saved checkpoint at epoch 89
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 91/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.0949]


Epoch 91/100 | Loss: 0.0875 | SSIM: 0.7681
Saved checkpoint at epoch 90
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 92/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0807]


Epoch 92/100 | Loss: 0.0876 | SSIM: 0.7680
Saved checkpoint at epoch 91
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 93/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.106]


Epoch 93/100 | Loss: 0.0875 | SSIM: 0.7679
Saved checkpoint at epoch 92
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 94/100: 100%|███████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.07]


Epoch 94/100 | Loss: 0.0881 | SSIM: 0.7680
Saved checkpoint at epoch 93
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 95/100: 100%|██████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.115]


Epoch 95/100 | Loss: 0.0889 | SSIM: 0.7681
Saved checkpoint at epoch 94
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 96/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0586]


Epoch 96/100 | Loss: 0.0879 | SSIM: 0.7680
Saved checkpoint at epoch 95
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 97/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0726]


Epoch 97/100 | Loss: 0.0875 | SSIM: 0.7681
Saved checkpoint at epoch 96
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 98/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.11s/it, loss=0.0837]


Epoch 98/100 | Loss: 0.0876 | SSIM: 0.7679
Saved checkpoint at epoch 97
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 99/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:49<00:00,  1.10s/it, loss=0.0812]


Epoch 99/100 | Loss: 0.0881 | SSIM: 0.7680
Saved checkpoint at epoch 98
GPU Memory: 0.14/8.00 GB (1.7%)


Epoch 100/100: 100%|█████████████████████████████████████████████████████| 263/263 [04:50<00:00,  1.10s/it, loss=0.066]


Epoch 100/100 | Loss: 0.0880 | SSIM: 0.7679
Saved checkpoint at epoch 99
GPU Memory: 0.14/8.00 GB (1.7%)


  student.load_state_dict(torch.load(config.STUDENT_SAVE_PATH))



Final Evaluation:
SSIM: 0.7685
FPS: 10.48 @ (1920, 1080)
Saved comparison images to debug/ folder!
