In [None]:
# fast_train.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from basicsr.archs.restormer_arch import Restormer
import time

In [None]:
class TurboDataset(Dataset):
    """RAM-cached dataset with optimized loading"""
    def __init__(self, blur_dir, sharp_dir, size=128):
        self.blur_paths = sorted([
            os.path.join(blur_dir, f) 
            for f in os.listdir(blur_dir) 
            if f.lower().endswith(('.png','.jpg','.jpeg'))
        ])
        self.sharp_paths = sorted([
            os.path.join(sharp_dir, f) 
            for f in os.listdir(sharp_dir) 
            if f.lower().endswith(('.png','.jpg','.jpeg'))
        ])
        
        # Pre-load all images into RAM
        print("‚ö° Caching images...")
        self.cache = []
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor()
        ])
        
        for b_path, s_path in zip(self.blur_paths, self.sharp_paths):
            blur = Image.open(b_path).convert('RGB')
            sharp = Image.open(s_path).convert('RGB')
            self.cache.append((
                self.transform(blur),
                self.transform(sharp)
            ))

In [None]:
    def __len__(self):
        return len(self.cache)

In [None]:
    def __getitem__(self, idx):
        return self.cache[idx]

In [None]:
class NanoRestormer(Restormer):
    """Ultra-light student model"""
    def __init__(self):
        super().__init__(
            dim=16,                  # Reduced capacity
            num_blocks=[1,1,1,1],    # Minimal depth
            num_refinement_blocks=1,
            heads=[1,1,1,1],         # Few attention heads
            ffn_expansion_factor=1.5, # Smaller expansion
            bias=False
        )

In [None]:
def train():
    # Hardware setup
    device = torch.device("cpu")
    torch.set_num_threads(os.cpu_count() or 4)
    print(f"üöÄ Training on {device} with {torch.get_num_threads()} threads")
    
    # Data - Full 1029 samples
    dataset = TurboDataset(
        blur_dir="C:/Users/Nayana/OneDrive/Desktop/image sharpening kb/dataset/gopro/gopro_deblur/blur/images",
        sharp_dir="C:/Users/Nayana/OneDrive/Desktop/image sharpening kb/dataset/gopro/gopro_deblur/sharp/images",
        size=128  # Fixed small size
    )
    
    # Large batch size for CPU efficiency
    dataloader = DataLoader(
        dataset,
        batch_size=16,  # Increased from 8
        shuffle=True,
        num_workers=0   # Disabled for RAM caching
    )

In [None]:
    # Teacher model
    teacher = Restormer().eval()
    teacher.load_state_dict(torch.load(
        "pretrained_models/motion_deblurring.pth",
        map_location=device
    )["params"])
    teacher = teacher.to(device)

In [None]:
    # Tiny student model
    student = NanoRestormer().to(device)
    optimizer = optim.Adam(student.parameters(), lr=2e-4)  # Higher learning rate
    loss_fn = nn.L1Loss()  # Faster than MSE

In [None]:
    # Time tracking
    start_time = time.time()
    max_duration = 6 * 3600  # 6 hours in seconds

In [None]:
    for epoch in range(50):
        epoch_start = time.time()
        student.train()
        epoch_loss = 0.0
        
        for blurry, sharp in tqdm(dataloader, desc=f"Epoch {epoch+1}/50"):
            # Check time remaining
            elapsed = time.time() - start_time
            if elapsed > max_duration * 0.95:  # Stop before 6 hours
                print("\n‚è∞ Approaching 6 hour limit - saving model...")
                torch.save({"params": student.state_dict()}, "student_final.pth")
                print("‚úÖ Model saved successfully")
                return
            
            blurry, sharp = blurry.to(device), sharp.to(device)
            
            with torch.no_grad():
                teacher_out = teacher(blurry)
            
            student_out = student(blurry)
            loss = loss_fn(student_out, teacher_out)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        epoch_time = time.time() - epoch_start
        remaining = max(0, max_duration - (time.time() - start_time))
        epochs_left = min(50 - (epoch+1), int(remaining / epoch_time))
        
        print(f"Epoch {epoch+1} | Loss: {epoch_loss/len(dataloader):.4f} | "
              f"Time: {epoch_time:.1f}s | Est. remaining: {epochs_left} epochs")

In [None]:
    torch.save({"params": student.state_dict()}, "student_final.pth")
    print("‚úÖ Full training complete!")

In [None]:
if __name__ == "__main__":
    os.environ["OMP_NUM_THREADS"] = "1"  # Prevents thread oversubscription
    train()